diff --git a/.github/workflows/pretest-and-test.yml b/.github/workflows/pretest-and-test.yml index 3ccb18b0b..44ebb5722 100644 --- a/.github/workflows/pretest-and-test.yml +++ b/.github/workflows/pretest-and-test.yml @@ -38,7 +38,7 @@ jobs: - name: Code Style run: | - pip install pysen black==21.11b1 flake8==4.0.1 isort==5.10.1 mypy==0.991 + pip install pysen black==23.3.0 flake8==4.0.1 isort==5.10.1 mypy==0.991 pip install types-PyYAML types-setuptools cp "$(pip show torch | awk '/^Location:/ { print $2 }')/torch/__init__.py" stubs/torch/__init__.py MYPYPATH="${PWD}/stubs" pysen run lint diff --git a/docs/source/conf.py b/docs/source/conf.py index b2ef36240..fd906b89b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,9 +17,9 @@ # -- Project information ----------------------------------------------------- -project = 'pytorch-pfn-extras' -copyright = '2021, Preferred Networks, Inc.' -author = 'Preferred Networks, Inc.' +project = "pytorch-pfn-extras" +copyright = "2021, Preferred Networks, Inc." +author = "Preferred Networks, Inc." # -- General configuration --------------------------------------------------- @@ -28,15 +28,15 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'myst_parser', + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "myst_parser", ] autodoc_typehints = "description" # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -51,9 +51,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/example/ignite-mnist.py b/example/ignite-mnist.py index c048d5ea5..2d2750102 100644 --- a/example/ignite-mnist.py +++ b/example/ignite-mnist.py @@ -1,20 +1,20 @@ from argparse import ArgumentParser +import pytorch_pfn_extras as ppe +import pytorch_pfn_extras.training.extensions as extensions +import torch +import torch.nn.functional as F +from ignite.engine import ( + Events, + create_supervised_evaluator, + create_supervised_trainer, +) +from ignite.metrics import Accuracy, Loss from torch import nn from torch.optim import SGD from torch.utils.data import DataLoader -import torch -import torch.nn.functional as F -from torchvision.transforms import Compose, ToTensor, Normalize from torchvision.datasets import MNIST - -from ignite.engine import Events -from ignite.engine import create_supervised_trainer -from ignite.engine import create_supervised_evaluator -from ignite.metrics import Accuracy, Loss - -import pytorch_pfn_extras as ppe -import pytorch_pfn_extras.training.extensions as extensions +from torchvision.transforms import Compose, Normalize, ToTensor class Net(nn.Module): @@ -40,55 +40,75 @@ def get_data_loaders(train_batch_size, val_batch_size): data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) train_loader = DataLoader( - MNIST(download=True, root="../data", transform=data_transform, - train=True), - batch_size=train_batch_size, shuffle=True) + MNIST( + download=True, root="../data", transform=data_transform, train=True + ), + batch_size=train_batch_size, + shuffle=True, + ) val_loader = DataLoader( - MNIST(download=False, root="../data", transform=data_transform, - train=False), - batch_size=val_batch_size, shuffle=False) + MNIST( + download=False, + root="../data", + transform=data_transform, + train=False, + ), + batch_size=val_batch_size, + shuffle=False, + ) return train_loader, val_loader def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval): train_loader, val_loader = get_data_loaders( - train_batch_size, val_batch_size) + train_batch_size, val_batch_size + ) model = Net() - device = 'cpu' + device = "cpu" if torch.cuda.is_available(): - device = 'cuda:0' + device = "cuda:0" model = model.to(device) optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) optimizer.step() trainer = create_supervised_trainer( - model, optimizer, F.nll_loss, device=device) + model, optimizer, F.nll_loss, device=device + ) evaluator = create_supervised_evaluator( model, - metrics={'acc': Accuracy(), 'loss': Loss(F.nll_loss)}, - device=device) + metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)}, + device=device, + ) # manager.extend(...) also works my_extensions = [ extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), - extensions.ParameterStatistics(model, prefix='model'), + extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.snapshot(), extensions.IgniteEvaluator( - evaluator, val_loader, model, progress_bar=True), - extensions.PlotReport(['train/loss'], 'epoch', filename='loss.png'), - extensions.PrintReport([ - 'epoch', 'iteration', 'train/loss', 'lr', - 'model/fc2.bias/grad/min', 'val/loss', 'val/acc', - ]), + evaluator, val_loader, model, progress_bar=True + ), + extensions.PlotReport(["train/loss"], "epoch", filename="loss.png"), + extensions.PrintReport( + [ + "epoch", + "iteration", + "train/loss", + "lr", + "model/fc2.bias/grad/min", + "val/loss", + "val/acc", + ] + ), ] - models = {'main': model} - optimizers = {'main': optimizer} + models = {"main": model} + optimizers = {"main": optimizer} manager = ppe.training.IgniteExtensionsManager( - trainer, models, optimizers, args.epochs, - extensions=my_extensions) + trainer, models, optimizers, args.epochs, extensions=my_extensions + ) # Lets load the snapshot if args.snapshot is not None: @@ -97,30 +117,57 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval): @trainer.on(Events.ITERATION_COMPLETED) def report_loss(engine): - ppe.reporting.report({'train/loss': engine.state.output}) + ppe.reporting.report({"train/loss": engine.state.output}) trainer.run(train_loader, max_epochs=epochs) if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument('--batch_size', type=int, default=64, - help='input batch size for training (default: 64)') - parser.add_argument('--val_batch_size', type=int, default=1000, - help='input batch size for validation (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, - help='SGD momentum (default: 0.5)') - parser.add_argument('--log_interval', type=int, default=10, - help='how many batches to wait before logging ' - 'training status') - parser.add_argument('--snapshot', type=str, default=None, - help='path to snapshot file') + parser.add_argument( + "--batch_size", + type=int, + default=64, + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--val_batch_size", + type=int, + default=1000, + help="input batch size for validation (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", type=float, default=0.01, help="learning rate (default: 0.01)" + ) + parser.add_argument( + "--momentum", + type=float, + default=0.5, + help="SGD momentum (default: 0.5)", + ) + parser.add_argument( + "--log_interval", + type=int, + default=10, + help="how many batches to wait before logging " "training status", + ) + parser.add_argument( + "--snapshot", type=str, default=None, help="path to snapshot file" + ) args = parser.parse_args() - run(args.batch_size, args.val_batch_size, args.epochs, args.lr, - args.momentum, args.log_interval) + run( + args.batch_size, + args.val_batch_size, + args.epochs, + args.lr, + args.momentum, + args.log_interval, + ) diff --git a/example/mnist.py b/example/mnist.py index cab6b1cb0..ba2ddaa74 100644 --- a/example/mnist.py +++ b/example/mnist.py @@ -1,13 +1,13 @@ import argparse + +import pytorch_pfn_extras as ppe +import pytorch_pfn_extras.training.extensions as extensions import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms -import pytorch_pfn_extras as ppe -import pytorch_pfn_extras.training.extensions as extensions - class Net(nn.Module): def __init__(self): @@ -33,56 +33,98 @@ def train(manager, args, model, device, train_loader): while not manager.stop_trigger: model.train() for _, (data, target) in enumerate(train_loader): - with manager.run_iteration(step_optimizers=['main']): + with manager.run_iteration(step_optimizers=["main"]): data, target = data.to(device), target.to(device) output = model(data) loss = F.nll_loss(output, target) - ppe.reporting.report({'train/loss': loss.item()}) + ppe.reporting.report({"train/loss": loss.item()}) loss.backward() def test(args, model, device, data, target): - """ The extension loops over the iterator in order to - drive the evaluator progress bar and reporting - averages + """The extension loops over the iterator in order to + drive the evaluator progress bar and reporting + averages """ model.eval() data, target = data.to(device), target.to(device) output = model(data) # Final result will be average of averages of the same size - test_loss = F.nll_loss(output, target, reduction='mean').item() - ppe.reporting.report({'val/loss': test_loss}) + test_loss = F.nll_loss(output, target, reduction="mean").item() + ppe.reporting.report({"val/loss": test_loss}) pred = output.argmax(dim=1, keepdim=True) correct = pred.eq(target.view_as(pred)).sum().item() - ppe.reporting.report({'val/acc': correct / len(data)}) + ppe.reporting.report({"val/acc": correct / len(data)}) def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, - metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--no-cuda', dest='cuda', - action='store_false', default=True, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - parser.add_argument('--snapshot', type=str, default=None, - help='path to snapshot file') - parser.add_argument('--slack', type=str, default=None, - help='post to the specified Slack channel') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)", + ) + parser.add_argument( + "--no-cuda", + dest="cuda", + action="store_false", + default=True, + help="disables CUDA training", + ) + parser.add_argument( + "--seed", + type=int, + default=1, + metavar="S", + help="random seed (default: 1)", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument( + "--snapshot", type=str, default=None, help="path to snapshot file" + ) + parser.add_argument( + "--slack", + type=str, + default=None, + help="post to the specified Slack channel", + ) args = parser.parse_args() use_cuda = args.cuda and torch.cuda.is_available() @@ -90,68 +132,98 @@ def main(): device = torch.device("cuda" if use_cuda else "cpu") - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, - **kwargs) # type: ignore[arg-type] + datasets.MNIST( + "../data", + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs, # type: ignore[arg-type] + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.test_batch_size, shuffle=True, - **kwargs) # type: ignore[arg-type] + datasets.MNIST( + "../data", + train=False, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ), + batch_size=args.test_batch_size, + shuffle=True, + **kwargs, # type: ignore[arg-type] + ) model = Net() model.to(device) optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) + model.parameters(), lr=args.lr, momentum=args.momentum + ) # manager.extend(...) also works my_extensions = [ extensions.LogReport(), - # Enables TensorBoard support. # Run `tensorboard --logdir runs` to launch the TensorBoard. extensions.LogReport( - writer=ppe.writing.TensorBoardWriter(out_dir='runs'), - trigger=(1, 'iteration')), + writer=ppe.writing.TensorBoardWriter(out_dir="runs"), + trigger=(1, "iteration"), + ), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), - extensions.ParameterStatistics(model, prefix='model'), + extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.Evaluator( - test_loader, model, - eval_func=lambda data, target: - test(args, model, device, data, target), - progress_bar=True), + test_loader, + model, + eval_func=lambda data, target: test( + args, model, device, data, target + ), + progress_bar=True, + ), extensions.PlotReport( - ['train/loss', 'val/loss'], 'epoch', filename='loss.png'), - extensions.PrintReport(['epoch', 'iteration', - 'train/loss', 'lr', 'model/fc2.bias/grad/min', - 'val/loss', 'val/acc']), + ["train/loss", "val/loss"], "epoch", filename="loss.png" + ), + extensions.PrintReport( + [ + "epoch", + "iteration", + "train/loss", + "lr", + "model/fc2.bias/grad/min", + "val/loss", + "val/acc", + ] + ), extensions.snapshot(), ] if args.slack is not None: - my_extensions.append(extensions.Slack( - channel=args.slack, - msg='Epoch #{manager.epoch}: val/loss = {val/loss}', - # Surround the username with <> to mention. - end_msg='{default}\n<@your_slack_user_name>', - - # Upload any artifacts generated during the training. - filenames=['result/statistics.png'], - # You can specify when to upload these files. - # e.g., only at the final epoch: - # upload_trigger=(args.epochs, 'epoch'), - )) + my_extensions.append( + extensions.Slack( + channel=args.slack, + msg="Epoch #{manager.epoch}: val/loss = {val/loss}", + # Surround the username with <> to mention. + end_msg="{default}\n<@your_slack_user_name>", + # Upload any artifacts generated during the training. + filenames=["result/statistics.png"], + # You can specify when to upload these files. + # e.g., only at the final epoch: + # upload_trigger=(args.epochs, 'epoch'), + ) + ) # Custom stop triggers can be added to the manager and # their status accessed through `manager.stop_trigger` @@ -159,10 +231,13 @@ def main(): # trigger = ppe.training.triggers.EarlyStoppingTrigger( # check_trigger=(1, 'epoch'), monitor='val/loss') manager = ppe.training.ExtensionsManager( - model, optimizer, args.epochs, + model, + optimizer, + args.epochs, extensions=my_extensions, iters_per_epoch=len(train_loader), - stop_trigger=trigger) + stop_trigger=trigger, + ) # Lets load the snapshot if args.snapshot is not None: state = torch.load(args.snapshot) @@ -172,9 +247,9 @@ def main(): # to get access to the reporter and other facilities # test(args, model, device, test_loader) - if (args.save_model): + if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/example/mnist_custom_logic.py b/example/mnist_custom_logic.py index 9721cf101..4fcb01f3b 100644 --- a/example/mnist_custom_logic.py +++ b/example/mnist_custom_logic.py @@ -1,13 +1,13 @@ import argparse + +import pytorch_pfn_extras as ppe +import pytorch_pfn_extras.training.extensions as extensions import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms -import pytorch_pfn_extras as ppe -import pytorch_pfn_extras.training.extensions as extensions - class Net(nn.Module): def __init__(self): @@ -29,7 +29,6 @@ def forward(self, x): class CustomLogic(ppe.handler.Logic): - def __init__(self, steps_per_update): self.steps_per_update = steps_per_update super().__init__() @@ -58,67 +57,132 @@ def train_step_optimizers(self, models, optimizers, batch_idx): def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, - metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--device', type=str, default='cuda', - help='PyTorch device specifier') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - parser.add_argument('--snapshot', type=str, default=None, - help='path to snapshot file') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)", + ) + parser.add_argument( + "--device", type=str, default="cuda", help="PyTorch device specifier" + ) + parser.add_argument( + "--seed", + type=int, + default=1, + metavar="S", + help="random seed (default: 1)", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument( + "--snapshot", type=str, default=None, help="path to snapshot file" + ) args = parser.parse_args() torch.manual_seed(args.seed) - use_cuda = args.device.startswith('cuda') + use_cuda = args.device.startswith("cuda") - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, + datasets.MNIST( + "../data", + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ), + batch_size=args.batch_size, + shuffle=True, collate_fn=ppe.dataloaders.utils.CollateAsDict( - names=['data', 'target']), **kwargs) # type: ignore[arg-type] + names=["data", "target"] + ), + **kwargs, # type: ignore[arg-type] + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.test_batch_size, shuffle=True, + datasets.MNIST( + "../data", + train=False, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ), + batch_size=args.test_batch_size, + shuffle=True, collate_fn=ppe.dataloaders.utils.CollateAsDict( - names=['data', 'target']), **kwargs) # type: ignore[arg-type] + names=["data", "target"] + ), + **kwargs, # type: ignore[arg-type] + ) model = Net() optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) + model.parameters(), lr=args.lr, momentum=args.momentum + ) my_extensions = [ extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), - extensions.ParameterStatistics(model, prefix='model'), + extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.PlotReport( - ['train/loss', 'val/loss'], 'epoch', filename='loss.png'), - extensions.PrintReport(['epoch', 'iteration', - 'train/loss', 'lr', 'model/fc2.bias/grad/min', - 'val/loss', 'val/accuracy']), + ["train/loss", "val/loss"], "epoch", filename="loss.png" + ), + extensions.PrintReport( + [ + "epoch", + "iteration", + "train/loss", + "lr", + "model/fc2.bias/grad/min", + "val/loss", + "val/accuracy", + ] + ), extensions.snapshot(), ] @@ -138,13 +202,13 @@ def forward(self, data, target): if model.training: loss = F.nll_loss(output, target) - ppe.reporting.report({'train/loss': loss.item()}) - return {'loss': loss} + ppe.reporting.report({"train/loss": loss.item()}) + return {"loss": loss} # Final result will be average of averages of the same size - test_loss = F.nll_loss(output, target, reduction='mean').item() + test_loss = F.nll_loss(output, target, reduction="mean").item() pred = output.argmax(dim=1, keepdim=True) - return {'loss': test_loss, 'output': pred} + return {"loss": test_loss, "output": pred} model_with_loss = ModelWithLoss(model) trainer = ppe.engine.create_trainer( @@ -158,9 +222,10 @@ def forward(self, data, target): model_with_loss, device=args.device, progress_bar=True, - metrics=[ppe.training.metrics.AccuracyMetric('target', 'output')], - options={'eval_report_keys': ['loss', 'accuracy']}), - options={'train_report_keys': ['loss']}, + metrics=[ppe.training.metrics.AccuracyMetric("target", "output")], + options={"eval_report_keys": ["loss", "accuracy"]}, + ), + options={"train_report_keys": ["loss"]}, logic=CustomLogic(3), ) @@ -174,9 +239,9 @@ def forward(self, data, target): trainer.run(train_loader, test_loader) - if (args.save_model): + if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/example/mnist_ddp.py b/example/mnist_ddp.py index b972be291..abee0502a 100644 --- a/example/mnist_ddp.py +++ b/example/mnist_ddp.py @@ -1,14 +1,13 @@ import argparse +import pytorch_pfn_extras as ppe +import pytorch_pfn_extras.training.extensions as extensions import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms -import pytorch_pfn_extras as ppe -import pytorch_pfn_extras.training.extensions as extensions - class Net(nn.Module): def __init__(self): @@ -33,18 +32,18 @@ def train(manager, args, model, device, train_loader): while not manager.stop_trigger: model.train() for _, (data, target) in enumerate(train_loader): - with manager.run_iteration(step_optimizers=['main']): + with manager.run_iteration(step_optimizers=["main"]): data, target = data.to(device), target.to(device) output = model(data) loss = F.nll_loss(output, target) - ppe.reporting.report({'train/loss': loss.item()}) + ppe.reporting.report({"train/loss": loss.item()}) loss.backward() def test(args, model, device, data, target): - """ The extension loops over the iterator in order to - drive the evaluator progress bar and reporting - averages + """The extension loops over the iterator in order to + drive the evaluator progress bar and reporting + averages """ model.eval() test_loss = 0.0 @@ -52,52 +51,95 @@ def test(args, model, device, data, target): data, target = data.to(device), target.to(device) output = model(data) # Final result will be average of averages of the same size - test_loss += F.nll_loss(output, target, reduction='mean').item() - ppe.reporting.report({'val/loss': test_loss}) + test_loss += F.nll_loss(output, target, reduction="mean").item() + ppe.reporting.report({"val/loss": test_loss}) pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() - ppe.reporting.report({'val/acc': correct / len(data)}) + ppe.reporting.report({"val/acc": correct / len(data)}) def init_distributed(use_cuda=True): # setup env for torch.distributed - comm_world_size, comm_rank, comm_local_rank = ( - ppe.distributed.initialize_ompi_environment( - backend="nccl", init_method="env")) + ( + comm_world_size, + comm_rank, + comm_local_rank, + ) = ppe.distributed.initialize_ompi_environment( + backend="nccl", init_method="env" + ) if comm_rank == 0: print("World size = {}".format(comm_world_size)) print("Rank = {}, Local Rank = {}".format(comm_rank, comm_local_rank)) torch.cuda.set_device(comm_local_rank) device = torch.device( - "cuda:{}".format(comm_local_rank) if use_cuda else "cpu") + "cuda:{}".format(comm_local_rank) if use_cuda else "cpu" + ) return comm_world_size, comm_rank, comm_local_rank, device def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, - metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--no-cuda', dest='cuda', - action='store_false', default=True, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - parser.add_argument('--snapshot', type=str, default=None, - help='path to snapshot file') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)", + ) + parser.add_argument( + "--no-cuda", + dest="cuda", + action="store_false", + default=True, + help="disables CUDA training", + ) + parser.add_argument( + "--seed", + type=int, + default=1, + metavar="S", + help="random seed (default: 1)", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument( + "--snapshot", type=str, default=None, help="path to snapshot file" + ) args = parser.parse_args() use_cuda = args.cuda and torch.cuda.is_available() @@ -106,14 +148,15 @@ def main(): torch.manual_seed(args.seed) comm_world_size, comm_rank, comm_local_rank, device = init_distributed( - use_cuda) + use_cuda + ) if comm_rank == 0: print("World size = {}".format(comm_world_size)) print("Rank = {}, Local Rank = {}".format(comm_rank, comm_local_rank)) print("Device = {}".format(device)) - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} - dataset_root = '../data' + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + dataset_root = "../data" if comm_local_rank == 0: # download mnist datasets.MNIST(dataset_root, download=True) @@ -122,37 +165,53 @@ def main(): train_dataset = datasets.MNIST( dataset_root, train=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)), - ])) + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ) test_dataset = datasets.MNIST( dataset_root, train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)), - ])) + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ) train_sampler = torch.utils.data.DistributedSampler[int]( - train_dataset, num_replicas=comm_world_size, rank=comm_rank) + train_dataset, num_replicas=comm_world_size, rank=comm_rank + ) train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, sampler=train_sampler, - **kwargs) # type: ignore[arg-type] + train_dataset, + batch_size=args.batch_size, + sampler=train_sampler, + **kwargs, # type: ignore[arg-type] + ) test_dataset_indices = list(range(len(test_dataset))) local_test_dataset_indices = test_dataset_indices[ - comm_rank:len(test_dataset_indices):comm_world_size] + comm_rank : len(test_dataset_indices) : comm_world_size + ] local_test_dataset = torch.utils.data.Subset( - test_dataset, local_test_dataset_indices) + test_dataset, local_test_dataset_indices + ) test_loader = torch.utils.data.DataLoader( - local_test_dataset, batch_size=args.test_batch_size, shuffle=True, - **kwargs) # type: ignore[arg-type] + local_test_dataset, + batch_size=args.test_batch_size, + shuffle=True, + **kwargs, # type: ignore[arg-type] + ) model = ppe.nn.parallel.DistributedDataParallel(Net().to(device)) optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) + model.parameters(), lr=args.lr, momentum=args.momentum + ) # manager.extend(...) also works if comm_local_rank == 0: @@ -160,19 +219,30 @@ def main(): extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), - extensions.ParameterStatistics(model, prefix='model'), + extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.Evaluator( - test_loader, model, - eval_func=lambda data, target: - test(args, model, device, data, target), - progress_bar=True), + test_loader, + model, + eval_func=lambda data, target: test( + args, model, device, data, target + ), + progress_bar=True, + ), extensions.PlotReport( - ['train/loss', 'val/loss'], 'epoch', filename='loss.png'), - extensions.PrintReport(['epoch', 'iteration', - 'train/loss', 'lr', - 'model/fc2.bias/grad/min', - 'val/loss', 'val/acc']), + ["train/loss", "val/loss"], "epoch", filename="loss.png" + ), + extensions.PrintReport( + [ + "epoch", + "iteration", + "train/loss", + "lr", + "model/fc2.bias/grad/min", + "val/loss", + "val/acc", + ] + ), extensions.snapshot(), ] else: @@ -184,10 +254,13 @@ def main(): # trigger = ppe.training.triggers.EarlyStoppingTrigger( # check_trigger=(1, 'epoch'), monitor='val/loss') manager = ppe.training.ExtensionsManager( - model, optimizer, args.epochs, + model, + optimizer, + args.epochs, extensions=my_extensions, iters_per_epoch=len(train_loader), - stop_trigger=trigger) + stop_trigger=trigger, + ) # Lets load the snapshot if args.snapshot is not None: state = torch.load(args.snapshot) @@ -197,12 +270,12 @@ def main(): # to get access to the reporter and other facilities # test(args, model, device, test_loader) - if (args.save_model): + if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt") # Wait for all processes to finish to complete successfully torch.distributed.barrier() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/example/mnist_trainer.py b/example/mnist_trainer.py index 55f9b5fcf..f0caff434 100644 --- a/example/mnist_trainer.py +++ b/example/mnist_trainer.py @@ -1,15 +1,14 @@ import argparse import numpy +import pytorch_pfn_extras as ppe +import pytorch_pfn_extras.training.extensions as extensions import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms -import pytorch_pfn_extras as ppe -import pytorch_pfn_extras.training.extensions as extensions - class Net(nn.Module): def __init__(self): @@ -40,88 +39,169 @@ def forward(self, data, target): if self.training: loss = F.nll_loss(output, target) - ppe.reporting.report({'train/loss': loss.item()}) - return {'loss': loss} + ppe.reporting.report({"train/loss": loss.item()}) + return {"loss": loss} # Final result will be average of averages of the same size - test_loss = F.nll_loss(output, target, reduction='mean').item() + test_loss = F.nll_loss(output, target, reduction="mean").item() pred = output.argmax(dim=1, keepdim=True) - return {'loss': test_loss, 'output': pred} + return {"loss": test_loss, "output": pred} def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, - metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument('--device', type=str, default='cuda', - help='PyTorch device specifier') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--deterministic', action='store_true', default=False, - help='make the behavior deterministic') - parser.add_argument('--save-model', action='store_true', default=False, - help='For Saving the current Model') - parser.add_argument('--snapshot', type=str, default=None, - help='path to snapshot file') - parser.add_argument('--compare-dump', type=str, default=None, - help='directory to save comparer dump to') - parser.add_argument('--compare-with', type=str, default=None, - help='directory to load comparer dump from') - parser.add_argument('--profiler', type=str, default=None, - help='output mode for profiler results') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)", + ) + parser.add_argument( + "--device", type=str, default="cuda", help="PyTorch device specifier" + ) + parser.add_argument( + "--seed", + type=int, + default=1, + metavar="S", + help="random seed (default: 1)", + ) + parser.add_argument( + "--deterministic", + action="store_true", + default=False, + help="make the behavior deterministic", + ) + parser.add_argument( + "--save-model", + action="store_true", + default=False, + help="For Saving the current Model", + ) + parser.add_argument( + "--snapshot", type=str, default=None, help="path to snapshot file" + ) + parser.add_argument( + "--compare-dump", + type=str, + default=None, + help="directory to save comparer dump to", + ) + parser.add_argument( + "--compare-with", + type=str, + default=None, + help="directory to load comparer dump from", + ) + parser.add_argument( + "--profiler", + type=str, + default=None, + help="output mode for profiler results", + ) args = parser.parse_args() torch.manual_seed(args.seed) numpy.random.seed(args.seed) torch.use_deterministic_algorithms(args.deterministic) - use_cuda = args.device.startswith('cuda') + use_cuda = args.device.startswith("cuda") - kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, + datasets.MNIST( + "../data", + train=True, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ), + batch_size=args.batch_size, + shuffle=True, collate_fn=ppe.dataloaders.utils.CollateAsDict( - names=['data', 'target']), **kwargs) # type: ignore[arg-type] + names=["data", "target"] + ), + **kwargs, # type: ignore[arg-type] + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.test_batch_size, shuffle=True, + datasets.MNIST( + "../data", + train=False, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ), + batch_size=args.test_batch_size, + shuffle=True, collate_fn=ppe.dataloaders.utils.CollateAsDict( - names=['data', 'target']), **kwargs) # type: ignore[arg-type] + names=["data", "target"] + ), + **kwargs, # type: ignore[arg-type] + ) model = Net() optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) + model.parameters(), lr=args.lr, momentum=args.momentum + ) my_extensions = [ extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), - extensions.ParameterStatistics(model, prefix='model'), + extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.PlotReport( - ['train/loss', 'val/loss'], 'epoch', filename='loss.png'), - extensions.PrintReport(['epoch', 'iteration', - 'train/loss', 'lr', 'model/fc2.bias/grad/min', - 'val/loss', 'val/accuracy']), + ["train/loss", "val/loss"], "epoch", filename="loss.png" + ), + extensions.PrintReport( + [ + "epoch", + "iteration", + "train/loss", + "lr", + "model/fc2.bias/grad/min", + "val/loss", + "val/accuracy", + ] + ), extensions.snapshot(), ] @@ -133,25 +213,37 @@ def main(): profile = None if args.profiler is not None: - if args.profiler == 'tensorboard': + if args.profiler == "tensorboard": + def callback(prof): - torch.profiler.tensorboard_trace_handler('./prof') # type: ignore[attr-defined] - elif args.profiler == 'export_chrome_trace': + torch.profiler.tensorboard_trace_handler("./prof") # type: ignore[attr-defined] + + elif args.profiler == "export_chrome_trace": + def callback(prof): - prof.export_chrome_trace('./prof') - elif args.profiler == 'export_stacks': + prof.export_chrome_trace("./prof") + + elif args.profiler == "export_stacks": + def callback(prof): - prof.export_stacks('./prof') - elif args.profiler == 'to_pickle': + prof.export_stacks("./prof") + + elif args.profiler == "to_pickle": + def callback(prof): import pandas as pd + df = pd.DataFrame([e.__dict__ for e in prof.events()]) df.to_pickle(f"{trainer.epoch}.pkl") - elif args.profiler == 'print': + + elif args.profiler == "print": + def callback(prof): table = prof.key_averages().table( - sort_by="self_cuda_time_total", row_limit=-1) + sort_by="self_cuda_time_total", row_limit=-1 + ) print(table) + else: assert False profile = torch.profiler.profile( # type: ignore[attr-defined] @@ -160,7 +252,8 @@ def callback(prof): torch.profiler.ProfilerActivity.CUDA, # type: ignore[attr-defined] ], schedule=torch.profiler.schedule( # type: ignore[attr-defined] - wait=0, warmup=0, active=len(train_loader)), + wait=0, warmup=0, active=len(train_loader) + ), on_trace_ready=callback, ) @@ -176,10 +269,10 @@ def callback(prof): model_with_loss, device=args.device, progress_bar=True, - metrics=[ppe.training.metrics.AccuracyMetric('target', 'output')], - options={'eval_report_keys': ['loss', 'accuracy']}, + metrics=[ppe.training.metrics.AccuracyMetric("target", "output")], + options={"eval_report_keys": ["loss", "accuracy"]}, ), - options={'train_report_keys': ['loss']}, + options={"train_report_keys": ["loss"]}, profile=profile, ) @@ -194,11 +287,11 @@ def callback(prof): if args.compare_dump is not None or args.compare_with is not None: comp = ppe.utils.comparer.Comparer( compare_fn=ppe.utils.comparer.get_default_comparer(rtol=1e-2), - outputs=['loss'], + outputs=["loss"], ) if args.compare_dump is None: # Compare the engine with an existing dump directory. - comp.add_dump('baseline', args.compare_with) + comp.add_dump("baseline", args.compare_with) comp.add_engine(args.device, trainer, train_loader, test_loader) comp.compare() else: @@ -209,9 +302,9 @@ def callback(prof): trainer.run(train_loader, test_loader) - if (args.save_model): + if args.save_model: torch.save(model.state_dict(), "mnist_cnn.pt") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pyproject.toml b/pyproject.toml index e5a6980ea..9fc29385e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,12 +5,16 @@ requires = ["setuptools<64", "wheel"] version = "0.10.1" [tool.pysen.lint] -enable_black = false +enable_black = true enable_flake8 = true -enable_isort = false +enable_isort = true enable_mypy = true mypy_preset = "strict" line_length = 80 [[tool.pysen.lint.mypy_targets]] paths = ["pytorch_pfn_extras"] + +[tool.pysen.lint.source] + includes = ["."] + excludes = ["pytorch_pfn_extras/onnx/", "tests/pytorch_pfn_extras_tests/onnx_tests/"] diff --git a/pytorch_pfn_extras/__init__.py b/pytorch_pfn_extras/__init__.py index 7dccbf965..6898f25c4 100644 --- a/pytorch_pfn_extras/__init__.py +++ b/pytorch_pfn_extras/__init__.py @@ -1,11 +1,12 @@ # Configure the logging before instantiating anything else from pytorch_pfn_extras import logging # NOQA + logging._configure_logging() from pytorch_pfn_extras import config # NOQA from pytorch_pfn_extras import cuda # NOQA -from pytorch_pfn_extras import dataset # NOQA from pytorch_pfn_extras import dataloaders # NOQA +from pytorch_pfn_extras import dataset # NOQA from pytorch_pfn_extras import distributed # NOQA from pytorch_pfn_extras import engine # NOQA from pytorch_pfn_extras import handler # NOQA @@ -16,14 +17,12 @@ from pytorch_pfn_extras import training # NOQA from pytorch_pfn_extras import utils # NOQA from pytorch_pfn_extras import writing # NOQA - -from pytorch_pfn_extras._tensor import from_ndarray # NOQA from pytorch_pfn_extras._tensor import as_ndarray # NOQA -from pytorch_pfn_extras._tensor import get_xp # NOQA from pytorch_pfn_extras._tensor import as_numpy_dtype # NOQA +from pytorch_pfn_extras._tensor import from_ndarray # NOQA from pytorch_pfn_extras._tensor import from_numpy_dtype # NOQA -from pytorch_pfn_extras.runtime._to import to # NOQA -from pytorch_pfn_extras.runtime._map import map # NOQA +from pytorch_pfn_extras._tensor import get_xp # NOQA from pytorch_pfn_extras._torch_version import requires # NOQA - from pytorch_pfn_extras._version import __version__ # NOQA +from pytorch_pfn_extras.runtime._map import map # NOQA +from pytorch_pfn_extras.runtime._to import to # NOQA diff --git a/pytorch_pfn_extras/_cupy/__init__.py b/pytorch_pfn_extras/_cupy/__init__.py index 8d66a5704..cb0ad364b 100644 --- a/pytorch_pfn_extras/_cupy/__init__.py +++ b/pytorch_pfn_extras/_cupy/__init__.py @@ -1,15 +1,18 @@ try: import cupy # NOQA + _cupy_import_error = None except Exception as e: from pytorch_pfn_extras._cupy import _cupy_stub as cupy # NOQA + _cupy_import_error = e def ensure_cupy() -> None: if _cupy_import_error is not None: raise RuntimeError( - f'CuPy is not available. Reason:\n{_cupy_import_error}') + f"CuPy is not available. Reason:\n{_cupy_import_error}" + ) def is_available() -> bool: diff --git a/pytorch_pfn_extras/_tensor.py b/pytorch_pfn_extras/_tensor.py index 3c646784f..d94ad3016 100644 --- a/pytorch_pfn_extras/_tensor.py +++ b/pytorch_pfn_extras/_tensor.py @@ -1,11 +1,9 @@ +from typing import Any, Dict, Union + import numpy import torch import torch.utils.dlpack -from typing import Any, Dict, Union - -from pytorch_pfn_extras._cupy import cupy -from pytorch_pfn_extras._cupy import ensure_cupy - +from pytorch_pfn_extras._cupy import cupy, ensure_cupy _NDArray = Any # TypeVar("_NDArray", numpy.ndarray, cupy.ndarray) _NumpyDtype = Any # numpy.dtype @@ -32,8 +30,9 @@ def from_ndarray(ndarray: _NDArray) -> torch.Tensor: elif isinstance(ndarray, numpy.ndarray): return torch.from_numpy(_copy_if_negative_strides(ndarray)) raise TypeError( - 'expected numpy.ndarray or cupy.ndarray ' - f'(got {type(ndarray).__name__})') + "expected numpy.ndarray or cupy.ndarray " + f"(got {type(ndarray).__name__})" + ) def _copy_if_negative_strides(ndarray: _NDArray) -> _NDArray: @@ -53,18 +52,18 @@ def as_ndarray(tensor: torch.Tensor) -> _NDArray: cannot be tracked in the computational graph. """ devtype = tensor.device.type - if devtype == 'cpu': + if devtype == "cpu": return tensor.detach().numpy() - elif devtype == 'cuda': + elif devtype == "cuda": ensure_cupy() - if hasattr(cupy, 'from_dlpack'): + if hasattr(cupy, "from_dlpack"): # TODO: Avoid using ``torch.utils.dlpack.to_dlpack``. # => return cupy.from_dlpack(tensor) # Blocked by PyTorch 1.10 bug # (https://github.com/pytorch/pytorch/pull/67618) return cupy.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) return cupy.fromDlpack(torch.utils.dlpack.to_dlpack(tensor)) - raise ValueError(f'Tensor is on unsupported device: {devtype}') + raise ValueError(f"Tensor is on unsupported device: {devtype}") def get_xp(obj: Union[_NDArray, torch.Tensor]) -> Any: @@ -78,21 +77,22 @@ def get_xp(obj: Union[_NDArray, torch.Tensor]) -> Any: elif isinstance(obj, torch.device): devtype = obj.type elif isinstance(obj, numpy.ndarray): - devtype = 'cpu' + devtype = "cpu" elif isinstance(obj, cupy.ndarray): - devtype = 'cuda' + devtype = "cuda" else: raise TypeError( - 'expected torch.Tensor, torch.device, numpy.ndarray, ' - f'or cupy.ndarray (got {type(obj).__name__})') + "expected torch.Tensor, torch.device, numpy.ndarray, " + f"or cupy.ndarray (got {type(obj).__name__})" + ) - if devtype == 'cpu': + if devtype == "cpu": return numpy - elif devtype == 'cuda': + elif devtype == "cuda": ensure_cupy() return cupy - raise ValueError(f'unsupported device type: {devtype}') + raise ValueError(f"unsupported device type: {devtype}") def as_numpy_dtype(torch_dtype: torch.dtype) -> _NumpyDtype: @@ -107,7 +107,7 @@ def as_numpy_dtype(torch_dtype: torch.dtype) -> _NumpyDtype: """ numpy_dtype = _torch_dtype_mapping.get(torch_dtype, None) if numpy_dtype is None: - raise TypeError(f'NumPy does not support {torch_dtype} equivalent') + raise TypeError(f"NumPy does not support {torch_dtype} equivalent") return numpy_dtype @@ -123,27 +123,26 @@ def from_numpy_dtype(numpy_dtype: _NumpyDtype) -> torch.dtype: """ torch_dtype = _numpy_dtype_mapping.get(numpy_dtype, None) if torch_dtype is None: - raise TypeError(f'PyTorch does not support {numpy_dtype} equivalent') + raise TypeError(f"PyTorch does not support {numpy_dtype} equivalent") return torch_dtype _torch_dtype_mapping: Dict[torch.dtype, _NumpyDtype] = { # https://pytorch.org/docs/stable/tensors.html # https://numpy.org/doc/stable/user/basics.types.html - - torch.float32: numpy.dtype('float32'), - torch.float64: numpy.dtype('float64'), - torch.float16: numpy.dtype('float16'), + torch.float32: numpy.dtype("float32"), + torch.float64: numpy.dtype("float64"), + torch.float16: numpy.dtype("float16"), # unsupported: torch.bfloat16 # unsupported: torch.complex32 - torch.complex64: numpy.dtype('complex64'), - torch.complex128: numpy.dtype('complex128'), - torch.uint8: numpy.dtype('uint8'), - torch.int8: numpy.dtype('int8'), - torch.int16: numpy.dtype('int16'), - torch.int32: numpy.dtype('int32'), - torch.int64: numpy.dtype('int64'), - torch.bool: numpy.dtype('bool'), + torch.complex64: numpy.dtype("complex64"), + torch.complex128: numpy.dtype("complex128"), + torch.uint8: numpy.dtype("uint8"), + torch.int8: numpy.dtype("int8"), + torch.int16: numpy.dtype("int16"), + torch.int32: numpy.dtype("int32"), + torch.int64: numpy.dtype("int64"), + torch.bool: numpy.dtype("bool"), } _numpy_dtype_mapping: Dict[_NumpyDtype, torch.dtype] = { diff --git a/pytorch_pfn_extras/_torch_version.py b/pytorch_pfn_extras/_torch_version.py index 695b637cb..0806d34ae 100644 --- a/pytorch_pfn_extras/_torch_version.py +++ b/pytorch_pfn_extras/_torch_version.py @@ -2,6 +2,6 @@ from packaging.version import Version -def requires(version: str, package: str = 'torch') -> bool: +def requires(version: str, package: str = "torch") -> bool: pkg_ver = pkg_resources.get_distribution(package).version return Version(pkg_ver.split("+")[0].split("-")[0]) >= Version(version) diff --git a/pytorch_pfn_extras/_version.py b/pytorch_pfn_extras/_version.py index 400a10474..f6104e0c2 100644 --- a/pytorch_pfn_extras/_version.py +++ b/pytorch_pfn_extras/_version.py @@ -1 +1 @@ -__version__ = '0.6.7' +__version__ = "0.6.7" diff --git a/pytorch_pfn_extras/config.py b/pytorch_pfn_extras/config.py index adf04b6d0..eafd74835 100644 --- a/pytorch_pfn_extras/config.py +++ b/pytorch_pfn_extras/config.py @@ -1,8 +1,16 @@ import json import os -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union import reprlib - +from typing import ( + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) ConfigKey = Tuple[Union[str, int], ...] AttrKey = Tuple[Union[str, int], ...] @@ -13,20 +21,21 @@ LoadTrace = Tuple[Tuple[str, ConfigKey], ...] -def customize_type(**default_kwargs: Any) -> Callable[ - [Callable[..., Any]], Callable[..., Any]]: +def customize_type( + **default_kwargs: Any, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: def deco(type_: Callable[..., Any]) -> Callable[..., Any]: type_._custom_default_kwargs = default_kwargs # type: ignore[attr-defined] # NOQA return type_ + return deco class Config(object): - def __init__( - self, - config: Any, - types: Optional[Mapping[str, Callable[..., Any]]] = None, + self, + config: Any, + types: Optional[Mapping[str, Callable[..., Any]]] = None, ) -> None: self._cache: Dict[KeyPair, Any] = {((), None): config} self._types = types or {} @@ -36,21 +45,21 @@ def __getitem__(self, key: str) -> Any: @classmethod def load_path( - cls, - path: str, - *, - loader: Optional[Loader] = None, - types: Optional[Mapping[str, Callable[..., Any]]] = None, - ) -> 'Config': + cls, + path: str, + *, + loader: Optional[Loader] = None, + types: Optional[Mapping[str, Callable[..., Any]]] = None, + ) -> "Config": if loader is None: loader = _json_loader return cls(_load(path, loader, ()), types) def _eval( - self, - config_key: ConfigKey, - attr_key: Optional[AttrKey], - trace: DumpTrace, + self, + config_key: ConfigKey, + attr_key: Optional[AttrKey], + trace: DumpTrace, ) -> Any: if (config_key, attr_key) in self._cache: return self._cache[(config_key, attr_key)] @@ -58,9 +67,7 @@ def _eval( circular = (config_key, attr_key) in trace trace = (*trace, (config_key, attr_key)) if circular: - raise RuntimeError( - 'Circular dependency', - _dump_trace(trace)) + raise RuntimeError("Circular dependency", _dump_trace(trace)) def cache(value: Any) -> Any: self._cache[(config_key, attr_key)] = value @@ -69,18 +76,19 @@ def cache(value: Any) -> Any: if attr_key: obj = self._eval(config_key, attr_key[:-1], trace) try: - if isinstance(attr_key[-1], str) \ - and hasattr(obj, attr_key[-1]): + if isinstance(attr_key[-1], str) and hasattr(obj, attr_key[-1]): return cache(getattr(obj, attr_key[-1])) else: return cache(obj[attr_key[-1]]) except Exception as e: e.args = e.args + ( - '{} not in {} ({})'.format( + "{} not in {} ({})".format( attr_key[-1], _dump_key(config_key, attr_key[:-1]), - reprlib.repr(obj)), - _dump_trace(trace)) + reprlib.repr(obj), + ), + _dump_trace(trace), + ) raise e elif attr_key is None: @@ -89,52 +97,64 @@ def cache(value: Any) -> Any: return cache(config[config_key[-1]]) except Exception as e: e.args = e.args + ( - '{} not in {}'.format( - config_key[-1], - _dump_key(config_key[:-1], None)), - _dump_trace(trace)) + "{} not in {}".format( + config_key[-1], _dump_key(config_key[:-1], None) + ), + _dump_trace(trace), + ) raise e else: config = self._eval(config_key, None, trace) if isinstance(config, dict): - if 'type' in config: + if "type" in config: try: - type_ = self._types[config['type']] + type_ = self._types[config["type"]] except Exception as e: e.args = e.args + ( - '{} not in types'.format(config['type']), - _dump_trace(trace)) + "{} not in types".format(config["type"]), + _dump_trace(trace), + ) raise e else: type_ = dict kwargs = {} for k in config.keys(): - if not k == 'type': + if not k == "type": kwargs[k] = self._eval((*config_key, k), (), trace) for k, v in getattr( - type_, '_custom_default_kwargs', {}).items(): + type_, "_custom_default_kwargs", {} + ).items(): if k not in kwargs: kwargs[k] = self._eval( - *_parse_key(v, config_key)[:2], trace) + *_parse_key(v, config_key)[:2], trace + ) try: return cache(type_(**kwargs)) except Exception as e: e.args = e.args + ( - '{} ({}) failed with kwargs {}'.format( - config['type'], type_, reprlib.repr(kwargs)), - _dump_trace(trace)) + "{} ({}) failed with kwargs {}".format( + config["type"], type_, reprlib.repr(kwargs) + ), + _dump_trace(trace), + ) raise e elif isinstance(config, list): - return cache([ - self._eval((*config_key, i), (), trace) - for i in range(len(config))]) - elif isinstance(config, str) and config.startswith('@'): - return cache(self._eval( - *_parse_key(config[1:], config_key[:-1])[:2], trace)) + return cache( + [ + self._eval((*config_key, i), (), trace) + for i in range(len(config)) + ] + ) + elif isinstance(config, str) and config.startswith("@"): + return cache( + self._eval( + *_parse_key(config[1:], config_key[:-1])[:2], trace + ) + ) else: return cache(config) @@ -142,13 +162,12 @@ def update_via_args(self, args: Sequence[Tuple[str, Any]]) -> None: for k, v in args: n_k, c_k = _parse_key(k, ())[:2] if (n_k, c_k) in self._cache: - if ( - isinstance(self._cache[(n_k, c_k)], bool) - and isinstance(v, str) + if isinstance(self._cache[(n_k, c_k)], bool) and isinstance( + v, str ): if not v.lower() in ("true", "false"): raise ValueError( - f'bool should be true/false. Found {v}' + f"bool should be true/false. Found {v}" ) v = v.lower() == "true" self._cache[(n_k, c_k)] = type(self._cache[(n_k, c_k)])(v) @@ -157,22 +176,22 @@ def update_via_args(self, args: Sequence[Tuple[str, Any]]) -> None: def _parse_key( - key: str, current_config_key: ConfigKey + key: str, current_config_key: ConfigKey ) -> Tuple[ConfigKey, Optional[AttrKey], bool]: - if key.startswith('!'): + if key.startswith("!"): key = key[1:] escape = True else: escape = False - if key.startswith('/'): + if key.startswith("/"): key = key[1:] rel = False else: rel = True - config_key_str = key.split('/') - config_key_str[-1], *attr_key_list = config_key_str[-1].split('.') + config_key_str = key.split("/") + config_key_str[-1], *attr_key_list = config_key_str[-1].split(".") config_key = [_parse_k(k) for k in config_key_str] attr_key: Optional[AttrKey] = tuple(_parse_k(k) for k in attr_key_list) @@ -186,9 +205,9 @@ def _parse_key( i = 0 while i < len(config_key): - if config_key[i] in {'', '.'}: + if config_key[i] in {"", "."}: config_key.pop(i) - elif config_key[i] == '..': + elif config_key[i] == "..": assert i > 0 config_key.pop(i) config_key.pop(i - 1) @@ -207,21 +226,21 @@ def _parse_k(k: str) -> Union[str, int]: def _dump_key(config_key: ConfigKey, attr_key: Optional[AttrKey]) -> str: - config_key_str = '/' + '/'.join(str(k) for k in config_key) + config_key_str = "/" + "/".join(str(k) for k in config_key) if attr_key: - attr_key_str = '.'.join(str(k) for k in attr_key) - return config_key_str + '.' + attr_key_str + attr_key_str = ".".join(str(k) for k in attr_key) + return config_key_str + "." + attr_key_str elif attr_key is None: - return '!' + config_key_str + return "!" + config_key_str else: return config_key_str def _dump_trace(trace: DumpTrace) -> str: - return ' -> '.join( - _dump_key(config_key, attr_key) - for config_key, attr_key in trace) + return " -> ".join( + _dump_key(config_key, attr_key) for config_key, attr_key in trace + ) def _load(path: str, loader: Loader, trace: LoadTrace) -> ConfigType: @@ -230,31 +249,37 @@ def _load(path: str, loader: Loader, trace: LoadTrace) -> ConfigType: trace = (*trace, (path, ())) if circular: raise RuntimeError( - 'Circular import', - ' -> '.join('{} of {}'.format(_dump_key(config_key, None), path) - for path, config_key in trace)) + "Circular import", + " -> ".join( + "{} of {}".format(_dump_key(config_key, None), path) + for path, config_key in trace + ), + ) config = loader(path) return _expand_import(config, os.path.dirname(path), loader, trace) def _expand_import( - config: ConfigType, - workdir: str, - loader: Loader, - trace: LoadTrace, + config: ConfigType, + workdir: str, + loader: Loader, + trace: LoadTrace, ) -> ConfigType: path, config_key = trace[-1] if isinstance(config, dict): - config = {k: _expand_import(v, workdir, loader, - (*trace, (path, (*config_key, k)))) - for k, v in config.items()} - if 'import' in config: - path = config['import'] + config = { + k: _expand_import( + v, workdir, loader, (*trace, (path, (*config_key, k))) + ) + for k, v in config.items() + } + if "import" in config: + path = config["import"] if not os.path.isabs(path): path = os.path.join(workdir, path) config_orig, config = config, _load(path, loader, trace) for k, v in config_orig.items(): - if k == 'import': + if k == "import": continue config_key, attr_key, rel = _parse_key(k, ()) assert attr_key == () @@ -267,15 +292,20 @@ def _expand_import( c[config_key[-1]] = v except Exception as e: e.args = e.args + ( - '{} not in {}'.format( - _dump_key(config_key, attr_key), path),) + "{} not in {}".format( + _dump_key(config_key, attr_key), path + ), + ) raise e return config elif isinstance(config, list): - return [_expand_import(v, workdir, loader, - (*trace, (path, (*config_key, i)))) - for i, v in enumerate(config)] + return [ + _expand_import( + v, workdir, loader, (*trace, (path, (*config_key, i))) + ) + for i, v in enumerate(config) + ] else: return config diff --git a/pytorch_pfn_extras/config_types.py b/pytorch_pfn_extras/config_types.py index d5060b727..61cfdf148 100644 --- a/pytorch_pfn_extras/config_types.py +++ b/pytorch_pfn_extras/config_types.py @@ -1,14 +1,13 @@ import warnings -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from pytorch_pfn_extras import config - if TYPE_CHECKING: import optuna -def optuna_types(trial: 'optuna.trial.Trial') -> Dict[str, Any]: +def optuna_types(trial: "optuna.trial.Trial") -> Dict[str, Any]: types = { "optuna_suggest_categorical": trial.suggest_categorical, "optuna_suggest_discrete_uniform": trial.suggest_discrete_uniform, @@ -21,15 +20,15 @@ def optuna_types(trial: 'optuna.trial.Trial') -> Dict[str, Any]: def load_path_with_optuna_types( - path: str, - trial: 'optuna.trial.Trial', - loader: Optional[config.Loader] = None, - types: Optional[Dict[str, Callable[..., Any]]] = None, + path: str, + trial: "optuna.trial.Trial", + loader: Optional[config.Loader] = None, + types: Optional[Dict[str, Callable[..., Any]]] = None, ) -> config.Config: if types is None: types = {} for key, value in optuna_types(trial).items(): if key in types: - warnings.warn(key + ' is overwritten by optuna suggest.') + warnings.warn(key + " is overwritten by optuna suggest.") types[key] = value return config.Config.load_path(path, loader=loader, types=types) diff --git a/pytorch_pfn_extras/cuda/__init__.py b/pytorch_pfn_extras/cuda/__init__.py index e13ff6773..61642511e 100644 --- a/pytorch_pfn_extras/cuda/__init__.py +++ b/pytorch_pfn_extras/cuda/__init__.py @@ -1,3 +1,5 @@ from pytorch_pfn_extras.cuda._allocator import stream # NOQA from pytorch_pfn_extras.cuda._allocator import use_torch_mempool_in_cupy # NOQA -from pytorch_pfn_extras.cuda._allocator import use_default_mempool_in_cupy # NOQA +from pytorch_pfn_extras.cuda._allocator import ( # NOQA + use_default_mempool_in_cupy, +) diff --git a/pytorch_pfn_extras/cuda/_allocator.py b/pytorch_pfn_extras/cuda/_allocator.py index e57365b70..12b363013 100644 --- a/pytorch_pfn_extras/cuda/_allocator.py +++ b/pytorch_pfn_extras/cuda/_allocator.py @@ -2,10 +2,7 @@ from typing import Any, Generator, Optional import torch - -from pytorch_pfn_extras._cupy import cupy -from pytorch_pfn_extras._cupy import is_available, ensure_cupy - +from pytorch_pfn_extras._cupy import cupy, ensure_cupy, is_available _allocator = None @@ -49,7 +46,8 @@ def use_torch_mempool_in_cupy() -> None: ensure_cupy() _allocator = cupy.cuda.memory.PythonFunctionAllocator( - _torch_alloc, _torch_free) + _torch_alloc, _torch_free + ) cupy.cuda.set_allocator(_allocator.malloc) @@ -58,11 +56,11 @@ def _torch_alloc(size: int, device_id: int) -> Any: cupy_stream_ptr = cupy.cuda.get_current_stream().ptr if torch_stream_ptr != cupy_stream_ptr: raise RuntimeError( - 'The current stream set in PyTorch and CuPy must be same.' - ' Use `pytorch_pfn_extras.cuda.stream` instead of' - ' `torch.cuda.stream`.') - return torch.cuda.caching_allocator_alloc( - size, device_id, torch_stream_ptr) + "The current stream set in PyTorch and CuPy must be same." + " Use `pytorch_pfn_extras.cuda.stream` instead of" + " `torch.cuda.stream`." + ) + return torch.cuda.caching_allocator_alloc(size, device_id, torch_stream_ptr) def _torch_free(mem_ptr: int, device_id: int) -> None: diff --git a/pytorch_pfn_extras/dataloaders/__init__.py b/pytorch_pfn_extras/dataloaders/__init__.py index 8a7549391..7f6abee2f 100644 --- a/pytorch_pfn_extras/dataloaders/__init__.py +++ b/pytorch_pfn_extras/dataloaders/__init__.py @@ -1,2 +1,2 @@ -from pytorch_pfn_extras.dataloaders.dataloader import DataLoader # NOQA from pytorch_pfn_extras.dataloaders import utils # NOQA +from pytorch_pfn_extras.dataloaders.dataloader import DataLoader # NOQA diff --git a/pytorch_pfn_extras/dataloaders/utils.py b/pytorch_pfn_extras/dataloaders/utils.py index 1bc338c03..f50fdbd1b 100644 --- a/pytorch_pfn_extras/dataloaders/utils.py +++ b/pytorch_pfn_extras/dataloaders/utils.py @@ -15,9 +15,11 @@ class CollateAsDict: """ def __init__( - self, names: Sequence[str], - collate_fn: Callable[..., Any] = - torch.utils.data._utils.collate.default_collate, + self, + names: Sequence[str], + collate_fn: Callable[ + ..., Any + ] = torch.utils.data._utils.collate.default_collate, ) -> None: self.names = names self.collate_fn = collate_fn diff --git a/pytorch_pfn_extras/dataset/__init__.py b/pytorch_pfn_extras/dataset/__init__.py index fa29a4a3b..b17bfe96e 100644 --- a/pytorch_pfn_extras/dataset/__init__.py +++ b/pytorch_pfn_extras/dataset/__init__.py @@ -1,3 +1,7 @@ -from pytorch_pfn_extras.dataset.tabular.tabular_dataset import TabularDataset # NOQA -from pytorch_pfn_extras.dataset.shared_dataset import ItemNotFoundException # NOQA from pytorch_pfn_extras.dataset.shared_dataset import SharedDataset # NOQA +from pytorch_pfn_extras.dataset.shared_dataset import ( # NOQA + ItemNotFoundException, +) +from pytorch_pfn_extras.dataset.tabular.tabular_dataset import ( # NOQA + TabularDataset, +) diff --git a/pytorch_pfn_extras/dataset/shared_dataset.py b/pytorch_pfn_extras/dataset/shared_dataset.py index 48ad8011e..0d91ad5a5 100644 --- a/pytorch_pfn_extras/dataset/shared_dataset.py +++ b/pytorch_pfn_extras/dataset/shared_dataset.py @@ -52,10 +52,11 @@ class ItemNotFoundException(Exception): class SharedDataset(torch.utils.data.Dataset): - """ Dataset that caches the load samples in shared memory + """Dataset that caches the load samples in shared memory Args """ + def __init__(self, sm_size, cache_type=InfiniteCache): super().__init__() self.cache = cache_type(sm_size) @@ -64,7 +65,8 @@ def __getitem__(self, idx): x = self.cache.get_value(idx) if x is None: raise ItemNotFoundException( - 'Item {} is not in the cache'.format(idx)) + "Item {} is not in the cache".format(idx) + ) return x def is_cached(self, idx): diff --git a/pytorch_pfn_extras/dataset/tabular/__init__.py b/pytorch_pfn_extras/dataset/tabular/__init__.py index 79f19e53c..49b1898ab 100644 --- a/pytorch_pfn_extras/dataset/tabular/__init__.py +++ b/pytorch_pfn_extras/dataset/tabular/__init__.py @@ -4,6 +4,7 @@ from pytorch_pfn_extras.dataset.tabular import _slice # NOQA from pytorch_pfn_extras.dataset.tabular import _transform # NOQA from pytorch_pfn_extras.dataset.tabular import _with_converter # NOQA - -from pytorch_pfn_extras.dataset.tabular.delegate_dataset import DelegateDataset # NOQA +from pytorch_pfn_extras.dataset.tabular.delegate_dataset import ( # NOQA + DelegateDataset, +) from pytorch_pfn_extras.dataset.tabular.from_data import from_data # NOQA diff --git a/pytorch_pfn_extras/dataset/tabular/_asmode.py b/pytorch_pfn_extras/dataset/tabular/_asmode.py index 543b61f52..82ae2460e 100644 --- a/pytorch_pfn_extras/dataset/tabular/_asmode.py +++ b/pytorch_pfn_extras/dataset/tabular/_asmode.py @@ -4,7 +4,6 @@ class _Astuple(tabular_dataset.TabularDataset): - def __init__(self, dataset): self._dataset = dataset @@ -27,7 +26,6 @@ def convert(self, data): class _Asdict(tabular_dataset.TabularDataset): - def __init__(self, dataset): self._dataset = dataset diff --git a/pytorch_pfn_extras/dataset/tabular/_concat.py b/pytorch_pfn_extras/dataset/tabular/_concat.py index 3ab36b488..43ef08c60 100644 --- a/pytorch_pfn_extras/dataset/tabular/_concat.py +++ b/pytorch_pfn_extras/dataset/tabular/_concat.py @@ -4,11 +4,10 @@ class _Concat(tabular_dataset.TabularDataset): - def __init__(self, *datasets): for dataset in datasets[1:]: if not dataset.keys == datasets[0].keys: - raise ValueError('All datasets must have the same keys') + raise ValueError("All datasets must have the same keys") self._datasets = datasets @@ -32,12 +31,16 @@ def get_examples(self, indices, key_indices): if indices is None: examples = [ dataset.get_examples(None, key_indices) - for dataset in self._datasets] + for dataset in self._datasets + ] return tuple( - [data - for sub_examples in examples - for data in sub_examples[col_index]] - for col_index in range(n_cols)) + [ + data + for sub_examples in examples + for data in sub_examples[col_index] + ] + for col_index in range(n_cols) + ) elif isinstance(indices, slice): start, stop, step = indices.indices(len(self)) @@ -53,8 +56,9 @@ def get_examples(self, indices, key_indices): sub_stop = min(sub_stop, len(dataset)) else: if sub_start >= len(dataset): - sub_start = \ + sub_start = ( len(dataset) + (sub_start - len(dataset)) % step + ) sub_stop = max(sub_stop, -1) if len(range(sub_start, sub_stop, step)) > 0: @@ -62,8 +66,11 @@ def get_examples(self, indices, key_indices): sub_start = None if sub_stop < 0 and step < 0: sub_stop = None - examples.append(dataset.get_examples( - slice(sub_start, sub_stop, step), key_indices)) + examples.append( + dataset.get_examples( + slice(sub_start, sub_stop, step), key_indices + ) + ) offset += len(dataset) @@ -75,10 +82,13 @@ def get_examples(self, indices, key_indices): if step < 0: examples.reverse() return tuple( - [data - for sub_examples in examples - for data in sub_examples[col_index]] - for col_index in range(n_cols)) + [ + data + for sub_examples in examples + for data in sub_examples[col_index] + ] + for col_index in range(n_cols) + ) else: examples = {} @@ -90,12 +100,12 @@ def get_examples(self, indices, key_indices): if index < offset or offset + len(dataset) <= index: continue sub_indices.append(index - offset) - example_indices[p] = ( - dataset_index, len(sub_indices) - 1) + example_indices[p] = (dataset_index, len(sub_indices) - 1) if len(sub_indices) > 0: examples[dataset_index] = dataset.get_examples( - sub_indices, key_indices) + sub_indices, key_indices + ) offset += len(dataset) @@ -105,9 +115,12 @@ def get_examples(self, indices, key_indices): return list(examples.values())[0] else: return tuple( - [examples[dataset_index][col_index][p] - for dataset_index, p in example_indices] - for col_index in range(n_cols)) + [ + examples[dataset_index][col_index][p] + for dataset_index, p in example_indices + ] + for col_index in range(n_cols) + ) def convert(self, data): return self._datasets[0].convert(data) diff --git a/pytorch_pfn_extras/dataset/tabular/_join.py b/pytorch_pfn_extras/dataset/tabular/_join.py index f04cbaa2c..e99541f8b 100644 --- a/pytorch_pfn_extras/dataset/tabular/_join.py +++ b/pytorch_pfn_extras/dataset/tabular/_join.py @@ -4,14 +4,13 @@ class _Join(tabular_dataset.TabularDataset): - def __init__(self, *datasets): keys = set(datasets[0].keys) for dataset in datasets[1:]: if not len(dataset) == len(datasets[0]): - raise ValueError('All datasets must have the same length') + raise ValueError("All datasets must have the same length") if len(keys.intersection(dataset.keys)) > 0: - raise ValueError('All keys must be unique among all datasets') + raise ValueError("All keys must be unique among all datasets") keys = keys.union(dataset.keys) self._datasets = datasets @@ -35,7 +34,8 @@ def get_examples(self, indices, key_indices): return tuple( col for dataset in self._datasets - for col in dataset.get_examples(indices, None)) + for col in dataset.get_examples(indices, None) + ) examples = {} key_offset = 0 @@ -52,7 +52,8 @@ def get_examples(self, indices, key_indices): sub_key_indices = tuple(sub_key_indices) sub_examples = dataset.get_examples(indices, sub_key_indices) for sub_key_index, col_example in zip( - sub_key_indices, sub_examples): + sub_key_indices, sub_examples + ): examples[key_offset + sub_key_index] = col_example key_offset += len(dataset.keys) diff --git a/pytorch_pfn_extras/dataset/tabular/_slice.py b/pytorch_pfn_extras/dataset/tabular/_slice.py index 2eb02c89d..2c3fa222a 100644 --- a/pytorch_pfn_extras/dataset/tabular/_slice.py +++ b/pytorch_pfn_extras/dataset/tabular/_slice.py @@ -1,11 +1,9 @@ # mypy: ignore-errors -from pytorch_pfn_extras.dataset.tabular import tabular_dataset -from pytorch_pfn_extras.dataset.tabular import _utils +from pytorch_pfn_extras.dataset.tabular import _utils, tabular_dataset class _Slice(tabular_dataset.TabularDataset): - def __init__(self, dataset, indices, keys): if keys is None: self._unary = None @@ -13,7 +11,7 @@ def __init__(self, dataset, indices, keys): self._unary = False else: self._unary = True - keys = keys, + keys = (keys,) self._dataset = dataset self._indices = _utils._as_indices(indices, len(dataset)) @@ -33,8 +31,9 @@ def keys(self): if self._key_indices is None: return self._dataset.keys else: - return tuple(self._dataset.keys[key_index] - for key_index in self._key_indices) + return tuple( + self._dataset.keys[key_index] for key_index in self._key_indices + ) @property def mode(self): @@ -47,7 +46,8 @@ def mode(self): def get_examples(self, indices, key_indices): indices = _utils._merge_indices( - self._indices, indices, len(self._dataset), len(self)) + self._indices, indices, len(self._dataset), len(self) + ) key_indices = _utils._merge_key_indices(self._key_indices, key_indices) return self._dataset.get_examples(indices, key_indices) @@ -56,7 +56,6 @@ def convert(self, data): class _SliceHelper(object): - def __init__(self, dataset): self._dataset = dataset diff --git a/pytorch_pfn_extras/dataset/tabular/_transform.py b/pytorch_pfn_extras/dataset/tabular/_transform.py index 0cb21a367..b4fcc8b9f 100644 --- a/pytorch_pfn_extras/dataset/tabular/_transform.py +++ b/pytorch_pfn_extras/dataset/tabular/_transform.py @@ -1,7 +1,6 @@ # mypy: ignore-errors -from pytorch_pfn_extras.dataset.tabular import tabular_dataset -from pytorch_pfn_extras.dataset.tabular import _utils +from pytorch_pfn_extras.dataset.tabular import _utils, tabular_dataset class _TransformBase(tabular_dataset.TabularDataset): @@ -12,14 +11,15 @@ def __init__(self, dataset, keys, transforms): self._transforms = [] for s, t in transforms: if any(k in key_set for k in s[1]): - raise ValueError('Transformations must be disjoint') + raise ValueError("Transformations must be disjoint") key_set.update(s[1]) ops_idx = _utils._as_key_indices(s[0], self._dataset.keys) res_idx = _utils._as_key_indices(s[1], keys) self._transforms.append(((ops_idx, res_idx), t)) if key_set != set(keys): raise ValueError( - 'Transformations must produce only all specified keys') + "Transformations must produce only all specified keys" + ) self._keys = keys @@ -70,7 +70,6 @@ def convert(self, data): class _Transform(_TransformBase): - def get_examples(self, indices, key_indices): if key_indices is None: key_indices = range(len(self._keys)) @@ -94,9 +93,7 @@ def get_examples(self, indices, key_indices): out_example = transform(*inputs) elif self._dataset.mode is dict: keys = [self._dataset.keys[i] for i in ops_idx] - out_example = transform( - **dict(zip(keys, inputs)) - ) + out_example = transform(**dict(zip(keys, inputs))) elif self._dataset.mode is None: out_example = transform(*inputs) if isinstance(out_example, tuple): @@ -113,7 +110,8 @@ def get_examples(self, indices, key_indices): # we are slicing the outputs using key_indices # the result key index needs to be recalculated out_examples[key_indices.index(key_index)].append( - out_example[col_index]) + out_example[col_index] + ) elif isinstance(out_example, dict): if hasattr(self, "_mode") and self._mode is not dict: raise ValueError( @@ -138,7 +136,8 @@ def get_examples(self, indices, key_indices): if key_index is None: continue out_examples[key_indices.index(key_index)].append( - out_example[col_index]) + out_example[col_index] + ) return out_examples @@ -147,7 +146,6 @@ def convert(self, data): class _TransformBatch(_TransformBase): - def get_examples(self, indices, key_indices): if indices is None: len_ = len(self) @@ -169,9 +167,7 @@ def get_examples(self, indices, key_indices): out_example = transform(*inputs) elif self._dataset.mode is dict: keys = [self._dataset.keys[i] for i in ops_idx] - out_example = transform( - **dict(zip(keys, inputs)) - ) + out_example = transform(**dict(zip(keys, inputs))) elif self._dataset.mode is None: out_example = transform(*inputs) @@ -192,8 +188,9 @@ def get_examples(self, indices, key_indices): # all the outputs are covered this works but when # we are slicing the outputs using key_indices # the result key index needs to be recalculated - out_examples[key_indices.index(key_index)] = ( - out_example[col_index]) + out_examples[key_indices.index(key_index)] = out_example[ + col_index + ] elif isinstance(out_example, dict): if hasattr(self, "_mode") and self._mode is not dict: raise ValueError( @@ -208,8 +205,9 @@ def get_examples(self, indices, key_indices): if key_index is None: continue key = self._keys[key_index] - out_examples[key_indices.index(key_index)] = ( - out_example[key]) + out_examples[key_indices.index(key_index)] = out_example[ + key + ] else: if hasattr(self, "_mode") and self._mode is not None: raise ValueError( @@ -224,6 +222,7 @@ def get_examples(self, indices, key_indices): for col_index, key_index in enumerate(t_res_idx): if key_index is None: continue - out_examples[key_indices.index(key_index)] = ( - out_example[col_index]) + out_examples[key_indices.index(key_index)] = out_example[ + col_index + ] return tuple(out_examples) diff --git a/pytorch_pfn_extras/dataset/tabular/_utils.py b/pytorch_pfn_extras/dataset/tabular/_utils.py index ed551d7fd..1a4dddf1a 100644 --- a/pytorch_pfn_extras/dataset/tabular/_utils.py +++ b/pytorch_pfn_extras/dataset/tabular/_utils.py @@ -11,8 +11,10 @@ def _as_indices(indices, len_): if all(isinstance(index, (bool, np.bool_)) for index in indices): if not len(indices) == len_: - raise ValueError('The number of booleans is ' - 'different from the length of dataset') + raise ValueError( + "The number of booleans is " + "different from the length of dataset" + ) return [i for i, index in enumerate(indices) if index] else: checked_indices = [] @@ -22,8 +24,10 @@ def _as_indices(indices, len_): index += len_ if index < 0 or len_ <= index: raise IndexError( - 'index {} is out of bounds for dataset with size {}' - .format(index, len_)) + "index {} is out of bounds for dataset with size {}".format( + index, len_ + ) + ) checked_indices.append(index) return checked_indices @@ -40,13 +44,15 @@ def _as_key_indices(keys, key_names): key_index += len(key_names) if key_index < 0 or len(key_names) <= key_index: raise IndexError( - 'index {} is out of bounds for keys with size {}'.format( - key, len(key_names))) + "index {} is out of bounds for keys with size {}".format( + key, len(key_names) + ) + ) else: try: key_index = key_names.index(key) except ValueError: - raise KeyError('{} does not exists'.format(key)) + raise KeyError("{} does not exists".format(key)) key_indices.append(key_index) return tuple(key_indices) diff --git a/pytorch_pfn_extras/dataset/tabular/_with_converter.py b/pytorch_pfn_extras/dataset/tabular/_with_converter.py index c8d7b67e4..4782a5c52 100644 --- a/pytorch_pfn_extras/dataset/tabular/_with_converter.py +++ b/pytorch_pfn_extras/dataset/tabular/_with_converter.py @@ -4,7 +4,6 @@ class _WithConverter(tabular_dataset.TabularDataset): - def __init__(self, dataset, converter): self._dataset = dataset self._converter = converter diff --git a/pytorch_pfn_extras/dataset/tabular/from_data.py b/pytorch_pfn_extras/dataset/tabular/from_data.py index 165199b74..6d7c499ab 100644 --- a/pytorch_pfn_extras/dataset/tabular/from_data.py +++ b/pytorch_pfn_extras/dataset/tabular/from_data.py @@ -93,20 +93,20 @@ def from_data(data, *, size=None): def _make_dataset(key, data, size): if isinstance(data, (numpy.ndarray, torch.Tensor)): if key is None: - key = '_{}'.format(id(data)) + key = "_{}".format(id(data)) return _Array(key, data) elif isinstance(data, list): if key is None: - key = '_{}'.format(id(data)) + key = "_{}".format(id(data)) return _List(key, data) elif callable(data): if key is None: - raise ValueError('key(s) must be specified for callable') + raise ValueError("key(s) must be specified for callable") if size is None: - raise ValueError('size must be specified for callable') + raise ValueError("size must be specified for callable") dataset = _Index(size) if isinstance(key, str): - key = key, + key = (key,) if not isinstance(key, tuple): key = tuple(key) data = [((dataset.keys, key), data)] @@ -114,7 +114,6 @@ def _make_dataset(key, data, size): class _Array(tabular_dataset.TabularDataset): - def __init__(self, key, data): self._key = key self._data = data @@ -124,7 +123,7 @@ def __len__(self): @property def keys(self): - return self._key, + return (self._key,) @property def mode(self): @@ -132,7 +131,7 @@ def mode(self): def get_examples(self, indices, key_indices): if key_indices is None: - key_indices = 0, + key_indices = (0,) if indices is None: return (self._data,) * len(key_indices) @@ -141,7 +140,6 @@ def get_examples(self, indices, key_indices): class _List(tabular_dataset.TabularDataset): - def __init__(self, key, data): self._key = key self._data = data @@ -151,7 +149,7 @@ def __len__(self): @property def keys(self): - return self._key, + return (self._key,) @property def mode(self): @@ -159,19 +157,19 @@ def mode(self): def get_examples(self, indices, key_indices): if key_indices is None: - key_indices = 0, + key_indices = (0,) if indices is None: return (self._data,) * len(key_indices) elif isinstance(indices, slice): return (self._data[indices],) * len(key_indices) else: - return ([self._data[index] for index in indices],) \ - * len(key_indices) + return ([self._data[index] for index in indices],) * len( + key_indices + ) class _Index(tabular_dataset.TabularDataset): - def __init__(self, size): self._len = size @@ -180,7 +178,7 @@ def __len__(self): @property def keys(self): - return 'index', + return ("index",) @property def mode(self): @@ -194,6 +192,6 @@ def get_examples(self, indices, key_indices): indices = list(range(start, stop, step)) if key_indices is None: - key_indices = 0, + key_indices = (0,) return (indices,) * len(key_indices) diff --git a/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py b/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py index c003cd690..8b7abc212 100644 --- a/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py +++ b/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py @@ -1,9 +1,8 @@ # mypy: ignore-errors import numpy -import torch - import pytorch_pfn_extras as ppe +import torch from torch.utils.data import Dataset @@ -201,8 +200,7 @@ def concat(self, *datasets): Returns: A concatenated dataset. """ - return ppe.dataset.tabular._concat._Concat( - self, *datasets) + return ppe.dataset.tabular._concat._Concat(self, *datasets) def join(self, *datasets): """Stack datasets along columns. @@ -243,8 +241,7 @@ def transform(self, keys, transform): Returns: A transfromed dataset. """ - return ppe.dataset.tabular._transform._Transform( - self, keys, transform) + return ppe.dataset.tabular._transform._Transform(self, keys, transform) def transform_batch(self, keys, transform_batch): """Apply a transform to examples. @@ -274,7 +271,8 @@ def transform_batch(self, keys, transform_batch): A transfromed dataset. """ return ppe.dataset.tabular._transform._TransformBatch( - self, keys, transform_batch) + self, keys, transform_batch + ) def with_converter(self, converter): """Override the behaviour of :meth:`convert`. @@ -289,7 +287,8 @@ def with_converter(self, converter): """ return ppe.dataset.tabular._with_converter._WithConverter( - self, converter) + self, converter + ) def get_example(self, i): example = self.get_examples([i], None) @@ -320,8 +319,7 @@ def __getitem__(self, index): """ if isinstance(index, slice): current, stop, step = index.indices(len(self)) - return [self.get_example(i) for i in - range(current, stop, step)] + return [self.get_example(i) for i in range(current, stop, step)] elif isinstance(index, list) or isinstance(index, numpy.ndarray): return [self.get_example(i) for i in index] else: diff --git a/pytorch_pfn_extras/distributed/__init__.py b/pytorch_pfn_extras/distributed/__init__.py index 3ae322192..d2d43f287 100644 --- a/pytorch_pfn_extras/distributed/__init__.py +++ b/pytorch_pfn_extras/distributed/__init__.py @@ -1,3 +1,9 @@ -from pytorch_pfn_extras.distributed._dataset_util import create_distributed_subset_indices # NOQA -from pytorch_pfn_extras.distributed._distributed_validation_sampler import DistributedValidationSampler # NOQA -from pytorch_pfn_extras.distributed._initialize import initialize_ompi_environment # NOQA +from pytorch_pfn_extras.distributed._dataset_util import ( # NOQA + create_distributed_subset_indices, +) +from pytorch_pfn_extras.distributed._distributed_validation_sampler import ( # NOQA + DistributedValidationSampler, +) +from pytorch_pfn_extras.distributed._initialize import ( # NOQA + initialize_ompi_environment, +) diff --git a/pytorch_pfn_extras/distributed/_dataset_util.py b/pytorch_pfn_extras/distributed/_dataset_util.py index 2ece1f280..cdd3db571 100644 --- a/pytorch_pfn_extras/distributed/_dataset_util.py +++ b/pytorch_pfn_extras/distributed/_dataset_util.py @@ -5,7 +5,7 @@ def _shared_random_seed() -> int: - seed = torch.randint(0, 2 ** 31, size=()) + seed = torch.randint(0, 2**31, size=()) if torch.distributed.is_initialized(): # type: ignore if torch.distributed.get_backend() == "nccl": # type: ignore seed = seed.cuda() diff --git a/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py b/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py index f4165ce71..bbf1bd2ab 100644 --- a/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py +++ b/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py @@ -1,11 +1,10 @@ -from typing import TypeVar, Optional, Iterator, Sized +from typing import Iterator, Optional, Sized, TypeVar import numpy as np import torch import torch.distributed as dist - -T_co = TypeVar('T_co', covariant=True) +T_co = TypeVar("T_co", covariant=True) class DistributedValidationSampler(torch.utils.data.Sampler): @@ -18,23 +17,31 @@ class DistributedValidationSampler(torch.utils.data.Sampler): so for training do not use this sampler (use PyTorch DistributedSampler instead). """ - def __init__(self, - dataset: Sized, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = 0) -> None: + def __init__( + self, + dataset: Sized, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + ) -> None: if num_replicas is None: if not dist.is_available(): # type: ignore[no-untyped-call] - raise RuntimeError("Requires distributed package to be available") + raise RuntimeError( + "Requires distributed package to be available" + ) num_replicas = dist.get_world_size() # type: ignore[no-untyped-call] if rank is None: if not dist.is_available(): # type: ignore[no-untyped-call] - raise RuntimeError("Requires distributed package to be available") + raise RuntimeError( + "Requires distributed package to be available" + ) rank = dist.get_rank() # type: ignore[no-untyped-call] if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" - " [0, {}]".format(rank, num_replicas - 1)) + " [0, {}]".format(rank, num_replicas - 1) + ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -42,7 +49,9 @@ def __init__(self, self.seed = seed self.dataset_len = len(dataset) - self.num_samples = len(np.array_split(range(self.dataset_len), num_replicas)[rank]) + self.num_samples = len( + np.array_split(range(self.dataset_len), num_replicas)[rank] + ) def __iter__(self) -> Iterator[T_co]: if self.shuffle: diff --git a/pytorch_pfn_extras/distributed/_initialize.py b/pytorch_pfn_extras/distributed/_initialize.py index 8f5a64285..9d09c122b 100644 --- a/pytorch_pfn_extras/distributed/_initialize.py +++ b/pytorch_pfn_extras/distributed/_initialize.py @@ -12,7 +12,7 @@ def initialize_ompi_environment( rank: int = 0, local_rank: int = 0, addr: str = "localhost", - port: str = "1234" + port: str = "1234", ) -> Tuple[int, int, int]: """Initialize `torch.distributed` environments with values taken from OpenMPI. @@ -42,9 +42,10 @@ def initialize_ompi_environment( addr = e.get("MASTER_ADDR", addr) port = e.get("MASTER_PORT", port) - if backend not in ("gloo" ,"nccl"): + if backend not in ("gloo", "nccl"): raise ValueError( - "Invalid value for backend, only 'gloo' and 'nccl' are supported") + "Invalid value for backend, only 'gloo' and 'nccl' are supported" + ) if init_method == "env": init_method = "env://" e["MASTER_ADDR"] = addr @@ -56,12 +57,12 @@ def initialize_ompi_environment( init_method = f"tcp://{addr}:{port}" else: raise ValueError( - "Invalid value for init_method, only 'env' and 'tcp' are supported") + "Invalid value for init_method, only 'env' and 'tcp' are supported" + ) if world_size > 1 and not torch.distributed.is_initialized(): # type: ignore torch.distributed.init_process_group( # type: ignore - backend, init_method=init_method, - world_size=world_size, rank=rank + backend, init_method=init_method, world_size=world_size, rank=rank ) torch.distributed.barrier() # type: ignore diff --git a/pytorch_pfn_extras/engine.py b/pytorch_pfn_extras/engine.py index ed7a8c58c..3db3eea18 100644 --- a/pytorch_pfn_extras/engine.py +++ b/pytorch_pfn_extras/engine.py @@ -1,48 +1,62 @@ from typing import ( - Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, - Union, TYPE_CHECKING, + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, ) -import torch - import pytorch_pfn_extras.handler as handler_module +import torch from pytorch_pfn_extras.runtime import runtime_registry from pytorch_pfn_extras.training._transform_model import default_transform_model if TYPE_CHECKING: + from pytorch_pfn_extras import writing from pytorch_pfn_extras.runtime._runtime import DeviceLike from pytorch_pfn_extras.training import extension - from pytorch_pfn_extras.training.trigger import TriggerLike - from pytorch_pfn_extras.training._trainer import Trainer from pytorch_pfn_extras.training._evaluator import Evaluator + from pytorch_pfn_extras.training._trainer import Trainer from pytorch_pfn_extras.training.metrics import MetricType - from pytorch_pfn_extras import writing + from pytorch_pfn_extras.training.trigger import TriggerLike def create_trainer( - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - optimizers: Union[torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]], - max_epochs: int, - *, - extensions: Optional[Sequence[Union['extension.ExtensionLike', - 'extension.ExtensionEntry']]] = None, - out_dir: str = 'result', - stop_trigger: 'TriggerLike' = None, - writer: Optional['writing.Writer'] = None, - evaluator: Optional[Union[ - 'Evaluator', Tuple['Evaluator', 'TriggerLike'], - Mapping[str, Union['Evaluator', Tuple['Evaluator', 'TriggerLike']]] - ]] = None, - device: 'DeviceLike' = 'cpu', - logic: Optional[handler_module.BaseLogic] = None, - transform_model: Callable[ - [str, torch.nn.Module], torch.nn.Module] = default_transform_model, - handler_class: Optional[Type[handler_module.Handler]] = None, - options: Optional[Dict[str, Any]] = None, - runtime_options: Optional[Mapping[str, Any]] = None, - profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] - **kwargs: Any, -) -> 'Trainer': + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + optimizers: Union[ + torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer] + ], + max_epochs: int, + *, + extensions: Optional[ + Sequence[Union["extension.ExtensionLike", "extension.ExtensionEntry"]] + ] = None, + out_dir: str = "result", + stop_trigger: "TriggerLike" = None, + writer: Optional["writing.Writer"] = None, + evaluator: Optional[ + Union[ + "Evaluator", + Tuple["Evaluator", "TriggerLike"], + Mapping[str, Union["Evaluator", Tuple["Evaluator", "TriggerLike"]]], + ] + ] = None, + device: "DeviceLike" = "cpu", + logic: Optional[handler_module.BaseLogic] = None, + transform_model: Callable[ + [str, torch.nn.Module], torch.nn.Module + ] = default_transform_model, + handler_class: Optional[Type[handler_module.Handler]] = None, + options: Optional[Dict[str, Any]] = None, + runtime_options: Optional[Mapping[str, Any]] = None, + profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] + **kwargs: Any, +) -> "Trainer": """Creates a trainer object. Args: @@ -94,13 +108,14 @@ def create_trainer( options = options.copy() if options else {} # TODO(kmaehashi): deprecate specifying 'runtime' key in options runtime_options = dict( - runtime_options if runtime_options - else options.pop('runtime', {})) + runtime_options if runtime_options else options.pop("runtime", {}) + ) logic = handler_module.Logic() if logic is None else logic handler_class = handler_class if handler_class else handler_module.Handler entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec( - device) + device + ) entry_runtime = entry_runtime_cls(device, runtime_options) handler = handler_class(logic, entry_runtime, {}) @@ -108,14 +123,20 @@ def create_trainer( handler.consume_options(options) logic.consume_options(options) if len(options) > 0: - raise ValueError('Unknown options: ', options) + raise ValueError("Unknown options: ", options) from pytorch_pfn_extras.training._trainer import Trainer + return Trainer( - handler, evaluator=evaluator, - models=models, optimizers=optimizers, max_epochs=max_epochs, - extensions=extensions, out_dir=out_dir, - stop_trigger=stop_trigger, writer=writer, + handler, + evaluator=evaluator, + models=models, + optimizers=optimizers, + max_epochs=max_epochs, + extensions=extensions, + out_dir=out_dir, + stop_trigger=stop_trigger, + writer=writer, transform_model=transform_model, profile=profile, **kwargs, @@ -123,17 +144,17 @@ def create_trainer( def create_evaluator( - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - *, - progress_bar: bool = False, - device: 'DeviceLike' = 'cpu', - metrics: Optional[Sequence['MetricType']] = None, - logic: Optional[handler_module.Logic] = None, - handler_class: Optional[Type[handler_module.Handler]] = None, - options: Optional[Dict[str, Any]] = None, - runtime_options: Optional[Mapping[str, Any]] = None, - profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] -) -> 'Evaluator': + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + *, + progress_bar: bool = False, + device: "DeviceLike" = "cpu", + metrics: Optional[Sequence["MetricType"]] = None, + logic: Optional[handler_module.Logic] = None, + handler_class: Optional[Type[handler_module.Handler]] = None, + options: Optional[Dict[str, Any]] = None, + runtime_options: Optional[Mapping[str, Any]] = None, + profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] +) -> "Evaluator": """Creates an evaluator object. The return value of this function is expected to be fed to `ppe.engine.create_trainer` as an argument. @@ -171,13 +192,14 @@ def create_evaluator( options = options.copy() if options else {} # TODO(kmaehashi): deprecate specifying 'runtime' key in options runtime_options = dict( - runtime_options if runtime_options - else options.pop('runtime', {})) + runtime_options if runtime_options else options.pop("runtime", {}) + ) logic = handler_module.Logic() if logic is None else logic handler_class = handler_class if handler_class else handler_module.Handler entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec( - device) + device + ) entry_runtime = entry_runtime_cls(device, runtime_options) handler = handler_class(logic, entry_runtime, options) @@ -185,9 +207,10 @@ def create_evaluator( handler.consume_options(options) logic.consume_options(options) if len(options) > 0: - raise ValueError('Unknown options: ', options) + raise ValueError("Unknown options: ", options) from pytorch_pfn_extras.training._evaluator import Evaluator + return Evaluator( handler, models=models, diff --git a/pytorch_pfn_extras/handler/__init__.py b/pytorch_pfn_extras/handler/__init__.py index 09014488e..613fe8c5d 100644 --- a/pytorch_pfn_extras/handler/__init__.py +++ b/pytorch_pfn_extras/handler/__init__.py @@ -1,6 +1,15 @@ -from pytorch_pfn_extras.handler._code_block import CodeBlock, update_parameters, forward # NOQA +from pytorch_pfn_extras.handler._code_block import ( # NOQA + CodeBlock, + forward, + update_parameters, +) from pytorch_pfn_extras.handler._handler import BaseHandler, Handler # NOQA -from pytorch_pfn_extras.handler._logic import BaseLogic, Logic, CodeBlockLogic, ClousureLogic # NOQA # Deprecated, only imported for backward compatibility from pytorch_pfn_extras.handler._logic import torch_autocast # NOQA +from pytorch_pfn_extras.handler._logic import ( # NOQA + BaseLogic, + ClousureLogic, + CodeBlockLogic, + Logic, +) diff --git a/pytorch_pfn_extras/handler/_code_block.py b/pytorch_pfn_extras/handler/_code_block.py index 94d4396ad..d2e23f3a6 100644 --- a/pytorch_pfn_extras/handler/_code_block.py +++ b/pytorch_pfn_extras/handler/_code_block.py @@ -21,6 +21,7 @@ class CodeBlock: backprop_to: Name of the values where backpropagation will be stopped. state: Data that can be used during the CodeBlock execution. """ + func: Callable optimizers: List[torch.optim.Optimizer] backprop: bool @@ -104,11 +105,11 @@ def forward(block: Callable) -> CodeBlock: if isinstance(block, torch.nn.Module): module = block else: - module = getattr(block, '__self__', None) + module = getattr(block, "__self__", None) assert module is not None func = block state = {} - runtime = getattr(module, '_ppe_runtime', None) + runtime = getattr(module, "_ppe_runtime", None) assert runtime is not None return CodeBlock( diff --git a/pytorch_pfn_extras/handler/_handler.py b/pytorch_pfn_extras/handler/_handler.py index 36820dbba..bc506c62f 100644 --- a/pytorch_pfn_extras/handler/_handler.py +++ b/pytorch_pfn_extras/handler/_handler.py @@ -1,11 +1,19 @@ from typing import ( - Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, - Tuple, Union, TYPE_CHECKING, + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, ) -import torch - import pytorch_pfn_extras as ppe +import torch from pytorch_pfn_extras import reporting from pytorch_pfn_extras.handler._logic import BaseLogic from pytorch_pfn_extras.training import Evaluator, Trainer @@ -15,13 +23,12 @@ class BaseHandler: - def __init__( - self, - logic: BaseLogic, - options: Dict[str, Any], - *args: Any, - **kwargs: Any + self, + logic: BaseLogic, + options: Dict[str, Any], + *args: Any, + **kwargs: Any, ) -> None: """Base class of Handler. @@ -60,9 +67,7 @@ def train_setup(self, trainer: Trainer, loader: Iterable[Any]) -> None: pass def train_epoch_begin( - self, - trainer: Trainer, - loader: Iterable[Any] + self, trainer: Trainer, loader: Iterable[Any] ) -> None: """A method called when starting a new epoch. @@ -94,9 +99,9 @@ def train_cleanup(self, trainer: Trainer) -> None: pass def train_validation_begin( - self, - trainer: Trainer, - evaluator: Evaluator, + self, + trainer: Trainer, + evaluator: Evaluator, ) -> None: """A method called when starting a validation. @@ -109,9 +114,9 @@ def train_validation_begin( pass def train_validation_end( - self, - trainer: Trainer, - evaluator: Evaluator, + self, + trainer: Trainer, + evaluator: Evaluator, ) -> None: """A method called after validation. @@ -125,11 +130,11 @@ def train_validation_end( pass def train_step( - self, - trainer: Trainer, - batch_idx: int, - batch: Any, - complete_fn: Callable[[int, Any], None], + self, + trainer: Trainer, + batch_idx: int, + batch: Any, + complete_fn: Callable[[int, Any], None], ) -> None: """A training step. @@ -141,11 +146,11 @@ def train_step( pass def train_post_step( - self, - trainer: Trainer, - batch_idx: int, - batch: Any, - outputs: Any, + self, + trainer: Trainer, + batch_idx: int, + batch: Any, + outputs: Any, ) -> None: """A method called after each training step. @@ -156,11 +161,7 @@ def train_post_step( # Called after train_step. pass - def eval_setup( - self, - evaluator: Evaluator, - loader: Iterable[Any] - ) -> None: + def eval_setup(self, evaluator: Evaluator, loader: Iterable[Any]) -> None: """A method called only once when starting a training run. When evaluator is not given, this method is not called. @@ -183,11 +184,11 @@ def eval_loop_begin(self, evaluator: Evaluator) -> None: pass def eval_step( - self, - evaluator: Evaluator, - batch_idx: int, - batch: Any, - complete_fn: Callable[[int, Any], None], + self, + evaluator: Evaluator, + batch_idx: int, + batch: Any, + complete_fn: Callable[[int, Any], None], ) -> None: """Evaluation iteration. @@ -209,11 +210,11 @@ def eval_loop_end(self, evaluator: Evaluator) -> None: pass def eval_post_step( - self, - evaluator: Evaluator, - batch_idx: int, - batch: Any, - outputs: Any, + self, + evaluator: Evaluator, + batch_idx: int, + batch: Any, + outputs: Any, ) -> None: """A method called after each evaluation step. @@ -225,16 +226,15 @@ def eval_post_step( pass -ModulesTuple = Tuple[str, torch.nn.Module, 'BaseRuntime'] +ModulesTuple = Tuple[str, torch.nn.Module, "BaseRuntime"] class Handler(BaseHandler): - def __init__( - self, - logic: BaseLogic, - entry_runtime: 'BaseRuntime', - options: Dict[str, Any], + self, + logic: BaseLogic, + entry_runtime: "BaseRuntime", + options: Dict[str, Any], ) -> None: """A set of callback functions to perform device-specific operations. @@ -260,14 +260,14 @@ def __init__( def consume_options(self, options: Dict[str, Any]) -> None: super().consume_options(options) - self._eval_report_keys = options.pop('eval_report_keys', []) - self._train_report_keys = options.pop('train_report_keys', []) + self._eval_report_keys = options.pop("eval_report_keys", []) + self._train_report_keys = options.pop("train_report_keys", []) # Consume this argument for backward compatibility - options.pop('async', False) + options.pop("async", False) def _runtime_iterator( - self, - models: Mapping[str, torch.nn.Module], + self, + models: Mapping[str, torch.nn.Module], ) -> Generator[ModulesTuple, ModulesTuple, None]: if not self._ppe_modules: for n, m in models.items(): @@ -281,18 +281,19 @@ def _runtime_iterator( yield sn, sm, rt def _setup( - self, - models: Mapping[str, torch.nn.Module], - loader: Union[Iterable[Any], Mapping[str, Iterable[Any]]], - optimizers: Optional[Mapping[str, torch.optim.Optimizer]] = None, + self, + models: Mapping[str, torch.nn.Module], + loader: Union[Iterable[Any], Mapping[str, Iterable[Any]]], + optimizers: Optional[Mapping[str, torch.optim.Optimizer]] = None, ) -> None: # This requires loader to be always a dict # should be avoided? if not isinstance(loader, dict): # The default model always has empty name when obtained from the # modules - loaders = {sn: loader - for sn, _, _ in self._runtime_iterator(models)} + loaders = { + sn: loader for sn, _, _ in self._runtime_iterator(models) + } else: loaders = loader if optimizers is None: @@ -310,7 +311,8 @@ def _setup( if len(self._ppe_modules) == 0: raise RuntimeError( - 'call `ppe.to(module, device)` before starting the training') + "call `ppe.to(module, device)` before starting the training" + ) def train_setup(self, trainer: Trainer, loader: Iterable[Any]) -> None: """A method called only once when starting a training run. @@ -334,9 +336,7 @@ def train_cleanup(self, trainer: Trainer) -> None: rt.train_cleanup(sm) def train_epoch_begin( - self, - trainer: Trainer, - loader: Iterable[Any] + self, trainer: Trainer, loader: Iterable[Any] ) -> None: """A method called when starting a new epoch. @@ -361,9 +361,9 @@ def train_epoch_end(self, trainer: Trainer) -> None: self._logic.train_epoch_end(trainer.models, trainer.epoch) def train_validation_begin( - self, - trainer: Trainer, - evaluator: Evaluator, + self, + trainer: Trainer, + evaluator: Evaluator, ) -> None: """A method called when starting a validation. @@ -376,9 +376,9 @@ def train_validation_begin( self._logic.train_validation_begin(evaluator.models) def train_validation_end( - self, - trainer: Trainer, - evaluator: Evaluator, + self, + trainer: Trainer, + evaluator: Evaluator, ) -> None: """A method called after validation. @@ -396,11 +396,11 @@ def train_validation_end( self._logic.train_validation_end(evaluator.models) def train_step( - self, - trainer: Trainer, - batch_idx: int, - batch: Any, - complete_fn: Callable[[int, Any], None], + self, + trainer: Trainer, + batch_idx: int, + batch: Any, + complete_fn: Callable[[int, Any], None], ) -> None: """A training step. @@ -418,17 +418,15 @@ def train_step( batch = self._entry_runtime.convert_batch(batch) outs = self._logic.train_step( - trainer.models, trainer.optimizers, batch_idx, batch) + trainer.models, trainer.optimizers, batch_idx, batch + ) self._logic.train_step_optimizers( - trainer.models, trainer.optimizers, batch_idx) + trainer.models, trainer.optimizers, batch_idx + ) complete_fn(batch_idx, outs) - def eval_setup( - self, - evaluator: Evaluator, - loader: Iterable[Any] - ) -> None: + def eval_setup(self, evaluator: Evaluator, loader: Iterable[Any]) -> None: """Called only once when starting a training run. When evaluator is not given, this method is not called. @@ -441,11 +439,11 @@ def eval_setup( self._setup(evaluator.models, loader) def eval_step( - self, - evaluator: Evaluator, - batch_idx: int, - batch: Any, - complete_fn: Callable[[int, Any], None], + self, + evaluator: Evaluator, + batch_idx: int, + batch: Any, + complete_fn: Callable[[int, Any], None], ) -> None: """Evaluation iteration. @@ -466,11 +464,11 @@ def eval_step( complete_fn(batch_idx, outs) def eval_post_step( - self, - evaluator: Evaluator, - batch_idx: int, - batch: Any, - outputs: Any, + self, + evaluator: Evaluator, + batch_idx: int, + batch: Any, + outputs: Any, ) -> None: """A method called after each evaluation step. @@ -497,11 +495,7 @@ def eval_loop_end(self, evaluator: Evaluator) -> None: pass def train_post_step( - self, - trainer: Trainer, - batch_idx: int, - batch: Any, - outputs: Any + self, trainer: Trainer, batch_idx: int, batch: Any, outputs: Any ) -> None: """A method called after each training step. diff --git a/pytorch_pfn_extras/handler/_logic.py b/pytorch_pfn_extras/handler/_logic.py index b9dacc663..af9f78a53 100644 --- a/pytorch_pfn_extras/handler/_logic.py +++ b/pytorch_pfn_extras/handler/_logic.py @@ -1,10 +1,9 @@ import contextlib import dataclasses -from typing import Any, Dict, Generator, Iterable, Mapping, Optional import warnings +from typing import Any, Dict, Generator, Iterable, Mapping, Optional import torch - from pytorch_pfn_extras.handler._code_block import forward, update_parameters from pytorch_pfn_extras.runtime import _autocast @@ -21,7 +20,7 @@ def torch_autocast(enabled: bool = True) -> Generator[None, None, None]: def _normalize_outputs(outputs: Any) -> Dict[str, Any]: target: Dict[str, Any] - if isinstance(outputs, tuple) and hasattr(outputs, '_fields'): + if isinstance(outputs, tuple) and hasattr(outputs, "_fields"): # namedtuple target = outputs._asdict() # type: ignore[attr-defined] elif isinstance(outputs, dict): @@ -50,10 +49,10 @@ def consume_options(self, options: Dict[str, Any]) -> None: pass def train_epoch_begin( - self, - models: Mapping[str, torch.nn.Module], - epoch: int, - loader: Iterable[Any], + self, + models: Mapping[str, torch.nn.Module], + epoch: int, + loader: Iterable[Any], ) -> None: """A method called when starting a new epoch of training. @@ -65,9 +64,9 @@ def train_epoch_begin( pass def train_epoch_end( - self, - models: Mapping[str, torch.nn.Module], - epoch: int, + self, + models: Mapping[str, torch.nn.Module], + epoch: int, ) -> None: """A method called when completing an epoch of training. @@ -78,11 +77,11 @@ def train_epoch_end( pass def train_step( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, + batch: Any, ) -> Any: """A method invokes the models forward and backward passes. @@ -102,10 +101,10 @@ def train_step( pass def train_step_optimizers( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, ) -> None: """A method in charge of stepping the provided optimizers. @@ -118,8 +117,7 @@ def train_step_optimizers( pass def train_validation_begin( - self, - models: Mapping[str, torch.nn.Module] + self, models: Mapping[str, torch.nn.Module] ) -> None: """A method called when starting a validation. @@ -129,8 +127,8 @@ def train_validation_begin( pass def train_validation_end( - self, - models: Mapping[str, torch.nn.Module], + self, + models: Mapping[str, torch.nn.Module], ) -> None: """A method called when the validation completes. @@ -140,10 +138,10 @@ def train_validation_end( pass def eval_step( - self, - models: Mapping[str, torch.nn.Module], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + batch_idx: int, + batch: Any, ) -> Any: """A method for an evaluation step. @@ -157,11 +155,10 @@ def eval_step( class Logic(BaseLogic): - def __init__( - self, - model_name: str = 'main', - options: Optional[Dict[str, Any]] = None, + self, + model_name: str = "main", + options: Optional[Dict[str, Any]] = None, ) -> None: """A set of methods that defines the training logic. @@ -188,24 +185,29 @@ def __init__( def consume_options(self, options: Dict[str, Any]) -> None: super().consume_options(options) - self.backward_outputs = options.pop('backward_outputs', None) - self._grad_scaler = options.pop('grad_scaler', None) + self.backward_outputs = options.pop("backward_outputs", None) + self._grad_scaler = options.pop("grad_scaler", None) - self._backward_fn = options.pop('backward_function', None) + self._backward_fn = options.pop("backward_function", None) autocast_options = options.pop("autocast", False) if isinstance(autocast_options, bool): - autocast_options = {"enabled": autocast_options, "device_type": "cuda"} + autocast_options = { + "enabled": autocast_options, + "device_type": "cuda", + } self._autocast = _autocast._AutocastManager( autocast_options, self._grad_scaler is not None ) if self._grad_scaler is not None: if not isinstance(self._grad_scaler, torch.cuda.amp.GradScaler): - raise RuntimeError('grad_scaler should be a ' - 'torch.cuda.amp.GradScaler object') + raise RuntimeError( + "grad_scaler should be a " + "torch.cuda.amp.GradScaler object" + ) def _forward(self, model: torch.nn.Module, batch: Any) -> Any: - if isinstance(batch, tuple) and hasattr(batch, '_fields'): + if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple return model(batch) if isinstance(batch, dict): @@ -218,10 +220,16 @@ def _backward(self, outputs: Dict[str, Any]) -> None: to_backward = set() if self.backward_outputs is None: for _, v in outputs.items(): - if isinstance(v, torch.Tensor) and v.grad_fn is not None and ( - ( - v.numel() == 1 - and (v.dtype.is_floating_point or v.dtype.is_complex) + if ( + isinstance(v, torch.Tensor) + and v.grad_fn is not None + and ( + ( + v.numel() == 1 + and ( + v.dtype.is_floating_point or v.dtype.is_complex + ) + ) ) ): to_backward.add(v) @@ -238,8 +246,8 @@ def _backward(self, outputs: Dict[str, Any]) -> None: to_backward.add(v) except KeyError: warnings.warn( - 'Couldn\'t find requested backward value: ' - f'{k} in {outputs.keys()}' + "Couldn't find requested backward value: " + f"{k} in {outputs.keys()}" ) for v in to_backward: @@ -249,10 +257,10 @@ def _backward(self, outputs: Dict[str, Any]) -> None: self._backward_fn(v) def train_epoch_begin( - self, - models: Mapping[str, torch.nn.Module], - epoch: int, - loader: Iterable[Any], + self, + models: Mapping[str, torch.nn.Module], + epoch: int, + loader: Iterable[Any], ) -> None: """A method called when starting a new epoch of training. @@ -263,8 +271,9 @@ def train_epoch_begin( """ model = models[self.model_name] model.train() - if hasattr(loader, 'sampler') and hasattr( - loader.sampler, 'set_epoch'): # type: ignore[attr-defined] + if hasattr(loader, "sampler") and hasattr( + loader.sampler, "set_epoch" + ): # type: ignore[attr-defined] # Needed for `torch.utils.data.DistributedSampler` loader.sampler.set_epoch(epoch) # type: ignore[attr-defined] @@ -273,11 +282,11 @@ def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None: model.eval() def train_step( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, + batch: Any, ) -> Any: """A method invokes the model forward and backward passes. @@ -304,15 +313,16 @@ def train_step( ), "loss scaling with multiple outputs is not supported" to_back_outs = { k: self._grad_scaler.scale(v) - for k, v in to_back_outs.items()} + for k, v in to_back_outs.items() + } self._backward(to_back_outs) return outs def train_step_optimizers( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, ) -> None: """A method in charge of stepping the provided optimizers. @@ -332,8 +342,8 @@ def train_step_optimizers( optimizer.step() def train_validation_begin( - self, - models: Mapping[str, torch.nn.Module], + self, + models: Mapping[str, torch.nn.Module], ) -> None: """A method called when starting a validation. @@ -348,10 +358,10 @@ def train_validation_end(self, models: Mapping[str, Any]) -> None: model.train() def eval_step( - self, - models: Mapping[str, torch.nn.Module], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + batch_idx: int, + batch: Any, ) -> Any: """A method for an evaluation step. @@ -368,9 +378,9 @@ def eval_step( class CodeBlockLogic(BaseLogic): def __init__( - self, - model_name: str = 'main', - options: Optional[Dict[str, Any]] = None, + self, + model_name: str = "main", + options: Optional[Dict[str, Any]] = None, ) -> None: """A set of methods that defines the training logic. @@ -388,15 +398,15 @@ def __init__( def consume_options(self, options: Dict[str, Any]) -> None: super().consume_options(options) - self.backward_outputs = options.pop('backward_outputs', None) + self.backward_outputs = options.pop("backward_outputs", None) if self.backward_outputs is not None: assert isinstance(self.backward_outputs, str) def train_epoch_begin( - self, - models: Mapping[str, torch.nn.Module], - epoch: int, - loader: Iterable[Any], + self, + models: Mapping[str, torch.nn.Module], + epoch: int, + loader: Iterable[Any], ) -> None: """A method called when starting a new epoch of training. @@ -407,8 +417,9 @@ def train_epoch_begin( """ model = models[self.model_name] model.train() - if hasattr(loader, 'sampler') and hasattr( - loader.sampler, 'set_epoch'): # type: ignore[attr-defined] + if hasattr(loader, "sampler") and hasattr( + loader.sampler, "set_epoch" + ): # type: ignore[attr-defined] # Needed for `torch.utils.data.DistributedSampler` loader.sampler.set_epoch(epoch) # type: ignore[attr-defined] @@ -417,11 +428,11 @@ def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None: model.eval() def train_step( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, + batch: Any, ) -> Any: """A method invokes the model forward and backward passes. @@ -448,8 +459,8 @@ def train_step( )(batch) def train_validation_begin( - self, - models: Mapping[str, torch.nn.Module], + self, + models: Mapping[str, torch.nn.Module], ) -> None: """A method called when starting a validation. @@ -464,10 +475,10 @@ def train_validation_end(self, models: Mapping[str, Any]) -> None: model.train() def eval_step( - self, - models: Mapping[str, torch.nn.Module], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + batch_idx: int, + batch: Any, ) -> Any: """A method for an evaluation step. @@ -492,18 +503,19 @@ def __float__(self) -> float: class ClousureLogic(Logic): - def consume_options(self, options: Dict[str, Any]) -> None: super().consume_options(options) if self._grad_scaler is not None: - raise RuntimeError('torch.cuda.amp.GradScaler does not support clousure step mode.') + raise RuntimeError( + "torch.cuda.amp.GradScaler does not support clousure step mode." + ) def train_step( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, - batch: Any, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, + batch: Any, ) -> Any: """A method invokes the model forward and backward passes and performs an optimization step. @@ -517,18 +529,21 @@ def train_step( batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor): Input tensors feeded to the model of the current step. """ + def clousure() -> ClousureModelOutput: with self._autocast.autocast(): optimizers[self.model_name].zero_grad() outs = self._forward(models[self.model_name], batch) to_back_outs = _normalize_outputs(outs) if len(to_back_outs) > 1: - raise RuntimeError("Clousure step with multiple outputs is not supported.") + raise RuntimeError( + "Clousure step with multiple outputs is not supported." + ) elif len(to_back_outs) == 0: raise RuntimeError("No backward target found.") self._backward(to_back_outs) - loss, = to_back_outs.values() + (loss,) = to_back_outs.values() return ClousureModelOutput( outs=outs, loss=loss, @@ -537,14 +552,16 @@ def clousure() -> ClousureModelOutput: optimizer = optimizers[self.model_name] clousure_model_output: ClousureModelOutput = optimizer.step(clousure) # type: ignore if not isinstance(clousure_model_output, ClousureModelOutput): - raise RuntimeError(f"{type(clousure_model_output)} type object returned from optimizer.step with clousure. optimizer.step is expected to return ppe.handler.ClousureModelOutput.") + raise RuntimeError( + f"{type(clousure_model_output)} type object returned from optimizer.step with clousure. optimizer.step is expected to return ppe.handler.ClousureModelOutput." + ) return clousure_model_output.outs def train_step_optimizers( - self, - models: Mapping[str, torch.nn.Module], - optimizers: Mapping[str, torch.optim.Optimizer], - batch_idx: int, + self, + models: Mapping[str, torch.nn.Module], + optimizers: Mapping[str, torch.optim.Optimizer], + batch_idx: int, ) -> None: """In clousure mode, the stepping of the optimizer cannot be changed. diff --git a/pytorch_pfn_extras/logging.py b/pytorch_pfn_extras/logging.py index fb583f111..be9938a5e 100644 --- a/pytorch_pfn_extras/logging.py +++ b/pytorch_pfn_extras/logging.py @@ -1,22 +1,21 @@ import logging import os - -from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL # NOQA +from logging import CRITICAL, DEBUG, ERROR, INFO, WARNING # NOQA from typing import Optional -_logger_name = 'ppe' -_logger_format = '[%(name)s] %(asctime)s: (%(levelname)s) %(message)s' +_logger_name = "ppe" +_logger_format = "[%(name)s] %(asctime)s: (%(levelname)s) %(message)s" _logger = None def _configure_logging( - *, - filename: Optional[str] = None, - level: str = 'ERROR', - format: str = _logger_format + *, + filename: Optional[str] = None, + level: str = "ERROR", + format: str = _logger_format, ) -> None: global _logger - filename = os.environ.get('PPE_LOG_FILENAME', filename) + filename = os.environ.get("PPE_LOG_FILENAME", filename) if filename is None: handler: logging.Handler = logging.StreamHandler() else: @@ -25,15 +24,20 @@ def _configure_logging( # To dynamically change the level if needed # basicConfig does not allow to change the level right after _logger = logging.getLogger(_logger_name) - level = os.environ.get('PPE_LOG_LEVEL', level) - for lvl in (logging.DEBUG, logging.INFO, - logging.WARNING, logging.ERROR, logging.CRITICAL): + level = os.environ.get("PPE_LOG_LEVEL", level) + for lvl in ( + logging.DEBUG, + logging.INFO, + logging.WARNING, + logging.ERROR, + logging.CRITICAL, + ): if logging.getLevelName(lvl) == level: _logger.setLevel(lvl) break else: _logger.setLevel(logging.INFO) - _logger.warning('invalid PPE_LOG_LEVEL (%s); using INFO', level) + _logger.warning("invalid PPE_LOG_LEVEL (%s); using INFO", level) _logger.addHandler(handler) diff --git a/pytorch_pfn_extras/nn/__init__.py b/pytorch_pfn_extras/nn/__init__.py index 3b2273113..f188e18d1 100644 --- a/pytorch_pfn_extras/nn/__init__.py +++ b/pytorch_pfn_extras/nn/__init__.py @@ -1,10 +1,12 @@ +from pytorch_pfn_extras.nn import parallel # NOQA from pytorch_pfn_extras.nn.modules.ensure_shape import Ensure, ensure # NOQA -from pytorch_pfn_extras.nn.modules.lazy_linear import LazyLinear # NOQA -from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv1d # NOQA -from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv2d # NOQA -from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv3d # NOQA +from pytorch_pfn_extras.nn.modules.extended_sequential import ( # NOQA + ExtendedSequential, +) from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm1d # NOQA from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm2d # NOQA from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm3d # NOQA -from pytorch_pfn_extras.nn.modules.extended_sequential import ExtendedSequential # NOQA -from pytorch_pfn_extras.nn import parallel # NOQA +from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv1d # NOQA +from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv2d # NOQA +from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv3d # NOQA +from pytorch_pfn_extras.nn.modules.lazy_linear import LazyLinear # NOQA diff --git a/pytorch_pfn_extras/nn/modules/ensure_shape.py b/pytorch_pfn_extras/nn/modules/ensure_shape.py index a1f58d036..00043372d 100644 --- a/pytorch_pfn_extras/nn/modules/ensure_shape.py +++ b/pytorch_pfn_extras/nn/modules/ensure_shape.py @@ -18,17 +18,16 @@ class Ensure(torch.nn.Module): """ def __init__( - self, - *, - shape: Optional[Tuple[Optional[int], ...]] = None, - dtype: Optional[torch.dtype] = None, - broadcastable: bool = False, - can_cast: bool = False, + self, + *, + shape: Optional[Tuple[Optional[int], ...]] = None, + dtype: Optional[torch.dtype] = None, + broadcastable: bool = False, + can_cast: bool = False, ): super().__init__() # type: ignore[no-untyped-call] if shape is None and dtype is None: - raise ValueError( - 'shape, dtype or both arguments must be specified') + raise ValueError("shape, dtype or both arguments must be specified") self._dtype = dtype self._broadcastable = broadcastable self._can_cast = can_cast @@ -36,10 +35,7 @@ def __init__( # so we can compare the shapes using broadcast semantics c_shape: Optional[Tuple[int, ...]] = None if shape is not None: - non_none_tuple = tuple( - [x if x is not None else 1 - for x in shape] - ) + non_none_tuple = tuple([x if x is not None else 1 for x in shape]) if None in shape: self._broadcastable = True c_shape = non_none_tuple @@ -74,29 +70,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self._broadcastable: if not self._broadcast(t_shape, self._shape): raise ValueError( - f'Shapes {self._shape} and {input.shape} are non' - ' broadcastable') + f"Shapes {self._shape} and {input.shape} are non" + " broadcastable" + ) else: raise ValueError( - f'Expected {self._shape}, input shape is {input.shape}') + f"Expected {self._shape}, input shape is {input.shape}" + ) if self._dtype is not None and input.dtype != self._dtype: if self._can_cast: if not torch.can_cast(input.dtype, self._dtype): raise ValueError( - f'Input dtype {input.dtype} can\'t be casted to' - f' {self._dtype}') + f"Input dtype {input.dtype} can't be casted to" + f" {self._dtype}" + ) else: raise ValueError( - f'Expected {self._dtype}, input dtype is {input.dtype}') + f"Expected {self._dtype}, input dtype is {input.dtype}" + ) return input def ensure( - tensor: torch.Tensor, - shape: Optional[Tuple[Optional[int], ...]] = None, - dtype: Optional[torch.dtype] = None, - broadcastable: bool = False, - can_cast: bool = False + tensor: torch.Tensor, + shape: Optional[Tuple[Optional[int], ...]] = None, + dtype: Optional[torch.dtype] = None, + broadcastable: bool = False, + can_cast: bool = False, ) -> None: """Checks the shape and type of a tensor. @@ -111,5 +111,5 @@ def ensure( can_cast: Check if the input tensor can be casted to the provided type. """ Ensure( - shape=shape, dtype=dtype, - broadcastable=broadcastable, can_cast=can_cast)(tensor) + shape=shape, dtype=dtype, broadcastable=broadcastable, can_cast=can_cast + )(tensor) diff --git a/pytorch_pfn_extras/nn/modules/extended_sequential.py b/pytorch_pfn_extras/nn/modules/extended_sequential.py index 222328808..9f9918d55 100644 --- a/pytorch_pfn_extras/nn/modules/extended_sequential.py +++ b/pytorch_pfn_extras/nn/modules/extended_sequential.py @@ -1,10 +1,10 @@ -import torch import copy -from typing import TypeVar import warnings +from typing import TypeVar +import torch -Model = TypeVar('Model', torch.nn.Module, 'ExtendedSequential') +Model = TypeVar("Model", torch.nn.Module, "ExtendedSequential") def _reset_parameters(model: Model) -> Model: @@ -15,38 +15,39 @@ def _reset_parameters(model: Model) -> Model: for submodel in model.values(): _reset_parameters(submodel) else: - if hasattr(model, 'reset_parameters'): + if hasattr(model, "reset_parameters"): model.reset_parameters() # type: ignore [operator] - elif hasattr(model, '_reset_parameters'): + elif hasattr(model, "_reset_parameters"): model._reset_parameters() # type: ignore [operator] else: - if (len(list(model.parameters())) != 0 - or len(list(model.buffers())) != 0): - warnings.warn('Cannot reset the parameters of module {}. ' - 'Consider adding `reset_parameters` or ' - '`_reset_parameters` ' - 'functions to the module'.format(model), - UserWarning) + if ( + len(list(model.parameters())) != 0 + or len(list(model.buffers())) != 0 + ): + warnings.warn( + "Cannot reset the parameters of module {}. " + "Consider adding `reset_parameters` or " + "`_reset_parameters` " + "functions to the module".format(model), + UserWarning, + ) return model class ExtendedSequential(torch.nn.Sequential): - """Sequential module with extended features from chainer. + """Sequential module with extended features from chainer.""" - """ - def _copy_model(self, mode: str) -> 'ExtendedSequential': - if mode == 'init': + def _copy_model(self, mode: str) -> "ExtendedSequential": + if mode == "init": return _reset_parameters(copy.deepcopy(self)) - elif mode == 'copy': + elif mode == "copy": return copy.deepcopy(self) else: # mode == share return copy.copy(self) - def repeat( - self, n_repeat: int, mode: str = 'init' - ) -> 'ExtendedSequential': + def repeat(self, n_repeat: int, mode: str = "init") -> "ExtendedSequential": """Repeats this Sequential multiple times. This method returns a :class:`~torch.nn.Sequential` object which has @@ -85,10 +86,11 @@ def repeat( if n_repeat <= 0: return ExtendedSequential() - if mode not in ['copy', 'share', 'init']: + if mode not in ["copy", "share", "init"]: raise ValueError( - 'The \'mode\' argument should be either \'init\',' - '\'copy\', or \'share\'. But {} was given.'.format(mode)) + "The 'mode' argument should be either 'init'," + "'copy', or 'share'. But {} was given.".format(mode) + ) model_list = [] for _ in range(n_repeat): diff --git a/pytorch_pfn_extras/nn/modules/lazy.py b/pytorch_pfn_extras/nn/modules/lazy.py index 51e8399c2..8c2cf3f4d 100644 --- a/pytorch_pfn_extras/nn/modules/lazy.py +++ b/pytorch_pfn_extras/nn/modules/lazy.py @@ -1,6 +1,6 @@ import inspect -from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING import warnings +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple import torch @@ -58,12 +58,14 @@ def lazy_parmeters_determined(self) -> bool: parameters are determined. Note that this may be called during ``__init__``. """ - return self._lazy_ready and all([ - not isinstance(getattr(self, x), UninitializedParameter) - for x in self.lazy_parameter_names]) - - def state_dict( - self: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]: + return self._lazy_ready and all( + [ + not isinstance(getattr(self, x), UninitializedParameter) + for x in self.lazy_parameter_names + ] + ) + + def state_dict(self: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]: """Returns a dictionary containing a whole state of the module. This function overrides the default behavior to exclude uninitialized @@ -81,8 +83,15 @@ def state_dict( return destination # type: ignore[no-any-return] def _lazy_load_hook( # type: ignore[no-untyped-def] - self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): """load_state_dict pre-hook function for lazy buffers and parameters. The purpose of this hook is to check the current state and/or @@ -98,8 +107,9 @@ def _lazy_load_hook( # type: ignore[no-untyped-def] state_initialized = state_dict[key].shape != (0,) if module_initialized and not state_initialized: raise RuntimeError( - 'Can\'t load non-initialized buffers in already ' - 'initialized modules') + "Can't load non-initialized buffers in already " + "initialized modules" + ) elif not module_initialized and state_initialized: # Here we need to avoid a tensor size mismatch # this is a regular tensor without a materialize @@ -114,12 +124,14 @@ def _lazy_load_hook( # type: ignore[no-untyped-def] # parameters (see comments of ``state_dict``). key = prefix + name module_initialized = not isinstance( - getattr(self, name), UninitializedParameter) + getattr(self, name), UninitializedParameter + ) state_initialized = key in state_dict if module_initialized and not state_initialized: raise RuntimeError( - 'Can\'t load uninitialized parameters in already ' - 'initialized modules') + "Can't load uninitialized parameters in already " + "initialized modules" + ) elif not module_initialized and state_initialized: getattr(self, name).materialize(state_dict[key].shape) elif key not in state_dict and not module_initialized: @@ -128,15 +140,15 @@ def _lazy_load_hook( # type: ignore[no-untyped-def] class UninitializedParameter(torch.nn.Parameter): - def __repr__(self) -> str: # type: ignore[override] - return 'Uninitialized lazy parameter' + return "Uninitialized lazy parameter" - def share_memory_(self) -> 'UninitializedParameter': + def share_memory_(self) -> "UninitializedParameter": raise RuntimeError( - 'Can\'t share memory on an unitialized parameter. ' - 'Run forward to initialize the network before calling ' - '`module.share_memory()`.') + "Can't share memory on an unitialized parameter. " + "Run forward to initialize the network before calling " + "`module.share_memory()`." + ) @property def is_leaf(self) -> bool: # type: ignore[override] @@ -145,18 +157,20 @@ def is_leaf(self) -> bool: # type: ignore[override] # for parameters; optimizers check for this attribute and raise an # error if non-leaf tensors are detected. frame = inspect.currentframe() - package_name = frame.f_back.f_globals['__package__'] # type: ignore - if package_name.startswith('torch.optim'): - warnings.warn(''' + package_name = frame.f_back.f_globals["__package__"] # type: ignore + if package_name.startswith("torch.optim"): + warnings.warn( + """ Use of uninitialized lazy parameter in Optimizer has been detected. - Maybe you forgot to run forward before passing `module.parameters()` to the optimizer?''') # NOQA + Maybe you forgot to run forward before passing `module.parameters()` to the optimizer?""" + ) # NOQA return True def materialize( - self, - shape: Tuple[int, ...], - device: Optional['DeviceLike'] = None, - dtype: Optional[torch.dtype] = None, + self, + shape: Tuple[int, ...], + device: Optional["DeviceLike"] = None, + dtype: Optional[torch.dtype] = None, ) -> None: r"""Create a Parameter with the same properties of the uninitialized one. Given a shape, it materializes a parameter in the same device diff --git a/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py b/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py index cff046034..a7a97cd1b 100644 --- a/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py +++ b/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py @@ -1,33 +1,35 @@ from typing import Any, Optional import torch - -from pytorch_pfn_extras.nn.modules.lazy import LazyInitializationMixin -from pytorch_pfn_extras.nn.modules.lazy import UninitializedParameter +from pytorch_pfn_extras.nn.modules.lazy import ( + LazyInitializationMixin, + UninitializedParameter, +) class _LazyBatchNorm( # type: ignore[misc] - LazyInitializationMixin, - torch.nn.modules.batchnorm._BatchNorm + LazyInitializationMixin, torch.nn.modules.batchnorm._BatchNorm ): - running_mean: Any running_var: Any - lazy_parameter_names = ('weight', 'bias') + lazy_parameter_names = ("weight", "bias") - def __init__(self, num_features: Optional[int], *args: Any, **kwargs: Any) -> None: + def __init__( + self, num_features: Optional[int], *args: Any, **kwargs: Any + ) -> None: super().__init__(num_features or 0, *args, **kwargs) if not self.affine: raise ValueError( - 'LazyBatchNorm is not compatible with affine=False.' - ' Use the regular BatchNorm layers instead') + "LazyBatchNorm is not compatible with affine=False." + " Use the regular BatchNorm layers instead" + ) # weight and bias are registered in the mixin if num_features is None: self.num_features: Optional[int] = None # type: ignore[assignment] if self.track_running_stats: # these buffers are not always needed # so we avoid explicit initializations - self.lazy_buffer_names = ('running_mean', 'running_var') + self.lazy_buffer_names = ("running_mean", "running_var") def reset_parameters(self) -> None: if self.lazy_parmeters_determined: @@ -46,13 +48,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.track_running_stats: assert isinstance(self.running_mean, torch.Tensor) self.running_mean = torch.zeros( - self.num_features, device=self.running_mean.device, - dtype=self.running_mean.dtype + self.num_features, + device=self.running_mean.device, + dtype=self.running_mean.dtype, ) assert isinstance(self.running_var, torch.Tensor) self.running_var = torch.ones( - self.num_features, device=self.running_var.device, - dtype=self.running_mean.dtype + self.num_features, + device=self.running_var.device, + dtype=self.running_mean.dtype, ) self.reset_parameters() return super().forward(input) @@ -64,6 +68,7 @@ class LazyBatchNorm1d(_LazyBatchNorm, torch.nn.BatchNorm1d): # type: ignore[mis When ``num_features`` is ``None``, it is determined at the first time of the forward step. """ + pass @@ -73,6 +78,7 @@ class LazyBatchNorm2d(_LazyBatchNorm, torch.nn.BatchNorm2d): # type: ignore[mis When ``num_features`` is ``None``, it is determined at the first time of the forward step. """ + pass @@ -82,4 +88,5 @@ class LazyBatchNorm3d(_LazyBatchNorm, torch.nn.BatchNorm3d): # type: ignore[mis When ``num_features`` is ``None``, it is determined at the first time of the forward step. """ + pass diff --git a/pytorch_pfn_extras/nn/modules/lazy_conv.py b/pytorch_pfn_extras/nn/modules/lazy_conv.py index 32a9e5869..51f00b666 100644 --- a/pytorch_pfn_extras/nn/modules/lazy_conv.py +++ b/pytorch_pfn_extras/nn/modules/lazy_conv.py @@ -1,18 +1,18 @@ from typing import Any, Optional import torch - from pytorch_pfn_extras.nn.modules.lazy import ( - LazyInitializationMixin, UninitializedParameter + LazyInitializationMixin, + UninitializedParameter, ) class _LazyConvNd(LazyInitializationMixin): - - lazy_parameter_names = ('weight',) + lazy_parameter_names = ("weight",) def __init__( - self: Any, in_channels: Optional[int], *args: Any, **kwargs: Any) -> None: + self: Any, in_channels: Optional[int], *args: Any, **kwargs: Any + ) -> None: super().__init__(in_channels or 0, *args, **kwargs) if in_channels is None: self.in_channels: Optional[int] = None @@ -22,11 +22,17 @@ def forward(self: Any, input: torch.Tensor) -> torch.Tensor: if isinstance(self.weight, UninitializedParameter): self.in_channels = input.shape[1] if self.transposed: - shape = (self.in_channels, self.out_channels // self.groups, - *self.kernel_size) + shape = ( + self.in_channels, + self.out_channels // self.groups, + *self.kernel_size, + ) else: - shape = (self.out_channels, self.in_channels // self.groups, - *self.kernel_size) + shape = ( + self.out_channels, + self.in_channels // self.groups, + *self.kernel_size, + ) self.weight = torch.nn.Parameter(self.weight.new_empty(*shape)) self.reset_parameters() return super().forward(input) # type: ignore @@ -44,6 +50,7 @@ class LazyConv1d(_LazyConvNd, torch.nn.Conv1d): # type: ignore[misc] When ``in_channels`` is ``None``, it is determined at the first time of the forward step. """ + pass @@ -53,6 +60,7 @@ class LazyConv2d(_LazyConvNd, torch.nn.Conv2d): # type: ignore[misc] When ``in_channels`` is ``None``, it is determined at the first time of the forward step. """ + pass @@ -62,4 +70,5 @@ class LazyConv3d(_LazyConvNd, torch.nn.Conv3d): # type: ignore[misc] When ``in_channels`` is ``None``, it is determined at the first time of the forward step. """ + pass diff --git a/pytorch_pfn_extras/nn/modules/lazy_linear.py b/pytorch_pfn_extras/nn/modules/lazy_linear.py index aabc2e892..519dcde4f 100644 --- a/pytorch_pfn_extras/nn/modules/lazy_linear.py +++ b/pytorch_pfn_extras/nn/modules/lazy_linear.py @@ -1,9 +1,10 @@ from typing import Any, Optional import torch - -from pytorch_pfn_extras.nn.modules.lazy import UninitializedParameter -from pytorch_pfn_extras.nn.modules.lazy import LazyInitializationMixin +from pytorch_pfn_extras.nn.modules.lazy import ( + LazyInitializationMixin, + UninitializedParameter, +) class LazyLinear(LazyInitializationMixin, torch.nn.Linear): # type: ignore[misc] @@ -13,9 +14,11 @@ class LazyLinear(LazyInitializationMixin, torch.nn.Linear): # type: ignore[misc the forward step. """ - lazy_parameter_names = ('weight',) + lazy_parameter_names = ("weight",) - def __init__(self, in_features: Optional[int], *args: Any, **kwargs: Any) -> None: + def __init__( + self, in_features: Optional[int], *args: Any, **kwargs: Any + ) -> None: super().__init__(in_features or 0, *args, **kwargs) if in_features is None: self.in_features = None # type: ignore[assignment] @@ -24,8 +27,9 @@ def __init__(self, in_features: Optional[int], *args: Any, **kwargs: Any) -> Non def forward(self, input: torch.Tensor) -> torch.Tensor: if isinstance(self.weight, UninitializedParameter): self.in_features = input.shape[-1] - self.weight = torch.nn.Parameter(self.weight.new_empty( - self.out_features, self.in_features)) + self.weight = torch.nn.Parameter( + self.weight.new_empty(self.out_features, self.in_features) + ) self.reset_parameters() return super().forward(input) diff --git a/pytorch_pfn_extras/nn/parallel/__init__.py b/pytorch_pfn_extras/nn/parallel/__init__.py index 68840f079..458e3b5e4 100644 --- a/pytorch_pfn_extras/nn/parallel/__init__.py +++ b/pytorch_pfn_extras/nn/parallel/__init__.py @@ -1 +1,3 @@ -from pytorch_pfn_extras.nn.parallel.distributed import DistributedDataParallel # NOQA +from pytorch_pfn_extras.nn.parallel.distributed import ( # NOQA + DistributedDataParallel, +) diff --git a/pytorch_pfn_extras/nn/parallel/distributed.py b/pytorch_pfn_extras/nn/parallel/distributed.py index 4724cc40a..ab3ccb600 100644 --- a/pytorch_pfn_extras/nn/parallel/distributed.py +++ b/pytorch_pfn_extras/nn/parallel/distributed.py @@ -1,27 +1,34 @@ import logging - -from contextlib import contextmanager +import threading from collections import OrderedDict +from contextlib import contextmanager from typing import ( - Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence, Tuple, - TypeVar, Union + Any, + Callable, + Dict, + Generator, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, ) import torch -from torch import nn +from pytorch_pfn_extras.profiler import record from torch import distributed as dist -from torch.utils import hooks +from torch import nn from torch.autograd import Variable from torch.autograd.profiler import record_function -import threading - -from pytorch_pfn_extras.profiler import record +from torch.utils import hooks logger = logging.getLogger(__name__) Tensors = Union[Tuple[torch.Tensor, ...], torch.Tensor] DistFunc = Callable[[Sequence[torch.Tensor], Optional[dist.ProcessGroup]], None] -HookFun = Callable[['DistributedDataParallel'], None] +HookFun = Callable[["DistributedDataParallel"], None] class _ForEachWrapper: @@ -31,22 +38,25 @@ class _ForEachWrapper: - torch with python for loop - torch._foreach_xxx """ + def __init__(self) -> None: self.flatten = torch._utils._flatten_dense_tensors self.unflatten = torch._utils._unflatten_dense_tensors - self._enable_foreach = (hasattr(torch, "_foreach_add") - and hasattr(torch, '_foreach_zero_')) + self._enable_foreach = hasattr(torch, "_foreach_add") and hasattr( + torch, "_foreach_zero_" + ) if not self._enable_foreach: logger.warning( "torch does not have _foreach_xxx functions." - " Please use newer torch") + " Please use newer torch" + ) def multi_tensor_scale( - self, - src: Sequence[torch.Tensor], - dst: Sequence[torch.Tensor], - scale: float, + self, + src: Sequence[torch.Tensor], + dst: Sequence[torch.Tensor], + scale: float, ) -> None: with torch.no_grad(): # type: ignore[no-untyped-call] # _foreach_zero for long type is not supported in CUDA @@ -75,16 +85,18 @@ def get_foreach_wrapper() -> _ForEachWrapper: def _reduce( - values: Sequence[torch.Tensor], - group: Optional[dist.ProcessGroup], + values: Sequence[torch.Tensor], + group: Optional[dist.ProcessGroup], ) -> None: size = sum([v.numel() for v in values]) # flatten values to improve the runtime perfomance of all-reduce - coalesced = torch.empty(size, device=values[0].device, - dtype=values[0].dtype) + coalesced = torch.empty( + size, device=values[0].device, dtype=values[0].dtype + ) coalesced_views = get_foreach_wrapper().unflatten( # type: ignore[no-untyped-call] - coalesced, values) + coalesced, values + ) get_foreach_wrapper().multi_tensor_scale(values, coalesced_views, 1.0) with record( @@ -94,28 +106,32 @@ def _reduce( # unflatten values get_foreach_wrapper().multi_tensor_scale( - coalesced_views, values, - 1.0 / dist.get_world_size(group) # type: ignore[no-untyped-call] + coalesced_views, + values, + 1.0 / dist.get_world_size(group), # type: ignore[no-untyped-call] ) def _broadcast( - values: Sequence[torch.Tensor], - group: Optional[dist.ProcessGroup] + values: Sequence[torch.Tensor], group: Optional[dist.ProcessGroup] ) -> None: with torch.no_grad(): # type: ignore[no-untyped-call] coalesced = get_foreach_wrapper().flatten( # type: ignore[no-untyped-call] - values) + values + ) with record( "torch.distributed.broadcast", use_cuda=torch.cuda.is_available() ): dist.broadcast(coalesced, 0, group=group) # type: ignore[no-untyped-call] src = get_foreach_wrapper().unflatten( # type: ignore[no-untyped-call] - coalesced, values) + coalesced, values + ) get_foreach_wrapper().multi_tensor_scale(src, values, 1.0) -def _group_by_type(values: Sequence[Optional[torch.Tensor]]) -> List[List[torch.Tensor]]: +def _group_by_type( + values: Sequence[Optional[torch.Tensor]], +) -> List[List[torch.Tensor]]: groups: Dict[torch.dtype, List[torch.Tensor]] = {} for value in values: if value is None: @@ -151,19 +167,24 @@ class DistributedDataParallel(nn.Module): broadcast_function: Broadcast function """ - _unused_parameters = ["device_ids", "output_device", "dim", - "find_unused_parameters", "check_reduction", - "gradient_as_bucket_view"] + _unused_parameters = [ + "device_ids", + "output_device", + "dim", + "find_unused_parameters", + "check_reduction", + "gradient_as_bucket_view", + ] def __init__( - self, - module: nn.Module, - broadcast_buffers: bool = True, - negotiate_grads: bool = True, - process_group: Optional[dist.ProcessGroup] = None, - reduce_function: Optional[DistFunc] = None, - broadcast_function: Optional[DistFunc] = None, - **kwargs: Any + self, + module: nn.Module, + broadcast_buffers: bool = True, + negotiate_grads: bool = True, + process_group: Optional[dist.ProcessGroup] = None, + reduce_function: Optional[DistFunc] = None, + broadcast_function: Optional[DistFunc] = None, + **kwargs: Any, ) -> None: """ This module receives keyword arguments for the compatibility with @@ -204,9 +225,9 @@ def __init__( # synchronize initial parameters and buffers params = dict(self.named_parameters()) buffers = dict(self.named_buffers()) - values = \ - [buffers[name] for name in self._sorted_buffer_keys] + \ - [params[name] for name in self._sorted_param_keys] + values = [buffers[name] for name in self._sorted_buffer_keys] + [ + params[name] for name in self._sorted_param_keys + ] if dist.is_initialized(): # type: ignore[no-untyped-call] groups = _group_by_type(values) for group in groups: @@ -219,8 +240,7 @@ def __init__( @contextmanager def no_sync(self) -> Generator[None, None, None]: - """A context manager to disable synchronization after backward - """ + """A context manager to disable synchronization after backward""" prev = self._require_sync self._require_sync = False try: @@ -234,13 +254,13 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return self.module(*args, **kwargs) def load_state_dict( - self, - state_dict: 'Mapping[str, torch.Tensor]', - strict: bool = True, + self, + state_dict: "Mapping[str, torch.Tensor]", + strict: bool = True, ) -> None: self.module.load_state_dict(state_dict, strict=strict) # type: ignore[arg-type] - T_destination = TypeVar('T_destination', bound=Mapping[str, torch.Tensor]) + T_destination = TypeVar("T_destination", bound=Mapping[str, torch.Tensor]) def state_dict(self) -> Dict[str, Any]: # type: ignore[override] return self.module.state_dict() @@ -257,10 +277,7 @@ def register_comm_hook(self, hook: HookFun) -> hooks.RemovableHandle: return handle def _backward_hook( - self, - module: torch.nn.Module, - gin: Tensors, - gout: Tensors + self, module: torch.nn.Module, gin: Tensors, gout: Tensors ) -> None: def _synchronize() -> None: if not self._require_sync: @@ -270,14 +287,17 @@ def _synchronize() -> None: hook(self) with record_function( - "ppe.nn.parallel.DistributedDataParallel.synchronize"): + "ppe.nn.parallel.DistributedDataParallel.synchronize" + ): params = dict(self.named_parameters()) if self._negotiate_grads: # find parameters that have gradients has_grads = torch.tensor( - [params[name].grad is not None - for name in self._sorted_param_keys], - device=self._device + [ + params[name].grad is not None + for name in self._sorted_param_keys + ], + device=self._device, ) # cast to long because bool may not be used in all_reduce @@ -288,19 +308,25 @@ def _synchronize() -> None: use_cuda=torch.cuda.is_available(), ): dist.all_reduce( # type: ignore[no-untyped-call] - has_grads, op=dist.ReduceOp.MAX) + has_grads, op=dist.ReduceOp.MAX + ) - for name, has_grad in zip(self._sorted_param_keys, - has_grads.bool().cpu()): + for name, has_grad in zip( + self._sorted_param_keys, has_grads.bool().cpu() + ): # create zero tensor as a gradient if a parameter # does not have the gradient and other processes # require to synchronize this parameter. if has_grad and params[name].grad is None: - params[name].grad = \ - torch.zeros_like(params[name].data) - - grads = [params[name].grad for name in self._sorted_param_keys - if params[name].grad is not None] + params[name].grad = torch.zeros_like( + params[name].data + ) + + grads = [ + params[name].grad + for name in self._sorted_param_keys + if params[name].grad is not None + ] groups = _group_by_type(grads) with record( "pytorch_pfn_extras.nn.parallel." @@ -320,8 +346,7 @@ def _synchronize() -> None: use_cuda=torch.cuda.is_available(), ): for group in groups: - self._broadcast_function( - group, self._process_group) + self._broadcast_function(group, self._process_group) # PyTorch will invoke `_synchronize` after the backward computation. Variable._execution_engine.queue_callback(_synchronize) @@ -344,6 +369,7 @@ def _input_to_device(self, obj: Any) -> Any: if isinstance(obj, list) and len(obj) > 0: return [self._input_to_device(x) for x in obj] if isinstance(obj, dict) and len(obj) > 0: - return {key: self._input_to_device(value) - for key, value in obj.items()} + return { + key: self._input_to_device(value) for key, value in obj.items() + } return obj diff --git a/pytorch_pfn_extras/profiler/__init__.py b/pytorch_pfn_extras/profiler/__init__.py index 4b63035f9..88abf507a 100644 --- a/pytorch_pfn_extras/profiler/__init__.py +++ b/pytorch_pfn_extras/profiler/__init__.py @@ -1,5 +1,5 @@ from pytorch_pfn_extras.profiler._record import record # NOQA from pytorch_pfn_extras.profiler._record import record_function # NOQA from pytorch_pfn_extras.profiler._record import record_iterable # NOQA -from pytorch_pfn_extras.profiler._time_summary import get_time_summary # NOQA from pytorch_pfn_extras.profiler._time_summary import TimeSummary # NOQA +from pytorch_pfn_extras.profiler._time_summary import get_time_summary # NOQA diff --git a/pytorch_pfn_extras/profiler/_record.py b/pytorch_pfn_extras/profiler/_record.py index a2cc14a2b..e77635319 100644 --- a/pytorch_pfn_extras/profiler/_record.py +++ b/pytorch_pfn_extras/profiler/_record.py @@ -1,10 +1,17 @@ -from contextlib import contextmanager import inspect -from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, TYPE_CHECKING import types +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Optional, + TypeVar, +) import torch - from pytorch_pfn_extras.profiler import _time_summary from pytorch_pfn_extras.runtime import runtime_registry @@ -18,7 +25,7 @@ def _infer_tag_name(frame: Optional[types.FrameType], depth: int) -> str: frame = frame.f_back assert frame is not None frame_info = inspect.getframeinfo(frame, context=0) - return '{}:{}:{}'.format( + return "{}:{}:{}".format( inspect.getmodulename(frame_info.filename), frame_info.lineno, frame_info.function, @@ -38,13 +45,12 @@ def complete(self) -> None: @contextmanager def record( - tag: Optional[str], - metric: Optional[str] = None, - use_cuda: bool = False, - enable: bool = True, - device: 'DeviceLike' = 'cpu' + tag: Optional[str], + metric: Optional[str] = None, + use_cuda: bool = False, + enable: bool = True, + device: "DeviceLike" = "cpu", ) -> Generator[_time_summary._ReportNotification, None, None]: - if not enable: yield _DummyReportNotification() return @@ -55,8 +61,7 @@ def record( if metric is None: metric = tag - runtime_cls = runtime_registry.get_runtime_class_for_device_spec( - device) + runtime_cls = runtime_registry.get_runtime_class_for_device_spec(device) runtime_tracer = runtime_cls.trace if use_cuda: @@ -71,13 +76,13 @@ def record( torch.cuda.nvtx.range_pop() # type: ignore[no-untyped-call] -_T = TypeVar('_T') +_T = TypeVar("_T") def record_function( - tag: Optional[str], - use_cuda: bool = False, - enable: bool = True, + tag: Optional[str], + use_cuda: bool = False, + enable: bool = True, ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: def wrapper(f: Callable[..., _T]) -> Callable[..., _T]: def wrapped(*args: Any, **kwargs: Any) -> _T: @@ -90,11 +95,11 @@ def wrapped(*args: Any, **kwargs: Any) -> _T: def record_iterable( - tag: Optional[str], - iter: Iterable[_T], - divide_metric: bool = False, - use_cuda: bool = False, - enable: bool = True, + tag: Optional[str], + iter: Iterable[_T], + divide_metric: bool = False, + use_cuda: bool = False, + enable: bool = True, ) -> Iterable[_T]: if tag is None: tag = _infer_tag_name(inspect.currentframe(), depth=1) diff --git a/pytorch_pfn_extras/profiler/_time_summary.py b/pytorch_pfn_extras/profiler/_time_summary.py index d62bfd076..aa5530bf0 100644 --- a/pytorch_pfn_extras/profiler/_time_summary.py +++ b/pytorch_pfn_extras/profiler/_time_summary.py @@ -1,28 +1,27 @@ import atexit -from contextlib import contextmanager +import multiprocessing as mp import os -import time -from typing import Callable, Dict, Generator, Optional, Tuple -import threading import queue -import multiprocessing as mp -import torch +import threading +import time import weakref +from contextlib import contextmanager +from typing import Callable, Dict, Generator, Optional, Tuple +import torch from pytorch_pfn_extras.reporting import DictSummary - Events = Tuple[torch.cuda.Event, torch.cuda.Event] class _ReportNotification: def __init__( - self, - summary: 'TimeSummary', - tag: str, - use_cuda: bool, - begin_event: Optional[torch.cuda.Event], - begin: float, + self, + summary: "TimeSummary", + tag: str, + use_cuda: bool, + begin_event: Optional[torch.cuda.Event], + begin: float, ) -> None: self._is_completed = True self._summary = summary @@ -36,19 +35,22 @@ def defer(self) -> None: def complete(self) -> None: self._summary.complete_report( - self._tag, self._use_cuda, self._begin_event, self._begin) + self._tag, self._use_cuda, self._begin_event, self._begin + ) class _CPUWorker: def __init__( - self, - add: Callable[[str, float], None], - max_queue_size: int, + self, + add: Callable[[str, float], None], + max_queue_size: int, ) -> None: self._add = add self._max_queue_size = max_queue_size self._initialized = False - self._queue: Optional[mp.JoinableQueue[Optional[Tuple[str, float]]]] = None + self._queue: Optional[ + mp.JoinableQueue[Optional[Tuple[str, float]]] + ] = None self._thread: Optional[threading.Thread] = None self._thread_exited = False @@ -108,17 +110,17 @@ def _worker(self) -> None: class _CUDAWorker: def __init__( - self, - add: Callable[[str, float], None], - max_queue_size: int, + self, + add: Callable[[str, float], None], + max_queue_size: int, ) -> None: self._add = add self._max_queue_size = max_queue_size self._initialized = False self._thread: Optional[threading.Thread] = None - self._queue: Optional['queue.Queue[Optional[_QueueElem]]'] = None + self._queue: Optional["queue.Queue[Optional[_QueueElem]]"] = None self._event_lock = threading.Lock() - self._events: Optional['queue.Queue[torch.cuda.Event]'] = None + self._events: Optional["queue.Queue[torch.cuda.Event]"] = None self._thread_exited = False def initialize(self) -> None: @@ -146,9 +148,9 @@ def synchronize(self) -> None: self._queue.join() def put( - self, - name: str, - events: Tuple[torch.cuda.Event, torch.cuda.Event], + self, + name: str, + events: Tuple[torch.cuda.Event, torch.cuda.Event], ) -> None: assert self._queue is not None assert not self._thread_exited @@ -181,13 +183,14 @@ def get_cuda_event(self) -> torch.cuda.Event: with self._event_lock: if self._events.empty(): event = torch.cuda.Event( # type: ignore[no-untyped-call] - enable_timing=True) + enable_timing=True + ) self._events.put(event) return self._events.get() class _Finalizer: - def __init__(self, ts: 'TimeSummary') -> None: + def __init__(self, ts: "TimeSummary") -> None: self._ts = weakref.ref(ts) def __call__(self) -> None: @@ -209,7 +212,9 @@ class TimeSummary: when the instance is created. """ - def __init__(self, *, max_queue_size: int = 1000, auto_init: bool = True) -> None: + def __init__( + self, *, max_queue_size: int = 1000, auto_init: bool = True + ) -> None: self._summary_lock = threading.Lock() self._summary = DictSummary() self._additional_stats: Dict[str, float] = {} @@ -217,7 +222,9 @@ def __init__(self, *, max_queue_size: int = 1000, auto_init: bool = True) -> Non self._cpu_worker = _CPUWorker(self._add_from_worker, max_queue_size) self._cuda_worker: Optional[_CUDAWorker] = None if torch.cuda.is_available(): - self._cuda_worker = _CUDAWorker(self._add_from_worker, max_queue_size) + self._cuda_worker = _CUDAWorker( + self._add_from_worker, max_queue_size + ) self._initialized = False self._master_pid = os.getpid() @@ -242,7 +249,8 @@ def initialize(self) -> None: raise RuntimeError( "TimeSummary must be initialized in the same process as the " "one created the instance. Please call initialize() in the " - "main process.") + "main process." + ) self._cpu_worker.initialize() if self._cuda_worker is not None: self._cuda_worker.initialize() @@ -276,8 +284,8 @@ def add(self, name: str, value: float) -> None: @contextmanager def summary( - self, - clear: bool = False, + self, + clear: bool = False, ) -> Generator[Tuple[DictSummary, Dict[str, float]], None, None]: self.initialize() try: @@ -289,11 +297,11 @@ def summary( self._additional_stats = {} def complete_report( - self, - tag: str, - use_cuda: bool, - begin_event: Optional[torch.cuda.Event], - begin: float, + self, + tag: str, + use_cuda: bool, + begin_event: Optional[torch.cuda.Event], + begin: float, ) -> None: end = time.time() assert self._cpu_worker._queue is not None @@ -305,13 +313,14 @@ def complete_report( end_event = self._cuda_worker.get_cuda_event() end_event.record() # type: ignore[no-untyped-call] self._cuda_worker._queue.put( - (f"{tag}.cuda", (begin_event, end_event))) + (f"{tag}.cuda", (begin_event, end_event)) + ) @contextmanager def report( - self, - tag: str, - use_cuda: bool = False, + self, + tag: str, + use_cuda: bool = False, ) -> Generator[_ReportNotification, None, None]: """Context manager to automatically report execution times. @@ -332,7 +341,8 @@ def report( try: begin = time.time() notification = _ReportNotification( - self, tag, use_cuda, begin_event, begin) + self, tag, use_cuda, begin_event, begin + ) yield notification finally: if notification._is_completed: @@ -343,6 +353,6 @@ def report( def get_time_summary() -> TimeSummary: - if not hasattr(_thread_local, 'time_summary'): + if not hasattr(_thread_local, "time_summary"): _thread_local.time_summary = TimeSummary(auto_init=False) return _thread_local.time_summary # type: ignore[no-any-return] diff --git a/pytorch_pfn_extras/reporting.py b/pytorch_pfn_extras/reporting.py index 0676e35ba..a4132500e 100644 --- a/pytorch_pfn_extras/reporting.py +++ b/pytorch_pfn_extras/reporting.py @@ -2,17 +2,25 @@ import contextlib import threading import types +import warnings from typing import ( - Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence, - Tuple, Type, Union, + Any, + Callable, + Dict, + Generator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + overload, ) -from typing import overload -import warnings import numpy import torch - Scalar = Union[torch.Tensor, numpy.ndarray, numpy.floating, float] FloatLikeValue = Union[Scalar, float] Value = Union[Scalar, Callable[[], float]] @@ -33,7 +41,8 @@ def _nograd(value: Value) -> Value: def _nograd( - value: Union[FloatLikeValue, Value]) -> Union[FloatLikeValue, Value]: + value: Union[FloatLikeValue, Value] +) -> Union[FloatLikeValue, Value]: if isinstance(value, torch.Tensor): return value.detach() return value @@ -99,10 +108,10 @@ def __enter__(self) -> None: _get_reporters().append(self) def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[types.TracebackType], + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], ) -> None: """Recovers the previous reporter object to the current.""" _get_reporters().pop() @@ -148,9 +157,7 @@ def add_observer(self, name: str, observer: torch.nn.Module) -> None: self._observer_names[id(observer)] = name def add_observers( - self, - prefix: str, - observers: Sequence[Tuple[str, torch.nn.Module]] + self, prefix: str, observers: Sequence[Tuple[str, torch.nn.Module]] ) -> None: """Registers multiple observers at once. @@ -165,9 +172,9 @@ def add_observers( self._observer_names[id(observer)] = prefix + name def report( - self, - values: Mapping[str, Value], - observer: Optional[torch.nn.Module] = None, + self, + values: Mapping[str, Value], + observer: Optional[torch.nn.Module] = None, ) -> None: """Reports observed values. @@ -193,10 +200,11 @@ def report( observer_id = id(observer) if observer_id not in self._observer_names: raise KeyError( - 'Given observer is not registered to the reporter.') + "Given observer is not registered to the reporter." + ) observer_name = self._observer_names[observer_id] for key, value in values.items(): - name = '%s/%s' % (observer_name, key) + name = "%s/%s" % (observer_name, key) self.observation[name] = value else: self.observation.update(values) @@ -216,8 +224,8 @@ def get_current_reporter() -> Reporter: def report( - values: Mapping[str, Value], - observer: Optional[torch.nn.Module] = None, + values: Mapping[str, Value], + observer: Optional[torch.nn.Module] = None, ) -> None: """Reports observed values with the current reporter object. @@ -352,20 +360,22 @@ def state_dict(self) -> Dict[str, Any]: try: # Save the stats as python scalars in order to avoid # different device errors when loading them back - state = {'_x': float(self._x), - '_x2': float(self._x2), - '_n': int(self._n)} + state = { + "_x": float(self._x), + "_x2": float(self._x2), + "_n": int(self._n), + } except KeyError: - warnings.warn('The previous statistics are not saved.') + warnings.warn("The previous statistics are not saved.") return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: # Casting here is because of backward compatibility # Restore previously taken snapshots with autoload self._add_deferred_values() - self._x = float(_nograd(to_load['_x'])) - self._x2 = float(_nograd(to_load['_x2'])) - self._n = int(_nograd(to_load['_n'])) + self._x = float(_nograd(to_load["_x"])) + self._x2 = float(_nograd(to_load["_x2"])) + self._n = int(_nograd(to_load["_n"])) def __add__(self, other: "Summary") -> "Summary": s = Summary() @@ -404,10 +414,11 @@ def add(self, d: Mapping[str, Union[Value, Tuple[Value, Scalar]]]) -> None: w: Scalar = 1 if isinstance(v, tuple): v, w = v - if not numpy.isscalar(w) and not getattr(w, 'ndim', -1) == 0: + if not numpy.isscalar(w) and not getattr(w, "ndim", -1) == 0: raise ValueError( - 'Given weight to {} was not scalar.'.format(k)) - if callable(v) or numpy.isscalar(v) or getattr(v, 'ndim', -1) == 0: + "Given weight to {} was not scalar.".format(k) + ) + if callable(v) or numpy.isscalar(v) or getattr(v, "ndim", -1) == 0: summaries[k].add(v, weight=w) def compute_mean(self) -> Dict[str, Scalar]: @@ -420,8 +431,10 @@ def compute_mean(self) -> Dict[str, Scalar]: dict: Dictionary of mean values. """ - return {name: summary.compute_mean() - for name, summary in self._summaries.items()} + return { + name: summary.compute_mean() + for name, summary in self._summaries.items() + } def make_statistics(self) -> Dict[str, Scalar]: """Creates a dictionary of statistics. @@ -439,13 +452,14 @@ def make_statistics(self) -> Dict[str, Scalar]: for name, summary in self._summaries.items(): mean, std = summary.make_statistics() stats[name] = mean - stats[name + '.std'] = std + stats[name + ".std"] = std return stats def state_dict(self) -> Dict[str, Any]: return { - name: summ.state_dict() for name, summ in self._summaries.items()} + name: summ.state_dict() for name, summ in self._summaries.items() + } def load_state_dict(self, to_load: Dict[str, Any]) -> None: self._summaries.clear() diff --git a/pytorch_pfn_extras/runtime/__init__.py b/pytorch_pfn_extras/runtime/__init__.py index 40bd83c15..0fb8258f8 100644 --- a/pytorch_pfn_extras/runtime/__init__.py +++ b/pytorch_pfn_extras/runtime/__init__.py @@ -1,6 +1,5 @@ +from pytorch_pfn_extras.runtime._registry import _RuntimeRegistry # NOQA from pytorch_pfn_extras.runtime._runtime import BaseRuntime # NOQA from pytorch_pfn_extras.runtime._runtime import PyTorchRuntime # NOQA -from pytorch_pfn_extras.runtime._registry import _RuntimeRegistry # NOQA - runtime_registry = _RuntimeRegistry(PyTorchRuntime) diff --git a/pytorch_pfn_extras/runtime/_autocast.py b/pytorch_pfn_extras/runtime/_autocast.py index 228f1ce1a..fe86b1b5c 100644 --- a/pytorch_pfn_extras/runtime/_autocast.py +++ b/pytorch_pfn_extras/runtime/_autocast.py @@ -1,13 +1,16 @@ import contextlib from typing import Any, Dict, Generator + from pytorch_pfn_extras._torch_version import requires _cuda_amp_available = False try: import torch.cuda.amp + _cuda_amp_available = torch.cuda.is_available() and hasattr( - torch.cuda.amp, 'autocast') + torch.cuda.amp, "autocast" + ) except ImportError: pass @@ -19,22 +22,23 @@ def __init__( has_grad_scaler: bool, ) -> None: autocast_options = autocast_options.copy() - self._enabled = autocast_options.pop('enabled', True) - self._device_type = autocast_options.pop('device_type', 'cuda') + self._enabled = autocast_options.pop("enabled", True) + self._device_type = autocast_options.pop("device_type", "cuda") self._options = autocast_options self._use_old_ac = not requires("1.10.0") - if ( - self._enabled and self._use_old_ac and self._device_type != 'cuda' - ): - raise RuntimeError("Autocast only work with CUDA devices for PyTorch 1.9") + if self._enabled and self._use_old_ac and self._device_type != "cuda": + raise RuntimeError( + "Autocast only work with CUDA devices for PyTorch 1.9" + ) if not _cuda_amp_available: - if ( - has_grad_scaler - or (self._enabled and self._device_type == "cuda") + if has_grad_scaler or ( + self._enabled and self._device_type == "cuda" ): - raise RuntimeError('Requested AMP features but torch.cuda.amp' - ' is not enabled') + raise RuntimeError( + "Requested AMP features but torch.cuda.amp" + " is not enabled" + ) @contextlib.contextmanager def autocast(self, enabled: bool = True) -> Generator[None, None, None]: diff --git a/pytorch_pfn_extras/runtime/_registry.py b/pytorch_pfn_extras/runtime/_registry.py index 768b7a0c2..295b64201 100644 --- a/pytorch_pfn_extras/runtime/_registry.py +++ b/pytorch_pfn_extras/runtime/_registry.py @@ -1,8 +1,7 @@ from typing import Dict, Type import torch - -from pytorch_pfn_extras.runtime._runtime import DeviceLike, BaseRuntime +from pytorch_pfn_extras.runtime._runtime import BaseRuntime, DeviceLike class _RuntimeRegistry: @@ -11,17 +10,18 @@ def __init__(self, fallback_class: Type[BaseRuntime]): self._fallback_class = fallback_class def register( - self, - device_type: str, - runtime_class: Type[BaseRuntime], + self, + device_type: str, + runtime_class: Type[BaseRuntime], ) -> None: self._runtimes[device_type] = runtime_class def get_runtime_class_for_device_spec( - self, device: DeviceLike) -> Type[BaseRuntime]: + self, device: DeviceLike + ) -> Type[BaseRuntime]: if isinstance(device, torch.device): device_type = device.type else: assert isinstance(device, str) - device_type = device.split(':')[0] + device_type = device.split(":")[0] return self._runtimes.get(device_type, self._fallback_class) diff --git a/pytorch_pfn_extras/runtime/_runtime.py b/pytorch_pfn_extras/runtime/_runtime.py index 8e27d284f..e374045da 100644 --- a/pytorch_pfn_extras/runtime/_runtime.py +++ b/pytorch_pfn_extras/runtime/_runtime.py @@ -1,12 +1,18 @@ import contextlib import types - from typing import ( - Any, Dict, Generator, Iterable, Optional, Set, Tuple, Union, TYPE_CHECKING + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + Optional, + Set, + Tuple, + Union, ) import torch - from pytorch_pfn_extras.handler._code_block import CodeBlock from pytorch_pfn_extras.runtime import _autocast @@ -35,7 +41,9 @@ class BaseRuntime: """ def __init__( - self, device_spec: DeviceLike, options: Dict[str, Any], + self, + device_spec: DeviceLike, + options: Dict[str, Any], ) -> None: self.device_spec = device_spec self.options = options @@ -56,7 +64,8 @@ def convert_batch(self, args: Any) -> Any: if isinstance(args, tuple) and hasattr(args, "_fields"): # namedtuple return args._replace( # type: ignore[attr-defined] - **self._convert_batch_dict(args._asdict())) # type: ignore + **self._convert_batch_dict(args._asdict()) # type: ignore + ) if isinstance(args, dict): return self._convert_batch_dict(args) if isinstance(args, (list, tuple)): @@ -142,7 +151,7 @@ def train_epoch_end(self, module: torch.nn.Module) -> None: def train_pre_step( self, - trainer: 'Trainer', + trainer: "Trainer", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -165,7 +174,7 @@ def train_pre_step( def train_post_step( self, - trainer: 'Trainer', + trainer: "Trainer", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -221,7 +230,7 @@ def train_validation_end(self, module: torch.nn.Module) -> None: def eval_pre_step( self, - evaluator: 'Evaluator', + evaluator: "Evaluator", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -241,7 +250,7 @@ def eval_pre_step( def eval_post_step( self, - evaluator: 'Evaluator', + evaluator: "Evaluator", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -262,7 +271,11 @@ def eval_post_step( """ raise NotImplementedError() - def execute(self, code_block: CodeBlock, batch: Any,) -> Any: + def execute( + self, + code_block: CodeBlock, + batch: Any, + ) -> Any: """Method called by the CodeBlocks API to do device dependent execution. Args: @@ -297,7 +310,9 @@ def map( @classmethod @contextlib.contextmanager - def trace(cls, event_name: Optional[str], arg: Any) -> Generator[None, None, None]: + def trace( + cls, event_name: Optional[str], arg: Any + ) -> Generator[None, None, None]: """Context manager for tracing PPE events in the custom device tools. Args: @@ -328,13 +343,18 @@ class PyTorchRuntime(BaseRuntime): """ def __init__( - self, device_spec: DeviceLike, options: Dict[str, Any], + self, + device_spec: DeviceLike, + options: Dict[str, Any], ) -> None: super().__init__(device_spec, options) self._grad_scaler = options.get("grad_scaler", None) autocast_options = options.get("autocast", False) if isinstance(autocast_options, bool): - autocast_options = {"enabled": autocast_options, "device_type": "cuda"} + autocast_options = { + "enabled": autocast_options, + "device_type": "cuda", + } self._autocast = _autocast._AutocastManager( autocast_options, self._grad_scaler is not None ) @@ -376,7 +396,7 @@ def train_validation_end(self, module: torch.nn.Module) -> None: def train_pre_step( self, - trainer: 'Trainer', + trainer: "Trainer", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -385,7 +405,7 @@ def train_pre_step( def train_post_step( self, - trainer: 'Trainer', + trainer: "Trainer", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -395,7 +415,7 @@ def train_post_step( def eval_pre_step( self, - evaluator: 'Evaluator', + evaluator: "Evaluator", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -404,7 +424,7 @@ def eval_pre_step( def eval_post_step( self, - evaluator: 'Evaluator', + evaluator: "Evaluator", module: torch.nn.Module, batch_idx: int, batch: Any, @@ -412,7 +432,11 @@ def eval_post_step( ) -> None: pass - def execute(self, code_block: CodeBlock, batch: Any,) -> Any: + def execute( + self, + code_block: CodeBlock, + batch: Any, + ) -> Any: # Run forward, backward and optimize steps depending on codeblock opts if self._grad_scaler is None: @@ -439,10 +463,7 @@ def _scale(x: torch.Tensor) -> torch.Tensor: isinstance(v, torch.Tensor) and v.grad_fn is not None and v.numel() == 1 - and ( - v.dtype.is_floating_point - or v.dtype.is_complex - ) + and (v.dtype.is_floating_point or v.dtype.is_complex) ): _scale(v).backward() # type: ignore[no-untyped-call] else: @@ -483,7 +504,9 @@ def map( @classmethod @contextlib.contextmanager - def trace(cls, event_name: Optional[str], arg: Any) -> Generator[None, None, None]: + def trace( + cls, event_name: Optional[str], arg: Any + ) -> Generator[None, None, None]: """Context manager for tracing PPE events in the custom device tools. Args: @@ -515,7 +538,9 @@ def _getstate_without_runtime(self): # type: ignore # remove runtime class and getstate def _remove_runtime_class(state): # type: ignore - state = {k: v for k, v in state.items() if k != _RUNTIME_TAG_NAME} + state = { + k: v for k, v in state.items() if k != _RUNTIME_TAG_NAME + } for k, v in state.items(): if isinstance(v, dict): state[k] = _remove_runtime_class(v) # type: ignore @@ -528,6 +553,7 @@ def _remove_runtime_class(state): # type: ignore return state return _remove_runtime_class(state) # type: ignore + return _getstate_without_runtime getstate = None @@ -537,7 +563,7 @@ def _remove_runtime_class(state): # type: ignore setattr( # NOQA module, "__getstate__", - types.MethodType(mk_getstate(getstate), module) # type: ignore + types.MethodType(mk_getstate(getstate), module), # type: ignore ) diff --git a/pytorch_pfn_extras/runtime/_to.py b/pytorch_pfn_extras/runtime/_to.py index 4db4b0556..32770d9b0 100644 --- a/pytorch_pfn_extras/runtime/_to.py +++ b/pytorch_pfn_extras/runtime/_to.py @@ -1,21 +1,19 @@ from typing import Any, Dict, Optional, Type, TypeVar -import torch - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras.runtime._runtime import DeviceLike, BaseRuntime - +import torch +from pytorch_pfn_extras.runtime._runtime import BaseRuntime, DeviceLike -ModuleOrTensor = TypeVar('ModuleOrTensor', torch.nn.Module, torch.Tensor) +ModuleOrTensor = TypeVar("ModuleOrTensor", torch.nn.Module, torch.Tensor) def to( - module_or_tensor: ModuleOrTensor, - device: DeviceLike, - *, - options: Optional[Dict[str, Any]] = None, - runtime_class: Optional[Type[BaseRuntime]] = None, - config: Optional[Dict[str, Any]] = None, + module_or_tensor: ModuleOrTensor, + device: DeviceLike, + *, + options: Optional[Dict[str, Any]] = None, + runtime_class: Optional[Type[BaseRuntime]] = None, + config: Optional[Dict[str, Any]] = None, ) -> ModuleOrTensor: """A function to transfer the given object to the given device. @@ -52,7 +50,7 @@ def to( if config is not None: options = config elif config is not None: - raise ValueError('options and config cannot be specified together') + raise ValueError("options and config cannot be specified together") if runtime_class is None: registry = ppe.runtime.runtime_registry @@ -67,4 +65,4 @@ def to( elif isinstance(obj, torch.Tensor): return runtime.move_tensor(obj) else: - raise ValueError('Unsupported type for module_or_tensor') + raise ValueError("Unsupported type for module_or_tensor") diff --git a/pytorch_pfn_extras/torchscript.py b/pytorch_pfn_extras/torchscript.py index f452a71fe..57788379c 100644 --- a/pytorch_pfn_extras/torchscript.py +++ b/pytorch_pfn_extras/torchscript.py @@ -1,18 +1,23 @@ from typing import Any, Callable, List, Tuple + import torch # Run jit pass with post lint -def run_jit_pass(p: Callable, g: torch._C.Graph, *args: Any, **kwargs: Any) -> None: +def run_jit_pass( + p: Callable, g: torch._C.Graph, *args: Any, **kwargs: Any +) -> None: p(g, *args, **kwargs) torch._C._jit_pass_lint(g) -def find_inplace(g: torch._C.Graph) -> Tuple[torch._C.Graph, List[torch._C.Node]]: +def find_inplace( + g: torch._C.Graph, +) -> Tuple[torch._C.Graph, List[torch._C.Node]]: g = g.copy() run_jit_pass(torch._C._jit_pass_inline, g) nodes = [] for n in g.nodes(): - if n.kind().endswith('_'): + if n.kind().endswith("_"): nodes.append(n) return g, nodes diff --git a/pytorch_pfn_extras/training/__init__.py b/pytorch_pfn_extras/training/__init__.py index ffe617918..8fae0df54 100644 --- a/pytorch_pfn_extras/training/__init__.py +++ b/pytorch_pfn_extras/training/__init__.py @@ -1,14 +1,16 @@ -from pytorch_pfn_extras.training.extension import Extension # NOQA -from pytorch_pfn_extras.training.extension import ExtensionEntry # NOQA -from pytorch_pfn_extras.training.extension import make_extension # NOQA +from pytorch_pfn_extras.training import extensions # NOQA +from pytorch_pfn_extras.training._evaluator import DistributedEvaluator # NOQA +from pytorch_pfn_extras.training._evaluator import Evaluator # NOQA +from pytorch_pfn_extras.training._manager_protocol import ( # NOQA + ExtensionsManagerProtocol, +) +from pytorch_pfn_extras.training._trainer import Trainer # NOQA from pytorch_pfn_extras.training.extension import PRIORITY_EDITOR # NOQA from pytorch_pfn_extras.training.extension import PRIORITY_READER # NOQA from pytorch_pfn_extras.training.extension import PRIORITY_WRITER # NOQA -from pytorch_pfn_extras.training import extensions # NOQA -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol # NOQA +from pytorch_pfn_extras.training.extension import Extension # NOQA +from pytorch_pfn_extras.training.extension import ExtensionEntry # NOQA +from pytorch_pfn_extras.training.extension import make_extension # NOQA from pytorch_pfn_extras.training.manager import ExtensionsManager # NOQA from pytorch_pfn_extras.training.manager import IgniteExtensionsManager # NOQA from pytorch_pfn_extras.training.metrics import AccuracyMetric # NOQA -from pytorch_pfn_extras.training._trainer import Trainer # NOQA -from pytorch_pfn_extras.training._evaluator import Evaluator # NOQA -from pytorch_pfn_extras.training._evaluator import DistributedEvaluator # NOQA diff --git a/pytorch_pfn_extras/training/_evaluator.py b/pytorch_pfn_extras/training/_evaluator.py index d48214dbc..f6bcefb79 100644 --- a/pytorch_pfn_extras/training/_evaluator.py +++ b/pytorch_pfn_extras/training/_evaluator.py @@ -1,22 +1,27 @@ import contextlib import queue from typing import ( - Any, Callable, Generator, Iterable, Mapping, Optional, Sequence, - Union, TYPE_CHECKING, + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Mapping, + Optional, + Sequence, + Union, ) import torch import torch.distributed - from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training.extensions import evaluator - from pytorch_pfn_extras.training.metrics import Batch as DictBatch if TYPE_CHECKING: from pytorch_pfn_extras.handler import BaseHandler - from pytorch_pfn_extras.training.metrics import MetricType from pytorch_pfn_extras.reporting import Observation + from pytorch_pfn_extras.training.metrics import MetricType @contextlib.contextmanager @@ -27,9 +32,9 @@ def _nullcontext() -> Generator[None, None, None]: @contextlib.contextmanager def _progress_bar( - name: str, - required: bool, - size: int, + name: str, + required: bool, + size: int, ) -> Generator[Callable[[int], None], None, None]: if required: progress = evaluator.IterationStatus(size) @@ -38,6 +43,7 @@ def _progress_bar( def update(i: int) -> None: progress.current_position = i pbar.update() + yield update pbar.close() @@ -47,21 +53,22 @@ def update(i: int) -> None: class Evaluator: def __init__( - self, - handler: 'BaseHandler', - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - *, - progress_bar: bool = False, - metrics: Optional[Sequence['MetricType']] = None, - profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] + self, + handler: "BaseHandler", + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + *, + progress_bar: bool = False, + metrics: Optional[Sequence["MetricType"]] = None, + profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] ): super().__init__() if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( - 'model must be an instance of dict or toch.nn.Module') - self.models = {'main': models} + "model must be an instance of dict or toch.nn.Module" + ) + self.models = {"main": models} else: self.models = models @@ -72,8 +79,7 @@ def __init__( self._profile = profile for name, model in self.models.items(): self._reporter.add_observer(name, model) - self._reporter.add_observers( - name, model.named_modules()) + self._reporter.add_observers(name, model.named_modules()) def _process_metrics(self, ins: DictBatch, outs: DictBatch) -> DictBatch: for metric in self._metrics: @@ -81,15 +87,16 @@ def _process_metrics(self, ins: DictBatch, outs: DictBatch) -> DictBatch: return outs def _complete_step( - self, idx: int, outs: DictBatch, *, is_deferred: bool = False + self, idx: int, outs: DictBatch, *, is_deferred: bool = False ) -> None: c_idx = self._idxs.get() # Asure that iterations complete in order if c_idx != idx: raise RuntimeError( - 'Completed a not expected iteration. ' - '{} was expected but completion of {} happened'.format( - c_idx, idx) + "Completed a not expected iteration. " + "{} was expected but completion of {} happened".format( + c_idx, idx + ) ) x = self._inputs.get() observed = self._observed.get() @@ -106,10 +113,7 @@ def _gather_summaries(self) -> None: pass def run( - self, - loader: Iterable[Any], - *, - eval_len: Optional[int] = None + self, loader: Iterable[Any], *, eval_len: Optional[int] = None ) -> None: """Executes the evaluation loop. @@ -120,9 +124,9 @@ def run( The number of iterations per one evaluation epoch. """ # Note: setup_manager is done by the Trainer. - self._idxs: 'queue.Queue[int]' = queue.Queue() - self._inputs: 'queue.Queue[DictBatch]' = queue.Queue() - self._observed: 'queue.Queue[Observation]' = queue.Queue() + self._idxs: "queue.Queue[int]" = queue.Queue() + self._inputs: "queue.Queue[DictBatch]" = queue.Queue() + self._observed: "queue.Queue[Observation]" = queue.Queue() if eval_len is None: eval_len = len(loader) # type: ignore[arg-type] @@ -131,7 +135,7 @@ def run( self._summary = reporting.DictSummary() observation: Observation = {} self.handler.eval_loop_begin(self) - self._pbar = _progress_bar('validation', self._progress_bar, eval_len) + self._pbar = _progress_bar("validation", self._progress_bar, eval_len) self._update = self._pbar.__enter__() loader_iter = iter(loader) with self._profile or _nullcontext() as prof: @@ -146,7 +150,8 @@ def run( self._observed.put(observation) with self._reporter.scope(observation): self.handler.eval_step( - self, idx, x, self._complete_step) + self, idx, x, self._complete_step + ) # Some of the DataLoaders might need an explicit break # since they could start cycling on their data if (idx + 1) == eval_len: @@ -165,14 +170,16 @@ def run( class DistributedEvaluator(Evaluator): def __init__( - self, - handler: 'BaseHandler', - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - *, - progress_bar: bool = False, - metrics: Optional[Sequence['MetricType']] = None, + self, + handler: "BaseHandler", + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + *, + progress_bar: bool = False, + metrics: Optional[Sequence["MetricType"]] = None, ): - super().__init__(handler, models, progress_bar=progress_bar, metrics=metrics) + super().__init__( + handler, models, progress_bar=progress_bar, metrics=metrics + ) if not torch.distributed.is_initialized(): # type: ignore[no-untyped-call] raise RuntimeError("PyTorch distributed module is not initialized.") diff --git a/pytorch_pfn_extras/training/_manager_protocol.py b/pytorch_pfn_extras/training/_manager_protocol.py index c0a9ad401..624acff1d 100644 --- a/pytorch_pfn_extras/training/_manager_protocol.py +++ b/pytorch_pfn_extras/training/_manager_protocol.py @@ -1,17 +1,15 @@ -from typing import Mapping, Optional, TYPE_CHECKING -from typing_extensions import Protocol +from typing import TYPE_CHECKING, Mapping, Optional import torch +from typing_extensions import Protocol if TYPE_CHECKING: + from pytorch_pfn_extras import reporting, writing from pytorch_pfn_extras.training import trigger as trigger_module from pytorch_pfn_extras.training.extension import Extension - from pytorch_pfn_extras import writing - from pytorch_pfn_extras import reporting class ExtensionsManagerProtocol(Protocol): - @property def iteration(self) -> int: ... @@ -53,7 +51,7 @@ def stop_trigger(self) -> bool: ... @property - def _stop_trigger(self) -> 'trigger_module.Trigger': + def _stop_trigger(self) -> "trigger_module.Trigger": ... @property @@ -61,16 +59,16 @@ def out(self) -> str: ... @property - def writer(self) -> Optional['writing.Writer']: + def writer(self) -> Optional["writing.Writer"]: ... @property - def reporter(self) -> 'reporting.Reporter': + def reporter(self) -> "reporting.Reporter": ... - def get_extension(self, name: str) -> 'Extension': + def get_extension(self, name: str) -> "Extension": ... @property - def observation(self) -> 'reporting.Observation': + def observation(self) -> "reporting.Observation": ... diff --git a/pytorch_pfn_extras/training/_trainer.py b/pytorch_pfn_extras/training/_trainer.py index 6e6d274b2..fe1c46d1c 100644 --- a/pytorch_pfn_extras/training/_trainer.py +++ b/pytorch_pfn_extras/training/_trainer.py @@ -4,24 +4,33 @@ import time import warnings from typing import ( - Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union, TYPE_CHECKING + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, ) +import pytorch_pfn_extras.reporting as reporting import torch - from pytorch_pfn_extras import training +from pytorch_pfn_extras.profiler import record from pytorch_pfn_extras.training import extension as extension from pytorch_pfn_extras.training import trigger as trigger_module -import pytorch_pfn_extras.reporting as reporting -from pytorch_pfn_extras.profiler import record - -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) from pytorch_pfn_extras.training.trigger import Trigger, TriggerLike if TYPE_CHECKING: from pytorch_pfn_extras import handler as handler_module - from pytorch_pfn_extras.training._evaluator import Evaluator from pytorch_pfn_extras.profiler._time_summary import _ReportNotification + from pytorch_pfn_extras.training._evaluator import Evaluator @contextlib.contextmanager @@ -32,67 +41,89 @@ def _nullcontext() -> Generator[None, None, None]: class Trainer: def __init__( - self, - handler: 'handler_module.BaseHandler', - *, - evaluator: Optional[Union[ - 'Evaluator', Tuple['Evaluator', TriggerLike], - Mapping[str, Union['Evaluator', Tuple['Evaluator', TriggerLike]]]]], - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] - **kwargs: Any, + self, + handler: "handler_module.BaseHandler", + *, + evaluator: Optional[ + Union[ + "Evaluator", + Tuple["Evaluator", TriggerLike], + Mapping[ + str, Union["Evaluator", Tuple["Evaluator", TriggerLike]] + ], + ] + ], + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined] + **kwargs: Any, ): self.handler = handler - self._manager: Optional['training.ExtensionsManager'] = None + self._manager: Optional["training.ExtensionsManager"] = None # The followings are used when setting up a manager instance if not isinstance(models, dict): if not isinstance(models, torch.nn.Module): raise ValueError( - 'model must be an instance of dict or toch.nn.Module') - self._models = {'main': models} + "model must be an instance of dict or toch.nn.Module" + ) + self._models = {"main": models} else: self._models = models self._kwargs = kwargs self._profile = profile - self._enable_profile = kwargs.get('enable_profile', profile is not None) + self._enable_profile = kwargs.get("enable_profile", profile is not None) self._extensions: List[ # list of (args, kwargs) - Tuple[Tuple[ - Union['extension.ExtensionLike', extension.ExtensionEntry], - Optional[str], 'TriggerLike', Optional[int] - ], Dict[str, Any]]] = [] + Tuple[ + Tuple[ + Union["extension.ExtensionLike", extension.ExtensionEntry], + Optional[str], + "TriggerLike", + Optional[int], + ], + Dict[str, Any], + ] + ] = [] self._manager_state: Optional[Dict[str, Any]] = None - self._evaluators: Dict[str, Tuple['Evaluator', TriggerLike]] = {} + self._evaluators: Dict[str, Tuple["Evaluator", TriggerLike]] = {} if evaluator is None: evaluator = {} elif not isinstance(evaluator, collections.abc.Mapping): evaluator = {"Evaluator": evaluator} if isinstance(evaluator, collections.abc.Mapping): for n, e in evaluator.items(): - self._evaluators[n] = e if isinstance(e, tuple) else (e, (1, 'epoch')) + self._evaluators[n] = ( + e if isinstance(e, tuple) else (e, (1, "epoch")) + ) self.val_loader = None def extend( - self, - extension: Union['extension.ExtensionLike', extension.ExtensionEntry], - name: Optional[str] = None, - trigger: 'TriggerLike' = None, - priority: Optional[int] = None, - *, - call_before_training: bool = False, - **kwargs: Any, + self, + extension: Union["extension.ExtensionLike", extension.ExtensionEntry], + name: Optional[str] = None, + trigger: "TriggerLike" = None, + priority: Optional[int] = None, + *, + call_before_training: bool = False, + **kwargs: Any, ) -> None: if self._manager is not None: - raise RuntimeError('cannot extend after starting the engine') + raise RuntimeError("cannot extend after starting the engine") self._extensions.append( - ((extension, name, trigger, priority), - dict(call_before_training=call_before_training, **kwargs))) + ( + (extension, name, trigger, priority), + dict(call_before_training=call_before_training, **kwargs), + ) + ) - def _setup_manager(self, iters_per_epoch: int) -> 'training.ExtensionsManager': + def _setup_manager( + self, iters_per_epoch: int + ) -> "training.ExtensionsManager": from pytorch_pfn_extras.training import ExtensionsManager + self._manager = ExtensionsManager( - self._models, iters_per_epoch=iters_per_epoch, **self._kwargs) + self._models, iters_per_epoch=iters_per_epoch, **self._kwargs + ) for ex_args, ex_kwargs in self._extensions: self._manager.extend(*ex_args, **ex_kwargs) if self._manager_state is not None: @@ -100,9 +131,9 @@ def _setup_manager(self, iters_per_epoch: int) -> 'training.ExtensionsManager': return self._manager @property - def manager(self) -> 'training.ExtensionsManager': + def manager(self) -> "training.ExtensionsManager": if self._manager is None: - raise RuntimeError('the engine is not started yet') + raise RuntimeError("the engine is not started yet") return self._manager @property @@ -148,34 +179,37 @@ def stop_trigger(self, trigger: Trigger) -> None: self._stop_trigger = trigger @property - def evaluator(self) -> Optional['Evaluator']: + def evaluator(self) -> Optional["Evaluator"]: if len(self._evaluators) == 0: return None if len(self._evaluators) == 1: return next(iter(self._evaluators.values()))[0] - raise ValueError('multiple evaluators are registered.') + raise ValueError("multiple evaluators are registered.") def get_optimizer(self, name: str) -> torch.optim.Optimizer: return self.manager.optimizers[name] - def set_optimizer(self, name: str, optimizer: torch.optim.Optimizer) -> None: + def set_optimizer( + self, name: str, optimizer: torch.optim.Optimizer + ) -> None: self.manager.optimizers[name] = optimizer # type: ignore[index] def is_epoch_last_iter(self, idx: int) -> bool: return (idx + 1) == (self.manager._iters_per_epoch) def _complete_step( - self, - idx: int, - outs: Any, + self, + idx: int, + outs: Any, ) -> None: c_idx = self._idxs.get() # Asure that iterations complete in order if c_idx != idx: raise RuntimeError( - 'Completed a not expected iteration. ' - '{} was expected but completion of {} happened'.format( - c_idx, idx) + "Completed a not expected iteration. " + "{} was expected but completion of {} happened".format( + c_idx, idx + ) ) x = self._inputs.get() begin = self._times.get() @@ -187,12 +221,14 @@ def _complete_step( self.handler.train_post_step(self, idx, x, outs) reporting.report({"elapsed_time": time.time() - begin}) - def run(self, - train_loader: Iterable[Any], - val_loader: Optional[Iterable[Any]] = None, - *, - train_len: Optional[int] = None, - eval_len: Optional[int] = None) -> None: + def run( + self, + train_loader: Iterable[Any], + val_loader: Optional[Iterable[Any]] = None, + *, + train_len: Optional[int] = None, + eval_len: Optional[int] = None, + ) -> None: """Executes the training loop. Args: @@ -222,11 +258,11 @@ def run(self, class _EvaluatorExt: def __init__( - self, - trainer: 'Trainer', - evaluator: 'Evaluator', - val_loader: Optional[Iterable[Any]], - eval_len: Optional[int], + self, + trainer: "Trainer", + evaluator: "Evaluator", + val_loader: Optional[Iterable[Any]], + eval_len: Optional[int], ) -> None: self.needs_model_state = True self._trainer = trainer @@ -238,7 +274,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: evaluator = self._evaluator if self._val_loader is None: raise ValueError('"val_loader" is not given.') - evaluator.handler.train_validation_begin(self._trainer, evaluator) + evaluator.handler.train_validation_begin( + self._trainer, evaluator + ) evaluator.run(self._val_loader, eval_len=self._eval_len) evaluator.handler.train_validation_end(self._trainer, evaluator) @@ -257,11 +295,12 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: if len(self._evaluators) == 0: if val_loader is not None: warnings.warn( - '`val_loader` is given whereas the evaluator is missing.', - UserWarning) + "`val_loader` is given whereas the evaluator is missing.", + UserWarning, + ) else: if val_loader is None: - raise ValueError('`val_loader` is required') + raise ValueError("`val_loader` is required") for _, (evaluator, _) in self._evaluators.items(): evaluator.handler.eval_setup(evaluator, val_loader) @@ -271,27 +310,30 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # When iterations are completed in the callback # This is needed to avoid being constantly passing parameters - self._idxs: 'queue.Queue[int]' = queue.Queue() - self._inputs: 'queue.Queue[Any]' = queue.Queue() - self._times: 'queue.Queue[float]' = queue.Queue() - self._observed: 'queue.Queue[reporting.Observation]' = queue.Queue() + self._idxs: "queue.Queue[int]" = queue.Queue() + self._inputs: "queue.Queue[Any]" = queue.Queue() + self._times: "queue.Queue[float]" = queue.Queue() + self._observed: "queue.Queue[reporting.Observation]" = ( + queue.Queue() + ) # Iterator must be created after `train_epoch_begin` as it may be # using a DistributedSampler. loader_iter = iter(train_loader) - self._profile_records: 'queue.Queue[List[_ReportNotification]]' \ - = queue.Queue() + self._profile_records: "queue.Queue[List[_ReportNotification]]" = ( + queue.Queue() + ) for idx in range(train_len): with record( "pytorch_pfn_extras.training.Trainer:iteration", use_cuda=torch.cuda.is_available(), enable=self._enable_profile, - device=device + device=device, ) as ntf0: try: with record( "pytorch_pfn_extras.training.Trainer:get_data", enable=self._enable_profile, - device=device + device=device, ): x = next(loader_iter) except StopIteration: @@ -299,7 +341,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: with record( "pytorch_pfn_extras.training.Trainer:get_data", enable=self._enable_profile, - device=device + device=device, ): x = next(loader_iter) begin = time.time() @@ -311,19 +353,24 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: "pytorch_pfn_extras.training.Trainer:run_iteration", use_cuda=torch.cuda.is_available(), enable=self._enable_profile, - device=device - ) as ntf1, \ - self.manager.run_iteration(): + device=device, + ) as ntf1, self.manager.run_iteration(): self._observed.put(self.manager.observation) with record( "pytorch_pfn_extras.training.Trainer:train_step", use_cuda=torch.cuda.is_available(), enable=self._enable_profile, - device=device + device=device, ) as ntf2: - self._profile_records.put([ntf0, ntf1, ntf2]) + self._profile_records.put( + [ntf0, ntf1, ntf2] + ) self.handler.train_step( - self, idx, x, complete_fn=self._complete_step) + self, + idx, + x, + complete_fn=self._complete_step, + ) # Check if the callback was called except Exception: # The manager has errored and called the extensions @@ -340,7 +387,10 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # And will keep yielding results even if the epoch # is completed. We forcefully exit at the end of # every epoch - if self.is_epoch_last_iter(idx) or self.manager.stop_trigger: + if ( + self.is_epoch_last_iter(idx) + or self.manager.stop_trigger + ): break # In handlers that support a completely Async model train_epoch_end # Will take care of completing pending work diff --git a/pytorch_pfn_extras/training/_transform_model.py b/pytorch_pfn_extras/training/_transform_model.py index 5fa416e4d..9a0dd4ff1 100644 --- a/pytorch_pfn_extras/training/_transform_model.py +++ b/pytorch_pfn_extras/training/_transform_model.py @@ -1,12 +1,10 @@ import typing import torch -from torch.nn.parallel import DistributedDataParallel - from pytorch_pfn_extras.nn.parallel import ( DistributedDataParallel as PpeDistributedDataParallel, ) - +from torch.nn.parallel import DistributedDataParallel _TransformModel = typing.Callable[[str, torch.nn.Module], torch.nn.Module] diff --git a/pytorch_pfn_extras/training/_trigger_util.py b/pytorch_pfn_extras/training/_trigger_util.py index c35fb6d20..a2e957b64 100644 --- a/pytorch_pfn_extras/training/_trigger_util.py +++ b/pytorch_pfn_extras/training/_trigger_util.py @@ -1,10 +1,13 @@ -from typing import Any, Callable, Dict, Union, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple, Union -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) class Trigger: """Base class for triggers.""" + def may_fire(self, iteration: int, epoch_len: int) -> bool: """Flags if the trigger may fire at the current iteration @@ -24,7 +27,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: class _CallableTrigger(Trigger): - def __init__(self, func: Callable[[ExtensionsManagerProtocol], bool]) -> None: + def __init__( + self, func: Callable[[ExtensionsManagerProtocol], bool] + ) -> None: self.func = func def __call__(self, manager: ExtensionsManagerProtocol) -> bool: diff --git a/pytorch_pfn_extras/training/_util.py b/pytorch_pfn_extras/training/_util.py index 462dae9cd..ce3759c03 100644 --- a/pytorch_pfn_extras/training/_util.py +++ b/pytorch_pfn_extras/training/_util.py @@ -12,7 +12,7 @@ def _get_ignite_version(version: str) -> List[int]: # major and minor ids can be only integers. # Some examples of versions are: # 0.1.0, 0.1.1, 0.3.0.dev20191007, 0.3.0. - version_regexp = r'^[0-9]+\.[0-9]+\.[0-9]+(\.[0-9a-zA-Z]+)?$' + version_regexp = r"^[0-9]+\.[0-9]+\.[0-9]+(\.[0-9a-zA-Z]+)?$" if re.search(version_regexp, version): - return [int(x) for x in version.split('.')[:2]] - raise ValueError('Invalid version format') + return [int(x) for x in version.split(".")[:2]] + raise ValueError("Invalid version format") diff --git a/pytorch_pfn_extras/training/extension.py b/pytorch_pfn_extras/training/extension.py index 05353d9a8..caa79f62e 100644 --- a/pytorch_pfn_extras/training/extension.py +++ b/pytorch_pfn_extras/training/extension.py @@ -1,11 +1,14 @@ -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING import types +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from pytorch_pfn_extras.training import _trigger_util -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) if TYPE_CHECKING: from pytorch_pfn_extras.training._trigger_util import TriggerLike + ExtensionLike = Callable[[ExtensionsManagerProtocol], Any] @@ -46,7 +49,8 @@ class Extension: :meth:`pytorch_pfn_extras.ExtensionsManager.extend` for details. """ - trigger: 'TriggerLike' = (1, 'iteration') + + trigger: "TriggerLike" = (1, "iteration") priority: int = PRIORITY_READER name: Optional[str] = None needs_model_state = False @@ -77,15 +81,18 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> Any: """ raise NotImplementedError( - 'Extension implementation must override __call__.') + "Extension implementation must override __call__." + ) def __getattr__(self, name: str) -> Any: - if name == 'invoke_before_training': + if name == "invoke_before_training": raise AttributeError( - 'invoke_before_training has been removed since Chainer ' - 'v2.0.0. Use Extension.initialize instead.') - raise AttributeError('{} object has no attribute {}'.format( - type(self).__name__, name)) + "invoke_before_training has been removed since Chainer " + "v2.0.0. Use Extension.initialize instead." + ) + raise AttributeError( + "{} object has no attribute {}".format(type(self).__name__, name) + ) def finalize(self, manager: ExtensionsManagerProtocol) -> None: """Finalizes the extension. @@ -113,10 +120,10 @@ def initialize(self, manager: ExtensionsManagerProtocol) -> None: pass def on_error( - self, - manager: ExtensionsManagerProtocol, - exc: Exception, - tb: types.TracebackType, + self, + manager: ExtensionsManagerProtocol, + exc: Exception, + tb: types.TracebackType, ) -> None: """Handles the error raised during training before finalization. @@ -148,47 +155,47 @@ def load_state_dict(self, to_load: Dict[str, Any]) -> None: class _WrappedExtension(Extension): - - def __init__(self, ext: 'ExtensionLike') -> None: + def __init__(self, ext: "ExtensionLike") -> None: self._ext = ext - self.trigger = getattr(self._ext, 'trigger', Extension.trigger) - self.priority = getattr(self._ext, 'priority', Extension.priority) + self.trigger = getattr(self._ext, "trigger", Extension.trigger) + self.priority = getattr(self._ext, "priority", Extension.priority) super().__init__() @property def default_name(self) -> str: - return getattr(self._ext, 'default_name', None) or super().default_name + return getattr(self._ext, "default_name", None) or super().default_name def __call__(self, manager: ExtensionsManagerProtocol) -> None: self._ext(manager) def finalize(self, manager: ExtensionsManagerProtocol) -> None: - getattr(self._ext, 'finalize', super().finalize)(manager) + getattr(self._ext, "finalize", super().finalize)(manager) def initialize(self, manager: ExtensionsManagerProtocol) -> None: - getattr(self._ext, 'initialize', super().initialize)(manager) + getattr(self._ext, "initialize", super().initialize)(manager) def on_error( - self, - manager: ExtensionsManagerProtocol, - exc: Exception, - tb: types.TracebackType, + self, + manager: ExtensionsManagerProtocol, + exc: Exception, + tb: types.TracebackType, ) -> None: - getattr(self._ext, 'on_error', super().on_error)(manager, exc, tb) + getattr(self._ext, "on_error", super().on_error)(manager, exc, tb) _OnErrorType = Callable[ - [ExtensionsManagerProtocol, Exception, types.TracebackType], None] + [ExtensionsManagerProtocol, Exception, types.TracebackType], None +] def make_extension( - trigger: 'TriggerLike' = Extension.trigger, - default_name: Optional[str] = None, - priority: int = Extension.priority, - finalizer: 'ExtensionLike' = lambda manager: None, - initializer: 'ExtensionLike' = lambda manager: None, - on_error: _OnErrorType = lambda manager, exc, tb: None, -) -> Callable[['ExtensionLike'], 'ExtensionLike']: + trigger: "TriggerLike" = Extension.trigger, + default_name: Optional[str] = None, + priority: int = Extension.priority, + finalizer: "ExtensionLike" = lambda manager: None, + initializer: "ExtensionLike" = lambda manager: None, + on_error: _OnErrorType = lambda manager, exc, tb: None, +) -> Callable[["ExtensionLike"], "ExtensionLike"]: """Decorator to make given function into an extension. This decorator just adds some attributes to a given function. The value of @@ -210,7 +217,8 @@ def make_extension( called after an error is raised during the training loop. """ - def decorator(ext: 'ExtensionLike') -> 'ExtensionLike': + + def decorator(ext: "ExtensionLike") -> "ExtensionLike": ext.trigger = trigger # type: ignore ext.default_name = default_name or ext.__name__ # type: ignore ext.priority = priority # type: ignore @@ -222,7 +230,7 @@ def decorator(ext: 'ExtensionLike') -> 'ExtensionLike': return decorator -def _as_extension(ext: 'ExtensionLike') -> Extension: +def _as_extension(ext: "ExtensionLike") -> Extension: return ext if isinstance(ext, Extension) else _WrappedExtension(ext) @@ -243,39 +251,42 @@ class ExtensionEntry: """ def __init__( - self, - extension: 'ExtensionLike', - *, - name: Optional[str] = None, - priority: Optional[int] = None, - trigger: Optional['TriggerLike'] = None, - call_before_training: bool = False, + self, + extension: "ExtensionLike", + *, + name: Optional[str] = None, + priority: Optional[int] = None, + trigger: Optional["TriggerLike"] = None, + call_before_training: bool = False, ) -> None: self.extension = _as_extension(extension) self.priority = priority or self.extension.priority self.call_before_training = call_before_training self._update_trigger(trigger or self.extension.trigger) - self._update_name(name or self.extension.name or self.extension.default_name) + self._update_name( + name or self.extension.name or self.extension.default_name + ) - def _update_trigger(self, trigger: 'TriggerLike') -> None: + def _update_trigger(self, trigger: "TriggerLike") -> None: self.trigger = _trigger_util.get_trigger(trigger) def _update_name(self, name: str) -> None: - if name == 'training': + if name == "training": raise ValueError( - 'the name "training" is prohibited as an extension name') + 'the name "training" is prohibited as an extension name' + ) self.name = name self.extension.name = name def state_dict(self) -> Dict[str, Any]: state = {} - state['extension'] = self.extension.state_dict() - state['trigger'] = self.trigger.state_dict() + state["extension"] = self.extension.state_dict() + state["trigger"] = self.trigger.state_dict() return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - if 'extension' in to_load: - self.extension.load_state_dict(to_load['extension']) - if 'trigger' in to_load: - self.trigger.load_state_dict(to_load['trigger']) + if "extension" in to_load: + self.extension.load_state_dict(to_load["extension"]) + if "trigger" in to_load: + self.trigger.load_state_dict(to_load["trigger"]) diff --git a/pytorch_pfn_extras/training/extensions/__init__.py b/pytorch_pfn_extras/training/extensions/__init__.py index 26c1074f3..5a3285a26 100644 --- a/pytorch_pfn_extras/training/extensions/__init__.py +++ b/pytorch_pfn_extras/training/extensions/__init__.py @@ -1,30 +1,62 @@ from pytorch_pfn_extras.training.extensions import snapshot_writers # NOQA +from pytorch_pfn_extras.training.extensions import util as _util from pytorch_pfn_extras.training.extensions._snapshot import snapshot # NOQA -from pytorch_pfn_extras.training.extensions._snapshot import snapshot_object # NOQA +from pytorch_pfn_extras.training.extensions._snapshot import ( # NOQA + snapshot_object, +) from pytorch_pfn_extras.training.extensions.best_value import BestValue # NOQA from pytorch_pfn_extras.training.extensions.best_value import MaxValue # NOQA from pytorch_pfn_extras.training.extensions.best_value import MinValue # NOQA -from pytorch_pfn_extras.training.extensions.evaluator import Evaluator, DistributedEvaluator, IgniteEvaluator # NOQA -from pytorch_pfn_extras.training.extensions.fail_on_non_number import FailOnNonNumber # NOQA +from pytorch_pfn_extras.training.extensions.evaluator import ( # NOQA + DistributedEvaluator, + Evaluator, + IgniteEvaluator, +) +from pytorch_pfn_extras.training.extensions.fail_on_non_number import ( # NOQA + FailOnNonNumber, +) from pytorch_pfn_extras.training.extensions.log_report import LogReport # NOQA -from pytorch_pfn_extras.training.extensions.lr_scheduler import LRScheduler # NOQA -from pytorch_pfn_extras.training.extensions.micro_average import MicroAverage # NOQA -from pytorch_pfn_extras.training.extensions.profile_report import ProfileReport # NOQA -from pytorch_pfn_extras.training.extensions.parameter_statistics import ParameterStatistics # NOQA -from pytorch_pfn_extras.training.extensions.plot_report import PlotReport # NOQA -from pytorch_pfn_extras.training.extensions.profile_report import ProfileReport # NOQA -from pytorch_pfn_extras.training.extensions.slack import Slack, SlackWebhook # NOQA -from pytorch_pfn_extras.training.extensions.value_observation import observe_lr # NOQA -from pytorch_pfn_extras.training.extensions.value_observation import observe_value # NOQA -from pytorch_pfn_extras.training.extensions.variable_statistics_plot import VariableStatisticsPlot # NOQA -from pytorch_pfn_extras.training.extensions import util as _util - -from pytorch_pfn_extras.training.extensions.print_report import PrintReport as PrintReportCLI # NOQA -from pytorch_pfn_extras.training.extensions.progress_bar import ProgressBar as ProgressBarCLI # NOQA +from pytorch_pfn_extras.training.extensions.lr_scheduler import ( # NOQA + LRScheduler, +) +from pytorch_pfn_extras.training.extensions.micro_average import ( # NOQA + MicroAverage, +) +from pytorch_pfn_extras.training.extensions.parameter_statistics import ( # NOQA + ParameterStatistics, +) +from pytorch_pfn_extras.training.extensions.plot_report import ( # NOQA + PlotReport, +) +from pytorch_pfn_extras.training.extensions.print_report import ( + PrintReport as PrintReportCLI, # NOQA +) +from pytorch_pfn_extras.training.extensions.profile_report import ( # NOQA + ProfileReport, +) +from pytorch_pfn_extras.training.extensions.progress_bar import ( + ProgressBar as ProgressBarCLI, # NOQA +) +from pytorch_pfn_extras.training.extensions.slack import ( # NOQA + Slack, + SlackWebhook, +) +from pytorch_pfn_extras.training.extensions.value_observation import ( # NOQA + observe_lr, + observe_value, +) +from pytorch_pfn_extras.training.extensions.variable_statistics_plot import ( # NOQA + VariableStatisticsPlot, +) try: - from pytorch_pfn_extras.training.extensions.print_report_notebook import PrintReportNotebook # NOQA - from pytorch_pfn_extras.training.extensions.progress_bar_notebook import ProgressBarNotebook # NOQA + from pytorch_pfn_extras.training.extensions.print_report_notebook import ( # NOQA + PrintReportNotebook, + ) + from pytorch_pfn_extras.training.extensions.progress_bar_notebook import ( # NOQA + ProgressBarNotebook, + ) + _ipython_module_available = True except ImportError: _ipython_module_available = False diff --git a/pytorch_pfn_extras/training/extensions/_snapshot.py b/pytorch_pfn_extras/training/extensions/_snapshot.py index 1fcb275c3..ce6dadcc1 100644 --- a/pytorch_pfn_extras/training/extensions/_snapshot.py +++ b/pytorch_pfn_extras/training/extensions/_snapshot.py @@ -4,18 +4,19 @@ import torch import torch.distributed - -from pytorch_pfn_extras import logging +from pytorch_pfn_extras import logging, writing from pytorch_pfn_extras.training import extension -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol -from pytorch_pfn_extras import writing - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) logger = logging._get_root_logger() -def _find_snapshot_files(fmt: str, path: str, fs: Any) -> List[Tuple[float, str]]: - '''Only prefix and suffix match +def _find_snapshot_files( + fmt: str, path: str, fs: Any +) -> List[Tuple[float, str]]: + """Only prefix and suffix match TODO(kuenishi): currently clean format string such as "snapshot{.iteration}.npz" can only be parsed, but tricky (or @@ -33,16 +34,20 @@ def _find_snapshot_files(fmt: str, path: str, fs: Any) -> List[Tuple[float, str] A sorted list of pair of ``mtime, filename``, whose file name that matched the format ``fmt`` directly under ``path``. - ''' - prefix = fmt.split('{')[0] - suffix = fmt.split('}')[-1] + """ + prefix = fmt.split("{")[0] + suffix = fmt.split("}")[-1] - matched_files = (file for file in fs.list(path) - if file.startswith(prefix) and file.endswith(suffix)) + matched_files = ( + file + for file in fs.list(path) + if file.startswith(prefix) and file.endswith(suffix) + ) def _prepend_mtime(f: str) -> Any: t = fs.stat(os.path.join(path, f)).last_modified return (t, f) + return sorted(_prepend_mtime(file) for file in matched_files) @@ -63,7 +68,7 @@ def _find_latest_snapshot(fmt: str, path: str, fs: Any) -> Optional[str]: """ snapshot_files = _find_snapshot_files(fmt, path, fs) - logger.debug('found snapshot files {}'.format(snapshot_files)) + logger.debug("found snapshot files {}".format(snapshot_files)) if len(snapshot_files) > 0: _, filename = snapshot_files[-1] return filename @@ -71,7 +76,7 @@ def _find_latest_snapshot(fmt: str, path: str, fs: Any) -> Optional[str]: def _find_stale_snapshots( - fmt: str, path: str, n_retains: int, fs: Any + fmt: str, path: str, n_retains: int, fs: Any ) -> Generator[str, None, None]: """Finds stale snapshots in a directory, retaining several files @@ -100,7 +105,8 @@ def _find_stale_snapshots( def snapshot_object( - target: Any, filename: str, savefun: Any = None, **kwargs: Any) -> '_Snapshot': + target: Any, filename: str, savefun: Any = None, **kwargs: Any +) -> "_Snapshot": """snapshot_object(target, filename, savefun=None, \ *, condition=None, writer=None, snapshot_on_error=False, \ n_retains=-1, autoload=False) @@ -160,22 +166,21 @@ def snapshot_object( - :meth:`pytorch_pfn_extras.training.extensions.snapshot` """ - return snapshot(target=target, filename=filename, savefun=savefun, - **kwargs) + return snapshot(target=target, filename=filename, savefun=savefun, **kwargs) def snapshot( - savefun: Any = None, - filename: str = 'snapshot_iter_{.iteration}', - *, - target: Any = None, - condition: Any = None, - writer: Optional[writing.Writer] = None, - snapshot_on_error: bool = False, - n_retains: int = -1, - autoload: bool = False, - saver_rank: Optional[int] = None, -) -> '_Snapshot': + savefun: Any = None, + filename: str = "snapshot_iter_{.iteration}", + *, + target: Any = None, + condition: Any = None, + writer: Optional[writing.Writer] = None, + snapshot_on_error: bool = False, + n_retains: int = -1, + autoload: bool = False, + saver_rank: Optional[int] = None, +) -> "_Snapshot": """ Returns a trainer extension to take snapshots of the trainer. @@ -277,17 +282,31 @@ def __call__(self, x): """ if savefun is not None and writer is not None: raise TypeError( - 'savefun and writer arguments cannot be specified together.') + "savefun and writer arguments cannot be specified together." + ) if saver_rank is None: return _Snapshot( - target=target, condition=condition, writer=writer, - filename=filename, snapshot_on_error=snapshot_on_error, - n_retains=n_retains, autoload=autoload, savefun=savefun) + target=target, + condition=condition, + writer=writer, + filename=filename, + snapshot_on_error=snapshot_on_error, + n_retains=n_retains, + autoload=autoload, + savefun=savefun, + ) return _DistributedSnapshot( - target=target, condition=condition, writer=writer, filename=filename, - snapshot_on_error=snapshot_on_error, n_retains=n_retains, - autoload=autoload, saver_rank=saver_rank, savefun=savefun) + target=target, + condition=condition, + writer=writer, + filename=filename, + snapshot_on_error=snapshot_on_error, + n_retains=n_retains, + autoload=autoload, + saver_rank=saver_rank, + savefun=savefun, + ) def _always_true() -> bool: @@ -308,20 +327,21 @@ class _Snapshot(extension.Extension): The default priority is -100, which is lower than that of most built-in extensions. """ - trigger = 1, 'epoch' + + trigger = 1, "epoch" priority = extension.PRIORITY_SNAPSHOT needs_model_state = True def __init__( - self, - target: Any = None, - condition: Any = None, - writer: Optional[writing.Writer] = None, - filename: str = 'snapshot_iter_{.iteration}', - snapshot_on_error: bool = False, - n_retains: int = -1, - autoload: bool = False, - savefun: Any = None, + self, + target: Any = None, + condition: Any = None, + writer: Optional[writing.Writer] = None, + filename: str = "snapshot_iter_{.iteration}", + snapshot_on_error: bool = False, + n_retains: int = -1, + autoload: bool = False, + savefun: Any = None, ) -> None: if condition is None: condition = _always_true @@ -335,7 +355,8 @@ def __init__( self._savefun = savefun def initialize( # type: ignore[override] - self, manager: ExtensionsManagerProtocol) -> Optional[str]: + self, manager: ExtensionsManagerProtocol + ) -> Optional[str]: target = manager if self._target is None else self._target writer = manager.writer if self.writer is None else self.writer self.writer = writer @@ -347,16 +368,22 @@ def initialize( # type: ignore[override] # terms of mtime, and tries to load it it the target or # manager. assert writer is not None - loaded_fn = _find_latest_snapshot(self.filename, writer.out_dir, writer.fs) + loaded_fn = _find_latest_snapshot( + self.filename, writer.out_dir, writer.fs + ) if loaded_fn: - snapshot_file = writer.fs.open(os.path.join(writer.out_dir, loaded_fn), 'rb') + snapshot_file = writer.fs.open( + os.path.join(writer.out_dir, loaded_fn), "rb" + ) # As described above (at ``autoload`` option), # snapshot files to be autoloaded must be saved by # ``save_npz`` . In order to support general format, # we nned to first reconstruct the design of savefun # and loadfun. - state = torch.load(snapshot_file, # type: ignore[no-untyped-call] - map_location=torch.device("cpu")) + state = torch.load( + snapshot_file, # type: ignore[no-untyped-call] + map_location=torch.device("cpu"), + ) if type(target) is dict: for k in target: target[k].load_state_dict(state[k]) @@ -364,9 +391,11 @@ def initialize( # type: ignore[override] target.load_state_dict(state) snapshot_file.close() - if (hasattr(writer, '_add_cleanup_hook') - and self.n_retains > 0 - and isinstance(self.filename, str)): + if ( + hasattr(writer, "_add_cleanup_hook") + and self.n_retains > 0 + and isinstance(self.filename, str) + ): # This block sets a method to automatic cleanup of stale # snapshots, when ``n_retains`` argument is positive # number. When the given snapshot writer is Chainer's @@ -375,8 +404,9 @@ def initialize( # type: ignore[override] # injected here. def _cleanup() -> None: assert writer is not None - files = _find_stale_snapshots(self.filename, writer.out_dir, - self.n_retains, writer.fs) + files = _find_stale_snapshots( + self.filename, writer.out_dir, self.n_retains, writer.fs + ) for file in files: writer.fs.remove(os.path.join(writer.out_dir, file)) @@ -386,10 +416,10 @@ def _cleanup() -> None: return loaded_fn def on_error( - self, - manager: ExtensionsManagerProtocol, - exc: Exception, - tb: types.TracebackType, + self, + manager: ExtensionsManagerProtocol, + exc: Exception, + tb: types.TracebackType, ) -> None: super().on_error(manager, exc, tb) if self._snapshot_on_error: @@ -406,8 +436,7 @@ def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None: # We need to get a dictionary with the state here if type(target) is dict: - serialized_target = { - k: v.state_dict() for k, v in target.items()} + serialized_target = {k: v.state_dict() for k, v in target.items()} else: serialized_target = target.state_dict() filename = self.filename @@ -417,7 +446,8 @@ def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None: filename = filename.format(manager) outdir = manager.out writer( # type: ignore - filename, outdir, serialized_target, savefun=self._savefun) + filename, outdir, serialized_target, savefun=self._savefun + ) def finalize(self, manager: ExtensionsManagerProtocol) -> None: self.writer.finalize() # type: ignore @@ -437,34 +467,46 @@ class _DistributedSnapshot(_Snapshot): The default priority is lower than that of most built-in extensions. """ - trigger = 1, 'epoch' + + trigger = 1, "epoch" priority = extension.PRIORITY_SNAPSHOT def __init__( - self, - target: Any = None, - condition: Any = None, - writer: Optional[writing.Writer] = None, - filename: str = 'snapshot_iter_{.iteration}', - snapshot_on_error: bool = False, - n_retains: int = -1, - autoload: bool = False, - saver_rank: int = 0, - savefun: Any = None, + self, + target: Any = None, + condition: Any = None, + writer: Optional[writing.Writer] = None, + filename: str = "snapshot_iter_{.iteration}", + snapshot_on_error: bool = False, + n_retains: int = -1, + autoload: bool = False, + saver_rank: int = 0, + savefun: Any = None, ): - super().__init__(target, condition, writer, filename, - snapshot_on_error, n_retains, - autoload, savefun) + super().__init__( + target, + condition, + writer, + filename, + snapshot_on_error, + n_retains, + autoload, + savefun, + ) # To support distributed snapshots if not torch.distributed.is_initialized(): # type: ignore[no-untyped-call] - raise RuntimeError('The Distributed Snapshot extension', - ' requires torch.distributed to be initialized') + raise RuntimeError( + "The Distributed Snapshot extension", + " requires torch.distributed to be initialized", + ) self._saver_rank = saver_rank self._size = torch.distributed.get_world_size() # type: ignore[no-untyped-call] self._rank = torch.distributed.get_rank() # type: ignore[no-untyped-call] if not (0 <= saver_rank < self._size): - raise ValueError('Distributed snapshot requires a saver rank' - ' in the range [0-{})'.format(self._size)) + raise ValueError( + "Distributed snapshot requires a saver rank" + " in the range [0-{})".format(self._size) + ) def __call__(self, manager: ExtensionsManagerProtocol) -> None: if self.condition(): diff --git a/pytorch_pfn_extras/training/extensions/best_value.py b/pytorch_pfn_extras/training/extensions/best_value.py index 0c0f0b4b9..7bec96c42 100644 --- a/pytorch_pfn_extras/training/extensions/best_value.py +++ b/pytorch_pfn_extras/training/extensions/best_value.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING - -from pytorch_pfn_extras.training import extension -from pytorch_pfn_extras.training import triggers -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from pytorch_pfn_extras.training import extension, triggers +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) if TYPE_CHECKING: from pytorch_pfn_extras.training._trigger_util import TriggerLike @@ -24,13 +24,13 @@ class BestValue(extension.Extension): :class:`~pytorch_pfn_extras.triggers.BestValueTrigger`. """ - default_name = 'best_value' + default_name = "best_value" def __init__( - self, - key: str, - compare: Callable[[float, float], bool], - trigger: 'TriggerLike' = (1, 'epoch'), + self, + key: str, + compare: Callable[[float, float], bool], + trigger: "TriggerLike" = (1, "epoch"), ) -> None: self._best_epoch: Optional[int] = None self._best_it: Optional[int] = None @@ -42,8 +42,10 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: def _check_best_value_exists(self) -> None: if self._best_trigger._best_value is None: - raise RuntimeError("Best observation hasn't been obtained. " - "Run the BestValue extension at least once") + raise RuntimeError( + "Best observation hasn't been obtained. " + "Run the BestValue extension at least once" + ) @property def best_value(self) -> float: @@ -74,15 +76,15 @@ def best_epoch(self) -> int: def state_dict(self) -> Dict[str, Any]: return { - '_best_trigger': self._best_trigger.state_dict(), - '_best_it': self._best_it, - '_best_epoch': self._best_epoch + "_best_trigger": self._best_trigger.state_dict(), + "_best_it": self._best_it, + "_best_epoch": self._best_epoch, } def load_state_dict(self, to_load: Dict[str, Any]) -> None: - self._best_trigger.load_state_dict(to_load['_best_trigger']) - self._best_it = to_load['_best_it'] - self._best_epoch = to_load['_best_epoch'] + self._best_trigger.load_state_dict(to_load["_best_trigger"]) + self._best_it = to_load["_best_it"] + self._best_epoch = to_load["_best_epoch"] class MaxValue(BestValue): @@ -97,11 +99,12 @@ class MaxValue(BestValue): :class:`~pytorch_pfn_extras.triggers.BestValueTrigger`. """ - default_name = 'max_value' + default_name = "max_value" - def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')): + def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")): super().__init__( - key, lambda max_value, new_value: new_value > max_value, trigger) + key, lambda max_value, new_value: new_value > max_value, trigger + ) class MinValue(BestValue): @@ -116,8 +119,9 @@ class MinValue(BestValue): :class:`~pytorch_pfn_extras.triggers.BestValueTrigger`. """ - default_name = 'min_value' + default_name = "min_value" - def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')): + def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")): super().__init__( - key, lambda min_value, new_value: new_value < min_value, trigger) + key, lambda min_value, new_value: new_value < min_value, trigger + ) diff --git a/pytorch_pfn_extras/training/extensions/evaluator.py b/pytorch_pfn_extras/training/extensions/evaluator.py index ce98498d5..8fd57f650 100644 --- a/pytorch_pfn_extras/training/extensions/evaluator.py +++ b/pytorch_pfn_extras/training/extensions/evaluator.py @@ -1,19 +1,27 @@ import contextlib import datetime from typing import ( - Any, Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union, TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + TextIO, + Union, ) import numpy import torch import torch.distributed - from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import extension +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) from pytorch_pfn_extras.training.extensions import util -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - _MetricType = Callable[[Any, Any, Any], None] _Scalar = Union[torch.Tensor, numpy.ndarray, numpy.floating, float] @@ -83,29 +91,32 @@ class Evaluator(extension.Extension): eval_func: Evaluation function called at each iteration. """ - trigger = 1, 'epoch' - default_name = 'validation' + + trigger = 1, "epoch" + default_name = "validation" priority = extension.PRIORITY_WRITER def __init__( - self, - iterator: Union[torch.utils.data.DataLoader[Any], - Dict[str, torch.utils.data.DataLoader[Any]]], - target: Union[torch.nn.Module, Dict[str, torch.nn.Module]], - eval_hook: Optional[Callable[['Evaluator'], None]] = None, - eval_func: Optional[Callable[..., Any]] = None, - **kwargs: Any, + self, + iterator: Union[ + torch.utils.data.DataLoader[Any], + Dict[str, torch.utils.data.DataLoader[Any]], + ], + target: Union[torch.nn.Module, Dict[str, torch.nn.Module]], + eval_hook: Optional[Callable[["Evaluator"], None]] = None, + eval_func: Optional[Callable[..., Any]] = None, + **kwargs: Any, ) -> None: - progress_bar = kwargs.get('progress_bar', False) - metrics = kwargs.get('metrics', []) + progress_bar = kwargs.get("progress_bar", False) + metrics = kwargs.get("metrics", []) if isinstance(iterator, torch.utils.data.DataLoader): - self._iterators = {'main': iterator} + self._iterators = {"main": iterator} else: self._iterators = iterator if isinstance(target, torch.nn.Module): - target = {'main': target} + target = {"main": target} self._targets = target self.name = None @@ -118,7 +129,7 @@ def eval_func(self, *args: Any, **kwargs: Any) -> Any: if self._eval_func: func = self._eval_func else: - func = self._targets['main'] + func = self._targets["main"] return func(*args, **kwargs) def get_iterator(self, name: str) -> torch.utils.data.DataLoader[Any]: @@ -152,8 +163,8 @@ def add_metric(self, metric_fn: _MetricType) -> None: self._metrics.append(metric_fn) def __call__( - self, - manager: Optional[ExtensionsManagerProtocol] = None, + self, + manager: Optional[ExtensionsManagerProtocol] = None, ) -> Optional[Dict[str, _Scalar]]: """Executes the evaluator extension. @@ -175,13 +186,12 @@ def __call__( # set up a reporter reporter = reporting.Reporter() if self.name is not None: - prefix = self.name + '/' + prefix = self.name + "/" else: - prefix = '' + prefix = "" for name, target in self._targets.items(): reporter.add_observer(prefix + name, target) - reporter.add_observers(prefix + name, - target.named_modules()) + reporter.add_observers(prefix + name, target.named_modules()) with reporter: with torch.no_grad(): # type: ignore[no-untyped-call] @@ -190,7 +200,9 @@ def __call__( reporting.report(result) return result - def _gather_summaries(self, summary: reporting.DictSummary) -> reporting.DictSummary: + def _gather_summaries( + self, summary: reporting.DictSummary + ) -> reporting.DictSummary: return summary def evaluate(self) -> Dict[str, _Scalar]: @@ -207,7 +219,7 @@ def evaluate(self) -> Dict[str, _Scalar]: :func:`~pytorch_pfn_extras.report` without specifying any observer. """ - iterator = self._iterators['main'] + iterator = self._iterators["main"] if self.eval_hook: self.eval_hook(self) @@ -226,7 +238,7 @@ def evaluate(self) -> Dict[str, _Scalar]: progress.current_position = idx observation: Dict[str, Any] = {} with reporting.report_scope(observation): - if isinstance(batch, tuple) and hasattr(batch, '_fields'): + if isinstance(batch, tuple) and hasattr(batch, "_fields"): outs = self.eval_func(batch) elif isinstance(batch, (tuple, list)): outs = self.eval_func(*batch) @@ -287,31 +299,39 @@ class DistributedEvaluator(Evaluator): """ def __init__( - self, - iterator: Union[torch.utils.data.DataLoader[Any], - Dict[str, torch.utils.data.DataLoader[Any]]], - target: Union[torch.nn.Module, Dict[str, torch.nn.Module]], - eval_hook: Optional[Callable[['Evaluator'], None]] = None, - eval_func: Optional[Callable[..., Any]] = None, - **kwargs: Any, + self, + iterator: Union[ + torch.utils.data.DataLoader[Any], + Dict[str, torch.utils.data.DataLoader[Any]], + ], + target: Union[torch.nn.Module, Dict[str, torch.nn.Module]], + eval_hook: Optional[Callable[["Evaluator"], None]] = None, + eval_func: Optional[Callable[..., Any]] = None, + **kwargs: Any, ) -> None: if not torch.distributed.is_initialized(): # type: ignore[no-untyped-call] - msg = "PyTorch distributed module is not initialized. " \ - "Initialize process group or use non-distributed Evaluator." + msg = ( + "PyTorch distributed module is not initialized. " + "Initialize process group or use non-distributed Evaluator." + ) raise RuntimeError(msg) - if 'progress_bar' in kwargs: + if "progress_bar" in kwargs: rank = torch.distributed.get_rank() # type: ignore[no-untyped-call] - kwargs['progress_bar'] &= (rank == 0) + kwargs["progress_bar"] &= rank == 0 super().__init__(iterator, target, eval_hook, eval_func, **kwargs) - def _gather_summaries(self, summary: reporting.DictSummary) -> reporting.DictSummary: + def _gather_summaries( + self, summary: reporting.DictSummary + ) -> reporting.DictSummary: return sum(_dist_gather(summary), reporting.DictSummary()) @contextlib.contextmanager -def _in_eval_mode(targets: Iterable[torch.nn.Module]) -> Generator[None, None, None]: +def _in_eval_mode( + targets: Iterable[torch.nn.Module], +) -> Generator[None, None, None]: targets = list(targets) was_train = [t.training for t in targets] try: @@ -335,19 +355,22 @@ def epoch_detail(self) -> float: class _IteratorProgressBar(util.ProgressBar): - def __init__( - self, - name: str, - iterator: IterationStatus, - bar_length: int = 50, - out: Optional[TextIO] = None, + self, + name: str, + iterator: IterationStatus, + bar_length: int = 50, + out: Optional[TextIO] = None, ): - if not (hasattr(iterator, 'current_position') - and hasattr(iterator, 'epoch_detail')): - raise TypeError('Iterator must have the following attributes ' - 'to enable a progress bar: ' - 'current_position, epoch_detail') + if not ( + hasattr(iterator, "current_position") + and hasattr(iterator, "epoch_detail") + ): + raise TypeError( + "Iterator must have the following attributes " + "to enable a progress bar: " + "current_position, epoch_detail" + ) self._name = name self._iterator = iterator self._bar_length = bar_length @@ -357,28 +380,34 @@ def __init__( def get_lines(self) -> List[str]: iteration = self._iterator.current_position epoch_detail = self._iterator.epoch_detail - epoch_size = getattr(self._iterator, '_epoch_size', None) + epoch_size = getattr(self._iterator, "_epoch_size", None) lines = [] rate = epoch_detail - marks = '#' * int(rate * self._bar_length) - rest_marks = '.' * (self._bar_length - len(marks)) - lines.append('{} [{}{}] {:6.2%}\n'.format( - self._name, marks, rest_marks, rate)) + marks = "#" * int(rate * self._bar_length) + rest_marks = "." * (self._bar_length - len(marks)) + lines.append( + "{} [{}{}] {:6.2%}\n".format(self._name, marks, rest_marks, rate) + ) if epoch_size: - lines.append(f'{{:{len(self._name)}}} / {{}} iterations\n' - .format(iteration, epoch_size)) + lines.append( + f"{{:{len(self._name)}}} / {{}} iterations\n".format( + iteration, epoch_size + ) + ) else: - lines.append(f'{{:{len(self._name)}}} iterations\n' - .format(iteration)) + lines.append( + f"{{:{len(self._name)}}} iterations\n".format(iteration) + ) speed_t, speed_e = self.update_speed(iteration, epoch_detail) estimated_time = (1.0 - epoch_detail) / speed_e - itps = f'{{:{len(self._name)}.5g}} iters/sec.'.format(speed_t) - eta = 'Estimated time to finish: {}.\n' \ - .format(datetime.timedelta(seconds=estimated_time)) + itps = f"{{:{len(self._name)}.5g}} iters/sec.".format(speed_t) + eta = "Estimated time to finish: {}.\n".format( + datetime.timedelta(seconds=estimated_time) + ) lines.append("{} {}".format(itps, eta)) return lines @@ -389,20 +418,21 @@ def get_lines(self) -> List[str]: class IgniteEvaluator(Evaluator): def __init__( - self, - evaluator: 'Engine', - iterator: Union[torch.utils.data.DataLoader[Any], - Dict[str, torch.utils.data.DataLoader[Any]]], - target: Union[torch.nn.Module, Dict[str, torch.nn.Module]], - **kwargs: Any, + self, + evaluator: "Engine", + iterator: Union[ + torch.utils.data.DataLoader[Any], + Dict[str, torch.utils.data.DataLoader[Any]], + ], + target: Union[torch.nn.Module, Dict[str, torch.nn.Module]], + **kwargs: Any, ): super().__init__(iterator, target, None, **kwargs) self.evaluator = evaluator self.set_evaluator_handlers() def set_evaluator_handlers(self) -> None: - from ignite.engine import Engine - from ignite.engine import Events + from ignite.engine import Engine, Events # Register handlers to retrieve the Average metrics and report them @self.evaluator.on(Events.ITERATION_STARTED) @@ -412,6 +442,7 @@ def set_evaluation_started(engine: Engine) -> None: self.cm.__enter__() if self._progress_bar: + @self.evaluator.on(Events.ITERATION_STARTED) def update_progress_bar(engine: Engine) -> None: self.progress.current_position = engine.state.iteration @@ -428,12 +459,11 @@ def set_evaluation_completed(engine: Engine) -> None: with reporting.report_scope(ignite_metrics): metrics = self.evaluator.state.metrics for metric in metrics: - reporting.report( - {'val/{}'.format(metric): metrics[metric]}) + reporting.report({"val/{}".format(metric): metrics[metric]}) self.summary.add(ignite_metrics) def evaluate(self) -> Dict[str, _Scalar]: - iterator = self._iterators['main'] + iterator = self._iterators["main"] self.summary = reporting.DictSummary() self.progress = IterationStatus(len(iterator)) if self._progress_bar: diff --git a/pytorch_pfn_extras/training/extensions/fail_on_non_number.py b/pytorch_pfn_extras/training/extensions/fail_on_non_number.py index c61b428bd..9e47662eb 100644 --- a/pytorch_pfn_extras/training/extensions/fail_on_non_number.py +++ b/pytorch_pfn_extras/training/extensions/fail_on_non_number.py @@ -1,7 +1,8 @@ import torch - from pytorch_pfn_extras.training import extension -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) class FailOnNonNumber(extension.Extension): @@ -26,10 +27,12 @@ def __init__(self, *, check_grad: bool = True): def __call__(self, manager: ExtensionsManagerProtocol) -> None: for name, model in manager.models.items(): for param in model.parameters(): - if (not torch.isfinite(param).all() - or (self._check_grad - and param.grad is not None - and not torch.isfinite(param.grad).all())): + if not torch.isfinite(param).all() or ( + self._check_grad + and param.grad is not None + and not torch.isfinite(param.grad).all() + ): raise RuntimeError( - 'Kill the process since parameters in optimizer' - ' \'{}\' diverge. R.I.P.'.format(name)) + "Kill the process since parameters in optimizer" + " '{}' diverge. R.I.P.".format(name) + ) diff --git a/pytorch_pfn_extras/training/extensions/log_report.py b/pytorch_pfn_extras/training/extensions/log_report.py index 0f07032d3..c78448d0e 100644 --- a/pytorch_pfn_extras/training/extensions/log_report.py +++ b/pytorch_pfn_extras/training/extensions/log_report.py @@ -5,7 +5,9 @@ from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) Observation = Mapping[str, reporting.Scalar] @@ -18,37 +20,38 @@ class LogWriterSaveFunc: - def __init__(self, format: str, append: bool) -> None: self._format = format self._append = append def __call__(self, target: Dict[str, Any], file_o: Any) -> None: - if self._format == 'json': + if self._format == "json": if self._append: raise ValueError( - 'LogReport does not support json format with append mode.') + "LogReport does not support json format with append mode." + ) log = json.dumps(target, indent=4) - elif self._format == 'json-lines': + elif self._format == "json-lines": # Add a new line at the end for subsequent appends - log = '\n'.join([json.dumps(x) for x in target]) + '\n' - elif self._format == 'yaml': + log = "\n".join([json.dumps(x) for x in target]) + "\n" + elif self._format == "yaml": import yaml # This is to dump ordered dicts as regular dicts def dict_representer(dumper: Any, data: Any) -> Any: return dumper.represent_dict(data.items()) + yaml.add_representer( # type: ignore[no-untyped-call] - collections.OrderedDict, dict_representer) + collections.OrderedDict, dict_representer + ) # yaml.add_constructor(_mapping_tag, dict_constructor) log = yaml.dump(target) else: - raise ValueError('Unknown format: {}'.format(self._format)) - file_o.write(bytes(log.encode('ascii'))) + raise ValueError("Unknown format: {}".format(self._format)) + file_o.write(bytes(log.encode("ascii"))) class _LogBuffer: - def __init__(self) -> None: self.lookers: Dict[int, int] = {} self._log: List[Observation] = [] @@ -57,22 +60,22 @@ def __init__(self) -> None: def _trim(self) -> None: min_looker_index = min(self.lookers.values()) if min_looker_index > self._offset: - self._log = self._log[min_looker_index - self._offset:] + self._log = self._log[min_looker_index - self._offset :] self._offset = min_looker_index def append(self, observation: Observation) -> None: self._log.append(observation) def _get(self, looker_id: int) -> List[Observation]: - return self._log[self.lookers[looker_id] - self._offset:] + return self._log[self.lookers[looker_id] - self._offset :] def _clear(self, looker_id: int) -> None: if looker_id not in self.lookers: - raise ValueError(f'looker {looker_id} is not registered') + raise ValueError(f"looker {looker_id} is not registered") self.lookers[looker_id] = len(self._log) + self._offset self._trim() - def emit_new_looker(self) -> '_LogLooker': + def emit_new_looker(self) -> "_LogLooker": looker_id = len(self.lookers) assert looker_id not in self.lookers self.lookers[looker_id] = len(self._log) + self._offset @@ -83,7 +86,6 @@ def size(self) -> int: class _LogLooker: - def __init__(self, log_buffer: _LogBuffer, looker_id: int) -> None: self._log_buffer = log_buffer self._looker_id = looker_id @@ -154,14 +156,14 @@ class LogReport(extension.Extension): """ def __init__( - self, - keys: Optional[Iterable[str]] = None, - trigger: trigger_module.TriggerLike = (1, 'epoch'), - postprocess: Optional[Callable[[Mapping[str, Any]], None]] = None, - filename: Optional[str] = None, - append: bool = False, - format: Optional[str] = None, - **kwargs: Any, + self, + keys: Optional[Iterable[str]] = None, + trigger: trigger_module.TriggerLike = (1, "epoch"), + postprocess: Optional[Callable[[Mapping[str, Any]], None]] = None, + filename: Optional[str] = None, + append: bool = False, + format: Optional[str] = None, + **kwargs: Any, ): self._keys = keys self._trigger = trigger_module.get_trigger(trigger) @@ -170,20 +172,20 @@ def __init__( self._log_looker = self._log_buffer.emit_new_looker() # When using a writer, it needs to have a savefun defined # to deal with a string. - self._writer = kwargs.get('writer', None) + self._writer = kwargs.get("writer", None) if filename is None: - filename = 'log' + filename = "log" if format is None: - if filename.endswith('.jsonl'): - format = 'json-lines' - elif filename.endswith('.yaml'): - format = 'yaml' + if filename.endswith(".jsonl"): + format = "json-lines" + elif filename.endswith(".yaml"): + format = "yaml" else: - format = 'json' - elif format not in ('json', 'json-lines', 'yaml'): - raise ValueError(f'unsupported log format: {format}') + format = "json" + elif format not in ("json", "json-lines", "yaml"): + raise ValueError(f"unsupported log format: {format}") self._filename = filename self._append = append @@ -210,9 +212,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: for name, value in stats.items(): stats_cpu[name] = float(value) # copy to CPU - stats_cpu['epoch'] = manager.epoch - stats_cpu['iteration'] = manager.iteration - stats_cpu['elapsed_time'] = manager.elapsed_time + stats_cpu["epoch"] = manager.epoch + stats_cpu["iteration"] = manager.iteration + stats_cpu["elapsed_time"] = manager.elapsed_time if self._postprocess is not None: self._postprocess(stats_cpu) @@ -223,8 +225,13 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: log_name = self._filename.format(**stats_cpu) out = manager.out savefun = LogWriterSaveFunc(self._format, self._append) - writer(log_name, out, self._log_looker.get(), - savefun=savefun, append=self._append) + writer( + log_name, + out, + self._log_looker.get(), + savefun=savefun, + append=self._append, + ) if self._append: self._log_looker.clear() @@ -238,26 +245,26 @@ def log(self) -> List[Mapping[str, Any]]: def state_dict(self) -> Dict[str, Any]: state: Dict[str, Any] = {} - if hasattr(self._trigger, 'state_dict'): - state['_trigger'] = self._trigger.state_dict() + if hasattr(self._trigger, "state_dict"): + state["_trigger"] = self._trigger.state_dict() try: - state['_summary'] = self._summary.state_dict() + state["_summary"] = self._summary.state_dict() except KeyError: pass - state['_log'] = json.dumps(self._log_buffer._log) + state["_log"] = json.dumps(self._log_buffer._log) return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - if hasattr(self._trigger, 'load_state_dict'): - self._trigger.load_state_dict(to_load['_trigger']) - self._summary.load_state_dict(to_load['_summary']) - self._log_buffer._log = json.loads(to_load['_log']) + if hasattr(self._trigger, "load_state_dict"): + self._trigger.load_state_dict(to_load["_trigger"]) + self._summary.load_state_dict(to_load["_summary"]) + self._log_buffer._log = json.loads(to_load["_log"]) def _init_summary(self) -> None: self._summary = reporting.DictSummary() - def to_dataframe(self) -> 'pandas.DataFrame': + def to_dataframe(self) -> "pandas.DataFrame": if not _pandas_available: raise ImportError( "Need to install pandas to use `to_dataframe` method." diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py index 28be7f758..fc794554a 100644 --- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py +++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py @@ -2,25 +2,33 @@ from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) from torch.optim.lr_scheduler import ReduceLROnPlateau -def _get_value_from_log_report(manager: ExtensionsManagerProtocol, key: Any) -> Any: +def _get_value_from_log_report( + manager: ExtensionsManagerProtocol, key: Any +) -> Any: # Find and return the latest reported "key" from LogReport if key is None: return None if key not in manager.observation: raise ValueError( - '{} is not found in the reported values {}'.format( - key, manager.observation)) + "{} is not found in the reported values {}".format( + key, manager.observation + ) + ) return manager.observation[key] -def _default_stepper(manager: ExtensionsManagerProtocol, scheduler: Any) -> None: +def _default_stepper( + manager: ExtensionsManagerProtocol, scheduler: Any +) -> None: if isinstance(scheduler, ReduceLROnPlateau): - LRScheduler.step_by_value('val/loss')(manager, scheduler) + LRScheduler.step_by_value("val/loss")(manager, scheduler) else: scheduler.step() @@ -43,11 +51,12 @@ class LRScheduler(extension.Extension): """ def __init__( - self, - scheduler: Any, *, - stepper: Any = _default_stepper, - trigger: trigger_module.TriggerLike = (1, 'epoch'), - is_async: bool = True, + self, + scheduler: Any, + *, + stepper: Any = _default_stepper, + trigger: trigger_module.TriggerLike = (1, "epoch"), + is_async: bool = True, ) -> None: self.scheduler = scheduler self.trigger = trigger_module.get_trigger(trigger) @@ -59,12 +68,15 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: @staticmethod def step_by_value(key: Optional[str]) -> Any: - def _stepper(manager: ExtensionsManagerProtocol, scheduler: Any) -> None: + def _stepper( + manager: ExtensionsManagerProtocol, scheduler: Any + ) -> None: scheduler.step(_get_value_from_log_report(manager, key)) + return _stepper def state_dict(self) -> Dict[str, Any]: - return {'scheduler': self.scheduler.state_dict()} + return {"scheduler": self.scheduler.state_dict()} def load_state_dict(self, state: Dict[str, Any]) -> None: - self.scheduler.load_state_dict(state['scheduler']) + self.scheduler.load_state_dict(state["scheduler"]) diff --git a/pytorch_pfn_extras/training/extensions/micro_average.py b/pytorch_pfn_extras/training/extensions/micro_average.py index 10f2df441..984fc00f8 100644 --- a/pytorch_pfn_extras/training/extensions/micro_average.py +++ b/pytorch_pfn_extras/training/extensions/micro_average.py @@ -3,7 +3,9 @@ from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) class MicroAverage(extension.Extension): @@ -64,24 +66,26 @@ class MicroAverage(extension.Extension): priority = extension.PRIORITY_EDITOR def __init__( - self, - numerator_key: str, - denominator_key: str, - result_key: str, - trigger: trigger_module.TriggerLike = (1, 'epoch'), + self, + numerator_key: str, + denominator_key: str, + result_key: str, + trigger: trigger_module.TriggerLike = (1, "epoch"), ) -> None: self._trigger = trigger_module.get_trigger(trigger) self._numerator_key = numerator_key self._denominator_key = denominator_key self._result_key = result_key - self._numerator = 0. - self._denominator = 0. + self._numerator = 0.0 + self._denominator = 0.0 def __call__(self, manager: ExtensionsManagerProtocol) -> None: observation: Any = manager.observation - if not (self._numerator_key in observation - and self._denominator_key in observation): + if not ( + self._numerator_key in observation + and self._denominator_key in observation + ): return self._numerator += observation[self._numerator_key] @@ -94,10 +98,12 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: reporting.report({self._result_key: result}) def state_dict(self) -> Dict[str, Any]: - state = {'_numerator': self._numerator, - '_denominator': self._denominator} + state = { + "_numerator": self._numerator, + "_denominator": self._denominator, + } return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - self._numerator = to_load['_numerator'] - self._denominator = to_load['_denominator'] + self._numerator = to_load["_numerator"] + self._denominator = to_load["_denominator"] diff --git a/pytorch_pfn_extras/training/extensions/parameter_statistics.py b/pytorch_pfn_extras/training/extensions/parameter_statistics.py index 66975cc91..24e515d01 100644 --- a/pytorch_pfn_extras/training/extensions/parameter_statistics.py +++ b/pytorch_pfn_extras/training/extensions/parameter_statistics.py @@ -1,19 +1,19 @@ from typing import Any, Optional import torch - from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) _default_statistics = { - 'mean': lambda x: torch.mean(x), - 'std': lambda x: torch.std(x), - 'min': lambda x: torch.min(x), - 'max': lambda x: torch.max(x), - 'zeros': lambda x: (x == 0).sum(), + "mean": lambda x: torch.mean(x), + "std": lambda x: torch.std(x), + "min": lambda x: torch.min(x), + "max": lambda x: torch.max(x), + "zeros": lambda x: (x == 0).sum(), # 'percentile': lambda x: backend.get_array_module(x).percentile( # x, (0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87)) } @@ -72,41 +72,40 @@ class ParameterStatistics(extension.Extension): (0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87))``) """ - default_name = 'parameter_statistics' + + default_name = "parameter_statistics" priority = extension.PRIORITY_WRITER # prefix ends with a '/' and param_name is preceded by a '/' - report_key_template = ('{prefix}{param_name}/{attr_name}/' - '{function_name}') + report_key_template = "{prefix}{param_name}/{attr_name}/" "{function_name}" default_statistics = _default_statistics def __init__( - self, - links: Any, - statistics: Any = 'default', - report_params: bool = True, - report_grads: bool = True, - prefix: Optional[str] = None, - trigger: trigger_module.TriggerLike = (1, 'epoch'), - skip_nan_params: bool = False, + self, + links: Any, + statistics: Any = "default", + report_params: bool = True, + report_grads: bool = True, + prefix: Optional[str] = None, + trigger: trigger_module.TriggerLike = (1, "epoch"), + skip_nan_params: bool = False, ): - if not isinstance(links, (list, tuple)): - links = links, + links = (links,) self._links = links if statistics is None: statistics = {} - elif statistics == 'default': + elif statistics == "default": statistics = self.default_statistics self._statistics = dict(statistics) attrs = [] if report_params: - attrs.append('data') + attrs.append("data") if report_grads: - attrs.append('grad') + attrs.append("grad") self._attrs = attrs self._prefix = prefix @@ -137,24 +136,30 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: # since the statistics function should make no # assumption about the axes params = getattr(param, attr_name).flatten() - if (self._skip_nan_params - and ( - torch.isnan(params).any())): - value: Any = float('nan') + if self._skip_nan_params and ( + torch.isnan(params).any() + ): + value: Any = float("nan") else: value = function(params) key = self.report_key_template.format( - prefix=self._prefix + '/' if self._prefix else '', + prefix=self._prefix + "/" if self._prefix else "", param_name=param_name, attr_name=attr_name, - function_name=function_name + function_name=function_name, ) - if (isinstance(value, torch.Tensor) - and value.numel() > 1): + if ( + isinstance(value, torch.Tensor) + and value.numel() > 1 + ): # Append integer indices to the keys if the # statistic function return multiple values - statistics.update({'{}/{}'.format(key, i): v for - i, v in enumerate(value)}) + statistics.update( + { + "{}/{}".format(key, i): v + for i, v in enumerate(value) + } + ) else: statistics[key] = value diff --git a/pytorch_pfn_extras/training/extensions/plot_report.py b/pytorch_pfn_extras/training/extensions/plot_report.py index a906f6b7d..fb01ec305 100644 --- a/pytorch_pfn_extras/training/extensions/plot_report.py +++ b/pytorch_pfn_extras/training/extensions/plot_report.py @@ -1,20 +1,21 @@ import json -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import warnings +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy - from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) _available = None def matplotlib_savefun(target: Tuple[Any, Any, Any], file_o: Any) -> None: fig, leg, plt = target - fig.savefig(file_o, bbox_extra_artists=(leg,), bbox_inches='tight') + fig.savefig(file_o, bbox_extra_artists=(leg,), bbox_inches="tight") fig.clf() plt.close(fig) @@ -23,6 +24,7 @@ def _try_import_matplotlib() -> None: global matplotlib, _available try: import matplotlib # NOQA + _available = True except (ImportError, TypeError): _available = False @@ -33,10 +35,12 @@ def _check_available() -> None: _try_import_matplotlib() if not _available: - warnings.warn('matplotlib is not installed on your environment, ' - 'so nothing will be plotted at this time. ' - 'Please install matplotlib to plot figures.\n\n' - ' $ pip install matplotlib\n') + warnings.warn( + "matplotlib is not installed on your environment, " + "so nothing will be plotted at this time. " + "Please install matplotlib to plot figures.\n\n" + " $ pip install matplotlib\n" + ) class PlotReport(extension.Extension): @@ -114,18 +118,17 @@ class PlotReport(extension.Extension): """ def __init__( - self, - y_keys: Union[Iterable[str], str], - x_key: str = 'iteration', - trigger: trigger_module.TriggerLike = (1, 'epoch'), - postprocess: Any = None, - filename: Optional[str] = None, - marker: str = 'x', - grid: bool = True, - **kwargs: Any, + self, + y_keys: Union[Iterable[str], str], + x_key: str = "iteration", + trigger: trigger_module.TriggerLike = (1, "epoch"), + postprocess: Any = None, + filename: Optional[str] = None, + marker: str = "x", + grid: bool = True, + **kwargs: Any, ): - - file_name = kwargs.get('file_name', 'plot.png') + file_name = kwargs.get("file_name", "plot.png") if filename is None: filename = file_name del file_name # avoid accidental use @@ -144,7 +147,7 @@ def __init__( self._postprocess = postprocess self._init_summary() self._data: Dict[str, List[Tuple[Any, Any]]] = {k: [] for k in y_keys} - self._writer = kwargs.get('writer', None) + self._writer = kwargs.get("writer", None) @staticmethod def available() -> bool: @@ -173,8 +176,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: for name, value in stats.items(): stats_cpu[name] = float(value) # copy to CPU - stats_cpu['epoch'] = manager.epoch - stats_cpu['iteration'] = manager.iteration + stats_cpu["epoch"] = manager.epoch + stats_cpu["iteration"] = manager.iteration x = stats_cpu[self._x_key] data = self._data @@ -200,9 +203,14 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: if self._postprocess is not None: self._postprocess(f, a, summary) leg = a.legend( - bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) - writer(self._file_name, manager.out, (f, leg, plt), # type: ignore - savefun=matplotlib_savefun) + bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0 + ) + writer( + self._file_name, + manager.out, + (f, leg, plt), # type: ignore + savefun=matplotlib_savefun, + ) else: print( f"[WARNING] No data found for key {self._y_keys}, " @@ -215,11 +223,11 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: self._init_summary() def state_dict(self) -> Dict[str, Any]: - state = {'_plot_{}'.format(self._file_name): json.dumps(self._data)} + state = {"_plot_{}".format(self._file_name): json.dumps(self._data)} return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - key = '_plot_{}'.format(self._file_name) + key = "_plot_{}".format(self._file_name) self._data = json.loads(to_load[key]) def _init_summary(self) -> None: diff --git a/pytorch_pfn_extras/training/extensions/print_report.py b/pytorch_pfn_extras/training/extensions/print_report.py index c2f33bad8..86f5fa11c 100644 --- a/pytorch_pfn_extras/training/extensions/print_report.py +++ b/pytorch_pfn_extras/training/extensions/print_report.py @@ -1,17 +1,20 @@ -from copy import deepcopy import os import sys -from typing import Any, Dict, IO, List, Optional, Sequence, Tuple, Union +from copy import deepcopy +from typing import IO, Any, Dict, List, Optional, Sequence, Tuple, Union from pytorch_pfn_extras.training import extension -from pytorch_pfn_extras.training.extensions import log_report \ - as log_report_module +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) +from pytorch_pfn_extras.training.extensions import ( + log_report as log_report_module, +) from pytorch_pfn_extras.training.extensions import util -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol def create_header_and_templates( - entries: Sequence[str], + entries: Sequence[str], ) -> Tuple[str, List[Tuple[str, str, str]]]: """Construct header and templates from `entries` @@ -25,35 +28,36 @@ def create_header_and_templates( # format information entry_widths = [max(10, len(s)) for s in entries] - header = ' '.join(('{:%d}' % w for w in entry_widths)).format( - *entries) + '\n' + header = ( + " ".join(("{:%d}" % w for w in entry_widths)).format(*entries) + "\n" + ) templates = [] for entry, w in zip(entries, entry_widths): - templates.append((entry, '{:<%dg} ' % w, ' ' * (w + 2))) + templates.append((entry, "{:<%dg} " % w, " " * (w + 2))) return header, templates def filter_and_sort_entries( - all_entries: List[str], - unit: str = 'epoch', + all_entries: List[str], + unit: str = "epoch", ) -> List[str]: entries = deepcopy(all_entries) # TODO(nakago): sort other entries if necessary - if 'iteration' in entries: + if "iteration" in entries: # move iteration to head - entries.pop(entries.index('iteration')) - if unit == 'iteration': - entries = ['iteration'] + entries - if 'epoch' in entries: + entries.pop(entries.index("iteration")) + if unit == "iteration": + entries = ["iteration"] + entries + if "epoch" in entries: # move epoch to head - entries.pop(entries.index('epoch')) - if unit == 'epoch': - entries = ['epoch'] + entries - if 'elapsed_time' in entries: + entries.pop(entries.index("epoch")) + if unit == "epoch": + entries = ["epoch"] + entries + if "elapsed_time" in entries: # move elapsed_time to tail - entries.pop(entries.index('elapsed_time')) - entries.append('elapsed_time') + entries.pop(entries.index("elapsed_time")) + entries.append("elapsed_time") return entries @@ -76,10 +80,10 @@ class PrintReport(extension.Extension): """ def __init__( - self, - entries: Optional[Sequence[str]] = None, - log_report: Union[str, log_report_module.LogReport] = 'LogReport', - out: IO[Any] = sys.stdout, + self, + entries: Optional[Sequence[str]] = None, + log_report: Union[str, log_report_module.LogReport] = "LogReport", + out: IO[Any] = sys.stdout, ) -> None: if entries is None: self._infer_entries = True @@ -98,21 +102,20 @@ def __init__( self._all_entries: List[str] = [] def get_log_report( - self, - manager: ExtensionsManagerProtocol, + self, + manager: ExtensionsManagerProtocol, ) -> log_report_module.LogReport: log_report = self._log_report if isinstance(log_report, str): ext = manager.get_extension(log_report) if not isinstance(ext, log_report_module.LogReport): - raise TypeError('`log_report` must be LogReport object') + raise TypeError("`log_report` must be LogReport object") return ext elif isinstance(log_report, log_report_module.LogReport): log_report(manager) # update the log report return log_report else: - raise TypeError('log report has a wrong type %s' % - type(log_report)) + raise TypeError("log report has a wrong type %s" % type(log_report)) @property def _log_looker(self) -> log_report_module._LogLooker: @@ -133,12 +136,13 @@ def _update_entries(self, log_report: log_report_module.LogReport) -> None: updated_flag = True if updated_flag: - if hasattr(log_report, '_trigger') and hasattr(log_report._trigger, - 'unit'): + if hasattr(log_report, "_trigger") and hasattr( + log_report._trigger, "unit" + ): unit = log_report._trigger.unit # type: ignore[attr-defined] else: # Failed to infer `unit`, use epoch as default - unit = 'epoch' + unit = "epoch" entries = filter_and_sort_entries(self._all_entries, unit=unit) self._entries = entries header, templates = create_header_and_templates(entries) @@ -160,23 +164,23 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: for line in self._log_looker.get(): # delete the printed contents from the current cursor - if os.name == 'nt': + if os.name == "nt": util.erase_console(0, 0) else: - out.write('\033[J') + out.write("\033[J") self._print(line) self._log_looker.clear() def state_dict(self) -> Dict[str, Any]: log_report = self._log_report if isinstance(log_report, log_report_module.LogReport): - return {'_log_report': log_report.state_dict()} + return {"_log_report": log_report.state_dict()} return {} def load_state_dict(self, to_load: Dict[str, Any]) -> None: log_report = self._log_report if isinstance(log_report, log_report_module.LogReport): - log_report.load_state_dict(to_load['_log_report']) + log_report.load_state_dict(to_load["_log_report"]) def _print(self, observation: log_report_module.Observation) -> None: out = self._out @@ -185,6 +189,6 @@ def _print(self, observation: log_report_module.Observation) -> None: out.write(template.format(observation[entry])) else: out.write(empty) - out.write('\n') - if hasattr(out, 'flush'): + out.write("\n") + if hasattr(out, "flush"): out.flush() diff --git a/pytorch_pfn_extras/training/extensions/print_report_notebook.py b/pytorch_pfn_extras/training/extensions/print_report_notebook.py index 2d0c6d928..f09c47b1f 100644 --- a/pytorch_pfn_extras/training/extensions/print_report_notebook.py +++ b/pytorch_pfn_extras/training/extensions/print_report_notebook.py @@ -1,15 +1,16 @@ import sys -from typing import Any, IO, List, Optional, Union +from typing import IO, Any, List, Optional, Union from IPython.display import display from ipywidgets import HTML - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) +from pytorch_pfn_extras.training.extensions import ( + log_report as log_report_module, +) from pytorch_pfn_extras.training.extensions.print_report import PrintReport -from pytorch_pfn_extras.training.extensions import log_report \ - as log_report_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - class PrintReportNotebook(PrintReport): @@ -32,10 +33,10 @@ class PrintReportNotebook(PrintReport): """ def __init__( - self, - entries: Optional[List[str]] = None, - log_report: Union[str, log_report_module.LogReport] = 'LogReport', - out: IO[Any] = sys.stdout, + self, + entries: Optional[List[str]] = None, + log_report: Union[str, log_report_module.LogReport] = "LogReport", + out: IO[Any] = sys.stdout, ) -> None: super(PrintReportNotebook, self).__init__( entries=entries, log_report=log_report, out=out @@ -56,4 +57,4 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: if self._infer_entries: # --- update entries --- self._update_entries(log_report) - self._widget.value = df[self._entries].to_html(index=False, na_rep='') + self._widget.value = df[self._entries].to_html(index=False, na_rep="") diff --git a/pytorch_pfn_extras/training/extensions/profile_report.py b/pytorch_pfn_extras/training/extensions/profile_report.py index fdc97c055..95514d275 100644 --- a/pytorch_pfn_extras/training/extensions/profile_report.py +++ b/pytorch_pfn_extras/training/extensions/profile_report.py @@ -1,13 +1,15 @@ -from collections import OrderedDict import json +from collections import OrderedDict from typing import Any, Dict, Iterable, List, Optional from pytorch_pfn_extras import reporting +from pytorch_pfn_extras.profiler._time_summary import get_time_summary from pytorch_pfn_extras.training import extension -from pytorch_pfn_extras.training.extensions import log_report from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol -from pytorch_pfn_extras.profiler._time_summary import get_time_summary +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) +from pytorch_pfn_extras.training.extensions import log_report class ProfileReport(extension.Extension): @@ -44,15 +46,16 @@ class ProfileReport(extension.Extension): header (str): header string templates (str): template string for print values. """ + def __init__( - self, - store_keys: Optional[Iterable[str]] = None, - report_keys: Optional[Iterable[str]] = None, - trigger: trigger_module.TriggerLike = (1, "epoch"), - filename: Optional[str] = None, - append: bool = False, - format: Optional[str] = None, - **kwargs: Any, + self, + store_keys: Optional[Iterable[str]] = None, + report_keys: Optional[Iterable[str]] = None, + trigger: trigger_module.TriggerLike = (1, "epoch"), + filename: Optional[str] = None, + append: bool = False, + format: Optional[str] = None, + **kwargs: Any, ): self.time_summary = get_time_summary() # Initializes global TimeSummary. @@ -73,15 +76,15 @@ def __init__( filename = log_name del log_name # avoid accidental use self._log_name = filename - self._writer = kwargs.get('writer', None) + self._writer = kwargs.get("writer", None) if format is None and filename is not None: - if filename.endswith('.jsonl'): - format = 'json-lines' - elif filename.endswith('.yaml'): - format = 'yaml' + if filename.endswith(".jsonl"): + format = "json-lines" + elif filename.endswith(".yaml"): + format = "yaml" else: - format = 'json' + format = "json" self._append = append self._format = format @@ -114,7 +117,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: stats_cpu["elapsed_time"] = manager.elapsed_time # Recreate dict to fix order of logs out = OrderedDict( - [(k, stats_cpu[k]) for k in sorted(stats_cpu.keys())]) + [(k, stats_cpu[k]) for k in sorted(stats_cpu.keys())] + ) self._log.append(out) @@ -123,9 +127,15 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: log_name = self._log_name.format(**out) assert self._format is not None savefun = log_report.LogWriterSaveFunc( - self._format, self._append) - writer(log_name, out, self._log, # type: ignore - savefun=savefun, append=self._append) + self._format, self._append + ) + writer( + log_name, + out, + self._log, # type: ignore + savefun=savefun, + append=self._append, + ) if self._append: self._log = [] diff --git a/pytorch_pfn_extras/training/extensions/progress_bar.py b/pytorch_pfn_extras/training/extensions/progress_bar.py index 5c32d0944..4859e6be7 100644 --- a/pytorch_pfn_extras/training/extensions/progress_bar.py +++ b/pytorch_pfn_extras/training/extensions/progress_bar.py @@ -3,8 +3,10 @@ from typing import Any, List, Optional from pytorch_pfn_extras.training import extension +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) from pytorch_pfn_extras.training.extensions import util -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol class ProgressBar(extension.Extension): @@ -28,18 +30,19 @@ class ProgressBar(extension.Extension): """ def __init__( - self, - training_length: Any = None, - update_interval: int = 100, - bar_length: int = 50, - out: Any = sys.stdout, + self, + training_length: Any = None, + update_interval: int = 100, + bar_length: int = 50, + out: Any = sys.stdout, ): self._training_length = training_length self._update_interval = update_interval self._bar_length = bar_length self._out = out self._pbar = _ManagerProgressBar( - self._training_length, self._bar_length, self._out) + self._training_length, self._bar_length, self._out + ) def __call__(self, manager: ExtensionsManagerProtocol) -> None: if self._pbar.manager is None: @@ -55,7 +58,6 @@ def finalize(self, manager: ExtensionsManagerProtocol) -> None: class _ManagerProgressBar(util.ProgressBar): - def __init__(self, training_length: Any, bar_length: int, out: Any) -> None: super().__init__(out) self.training_length = training_length @@ -75,37 +77,46 @@ def get_lines(self) -> List[str]: self.training_length = t.get_training_length() # type: ignore[attr-defined] length, unit = self.training_length - if unit == 'iteration': + if unit == "iteration": rate = iteration / length else: rate = epoch / length rate = min(rate, 1.0) bar_length = self.bar_length - marks = '#' * int(rate * bar_length) - lines.append(' total [{}{}] {:6.2%}\n'.format( - marks, '.' * (bar_length - len(marks)), rate)) + marks = "#" * int(rate * bar_length) + lines.append( + " total [{}{}] {:6.2%}\n".format( + marks, "." * (bar_length - len(marks)), rate + ) + ) epoch_rate = epoch - int(epoch) - marks = '#' * int(epoch_rate * bar_length) - lines.append('this epoch [{}{}] {:6.2%}\n'.format( - marks, '.' * (bar_length - len(marks)), epoch_rate)) + marks = "#" * int(epoch_rate * bar_length) + lines.append( + "this epoch [{}{}] {:6.2%}\n".format( + marks, "." * (bar_length - len(marks)), epoch_rate + ) + ) if self.progress_template is None: self.progress_template = ( - '{0.iteration:10} iter, {0.epoch} epoch / %s %ss\n' % - self.training_length) + "{0.iteration:10} iter, {0.epoch} epoch / %s %ss\n" + % self.training_length + ) progress = self.progress_template.format(self.manager) lines.append(progress) speed_t, speed_e = self.update_speed(iteration, epoch) - if unit == 'iteration': + if unit == "iteration": estimated_time = (length - iteration) / speed_t else: estimated_time = (length - epoch) / speed_e estimated_time = max(estimated_time, 0.0) - lines.append('{:10.5g} iters/sec. Estimated time to finish: {}.\n' - .format(speed_t, - datetime.timedelta(seconds=estimated_time))) + lines.append( + "{:10.5g} iters/sec. Estimated time to finish: {}.\n".format( + speed_t, datetime.timedelta(seconds=estimated_time) + ) + ) return lines diff --git a/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py b/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py index 8207222a0..438c8f9b4 100644 --- a/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py +++ b/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py @@ -5,9 +5,10 @@ from IPython.display import display from ipywidgets import HTML, FloatProgress, HBox, VBox # NOQA - from pytorch_pfn_extras.training import extension, trigger # NOQA -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) class ProgressBarNotebook(extension.Extension): @@ -34,11 +35,11 @@ class ProgressBarNotebook(extension.Extension): """ def __init__( - self, - training_length: Any = None, - update_interval: int = 100, - bar_length: int = 50, - out: Any = sys.stdout, + self, + training_length: Any = None, + update_interval: int = 100, + bar_length: int = 50, + out: Any = sys.stdout, ): self._training_length = training_length if training_length is not None: @@ -46,26 +47,31 @@ def __init__( self._update_interval = update_interval self._recent_timing: List[Tuple[float, float, float]] = [] - self._total_bar = FloatProgress(description='total', - min=0, max=1, value=0, - bar_style='info') + self._total_bar = FloatProgress( + description="total", min=0, max=1, value=0, bar_style="info" + ) self._total_html = HTML() - self._epoch_bar = FloatProgress(description='this epoch', - min=0, max=1, value=0, - bar_style='info') + self._epoch_bar = FloatProgress( + description="this epoch", min=0, max=1, value=0, bar_style="info" + ) self._epoch_html = HTML() self._status_html = HTML() - self._widget = VBox([HBox([self._total_bar, self._total_html]), - HBox([self._epoch_bar, self._epoch_html]), - self._status_html]) + self._widget = VBox( + [ + HBox([self._total_bar, self._total_html]), + HBox([self._epoch_bar, self._epoch_html]), + self._status_html, + ] + ) def initialize(self, manager: ExtensionsManagerProtocol) -> None: if self._training_length is None: t = manager._stop_trigger if not isinstance(t, trigger.IntervalTrigger): raise TypeError( - 'cannot retrieve the training length from %s' % type(t)) + "cannot retrieve the training length from %s" % type(t) + ) self._training_length = t.period, t.unit self._init_status_template() @@ -77,7 +83,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: iteration, epoch_detail = manager.iteration, manager.epoch_detail - if unit == 'iteration': + if unit == "iteration": is_finished = iteration == length else: is_finished = epoch_detail == length @@ -87,8 +93,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: def finalize(self, manager: ExtensionsManagerProtocol) -> None: if self._total_bar.value != 1: - self._total_bar.bar_style = 'warning' - self._epoch_bar.bar_style = 'warning' + self._total_bar.bar_style = "warning" + self._epoch_bar.bar_style = "warning" @property def widget(self) -> VBox: @@ -102,7 +108,7 @@ def update(self, iteration: int, epoch_detail: float) -> None: recent_timing.append((iteration, epoch_detail, now)) - if unit == 'iteration': + if unit == "iteration": rate = iteration / length else: rate = epoch_detail / length @@ -113,12 +119,13 @@ def update(self, iteration: int, epoch_detail: float) -> None: self._epoch_bar.value = epoch_rate self._epoch_html.value = "{:6.2%}".format(epoch_rate) - status = self._status_template.format(iteration=iteration, - epoch=int(epoch_detail)) + status = self._status_template.format( + iteration=iteration, epoch=int(epoch_detail) + ) if rate == 1: - self._total_bar.bar_style = 'success' - self._epoch_bar.bar_style = 'success' + self._total_bar.bar_style = "success" + self._epoch_bar.bar_style = "success" old_t, old_e, old_sec = recent_timing[0] span = now - old_sec @@ -126,16 +133,16 @@ def update(self, iteration: int, epoch_detail: float) -> None: speed_t = (iteration - old_t) / span speed_e = (epoch_detail - old_e) / span else: - speed_t = float('inf') - speed_e = float('inf') + speed_t = float("inf") + speed_e = float("inf") - if unit == 'iteration': + if unit == "iteration": estimated_time = (length - iteration) / speed_t else: estimated_time = (length - epoch_detail) / speed_e - estimate = ('{:10.5g} iters/sec. Estimated time to finish: {}.' - .format(speed_t, - datetime.timedelta(seconds=estimated_time))) + estimate = "{:10.5g} iters/sec. Estimated time to finish: {}.".format( + speed_t, datetime.timedelta(seconds=estimated_time) + ) self._status_html.value = status + estimate @@ -144,5 +151,6 @@ def update(self, iteration: int, epoch_detail: float) -> None: def _init_status_template(self) -> None: self._status_template = ( - '{iteration:10} iter, {epoch} epoch / %s %ss
' % - self._training_length) + "{iteration:10} iter, {epoch} epoch / %s %ss
" + % self._training_length + ) diff --git a/pytorch_pfn_extras/training/extensions/slack.py b/pytorch_pfn_extras/training/extensions/slack.py index 2319492a9..0c01e0c09 100644 --- a/pytorch_pfn_extras/training/extensions/slack.py +++ b/pytorch_pfn_extras/training/extensions/slack.py @@ -1,24 +1,25 @@ import getpass -import os import json -import urllib.request +import os import shlex -import sys import socket +import sys import traceback import types -from typing import Any, Callable, Optional, Sequence, Union +import urllib.request import warnings - +from typing import Any, Callable, Optional, Sequence, Union from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) from pytorch_pfn_extras.training._trigger_util import TriggerLike - try: import slack_sdk + _slack_sdk_available = True except ImportError: _slack_sdk_available = False @@ -29,51 +30,49 @@ def _failsafe(func: Callable[[], Any]) -> str: try: return str(func()) except Exception: - return 'UNKNOWN' + return "UNKNOWN" _identity = ( - f'{_failsafe(getpass.getuser)}@{_failsafe(socket.gethostname)} ' - f'[PID {_failsafe(os.getpid)}]') + f"{_failsafe(getpass.getuser)}@{_failsafe(socket.gethostname)} " + f"[PID {_failsafe(os.getpid)}]" +) def _default_msg( - manager: ExtensionsManagerProtocol, - context: Any, + manager: ExtensionsManagerProtocol, + context: Any, ) -> str: - return f'Epoch #{manager.epoch}' + return f"Epoch #{manager.epoch}" def _default_start_msg( - manager: ExtensionsManagerProtocol, - context: Any, + manager: ExtensionsManagerProtocol, + context: Any, ) -> str: - cmdline = ' '.join([shlex.quote(x) for x in sys.argv]) - return ( - f'🏃 *Training started! {_identity}*\n' - f'Command: `{cmdline}`' - ) + cmdline = " ".join([shlex.quote(x) for x in sys.argv]) + return f"🏃 *Training started! {_identity}*\n" f"Command: `{cmdline}`" def _default_end_msg( - manager: ExtensionsManagerProtocol, - context: Any, + manager: ExtensionsManagerProtocol, + context: Any, ) -> str: - return f'✅ *Training finished! {_identity}*' + return f"✅ *Training finished! {_identity}*" def _default_error_msg( - manager: ExtensionsManagerProtocol, - exc: Exception, - context: Any, + manager: ExtensionsManagerProtocol, + exc: Exception, + context: Any, ) -> str: return ( - f'❌ *Error during training. {_identity}*\n' - f'{type(exc).__name__}: {exc}\n' - 'Traceback:\n' - '```\n' - ''.join(traceback.format_tb(exc.__traceback__)).strip() + '\n' - '```' + f"❌ *Error during training. {_identity}*\n" + f"{type(exc).__name__}: {exc}\n" + "Traceback:\n" + "```\n" + "".join(traceback.format_tb(exc.__traceback__)).strip() + "\n" + "```" ) @@ -107,8 +106,7 @@ def _default_error_msg( class _SlackBase(extension.Extension): - - trigger: TriggerLike = (1, 'epoch') + trigger: TriggerLike = (1, "epoch") default_msg = _default_msg default_start_msg = _default_start_msg @@ -132,25 +130,25 @@ def _upload_files(self, filenames: Sequence[str]) -> Sequence[str]: raise NotImplementedError def _format( - self, - msg: Union[_MessageFunc, str], - default: Optional[_MessageFunc], - manager: ExtensionsManagerProtocol, + self, + msg: Union[_MessageFunc, str], + default: Optional[_MessageFunc], + manager: ExtensionsManagerProtocol, ) -> str: - default_str = '' if default is None else default(manager, self._context) + default_str = "" if default is None else default(manager, self._context) if isinstance(msg, str): return msg.format( manager=manager, context=self._context, default=default_str, - **manager.observation + **manager.observation, ) return msg(manager, self._context) def _format_error( - self, - manager: ExtensionsManagerProtocol, - error: Exception, + self, + manager: ExtensionsManagerProtocol, + error: Exception, ) -> str: msg = self._error_msg assert msg is not None @@ -159,7 +157,7 @@ def _format_error( manager=manager, context=self._context, default=_default_error_msg(manager, error, self._context), - error=error + error=error, ) return msg(manager, error, self._context) @@ -172,30 +170,30 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: pass elif isinstance(self._filenames, Sequence): filenames = [ - self._format(f, None, manager) for f in self._filenames] + self._format(f, None, manager) for f in self._filenames + ] else: # callable filenames = self._filenames(manager, self._context) - needs_upload = ( - len(filenames) != 0 - and (self._upload_trigger is None - or self._upload_trigger(manager))) + needs_upload = len(filenames) != 0 and ( + self._upload_trigger is None or self._upload_trigger(manager) + ) if self._msg is None and not needs_upload: # The message is not set and no files to upload. return - text = '' + text = "" if self._msg is not None: text = self._format(self._msg, _default_msg, manager) # TODO(kmaehashi): keep track of already uploaded files and warn # TODO(kmaehashi): warn too many or too large files - attachments = '' + attachments = "" if needs_upload: permalinks = self._upload_files(filenames) - attachments = ''.join([f'<{link}| >' for link in permalinks]) + attachments = "".join([f"<{link}| >" for link in permalinks]) self._post_message(text + attachments) @@ -203,19 +201,21 @@ def initialize(self, manager: ExtensionsManagerProtocol) -> None: if not self._available or self._start_msg is None: return self._post_message( - self._format(self._start_msg, _default_start_msg, manager)) + self._format(self._start_msg, _default_start_msg, manager) + ) def finalize(self, manager: ExtensionsManagerProtocol) -> None: if not self._available or self._end_msg is None: return self._post_message( - self._format(self._end_msg, _default_end_msg, manager)) + self._format(self._end_msg, _default_end_msg, manager) + ) def on_error( - self, - manager: ExtensionsManagerProtocol, - exc: Exception, - tb: types.TracebackType + self, + manager: ExtensionsManagerProtocol, + exc: Exception, + tb: types.TracebackType, ) -> None: if not self._available or self._error_msg is None: return @@ -223,7 +223,8 @@ def on_error( class Slack(_SlackBase): - __doc__ = """An extension to communicate with Slack. + __doc__ = ( + """An extension to communicate with Slack. .. admonition:: Example @@ -236,7 +237,9 @@ class Slack(_SlackBase): ... filenames=["result/statistics.png"], ... upload_trigger=(max_epoch, 'epoch'), ... ) - """ + _message_spec_doc + """ + """ + + _message_spec_doc + + """ This extension can upload files along with the message when triggered. ``filenames`` can be a list of filenames (the same formatting rule as ``msg`` apply), or a callable taking (ExtensionsManager, context) and @@ -273,17 +276,18 @@ class Slack(_SlackBase): variable ``SLACK_BOT_TOKEN`` will be used. Optional, default is ``None``. """ + ) - trigger: TriggerLike = (1, 'epoch') + trigger: TriggerLike = (1, "epoch") def __init__( self, channel: str, msg: Optional[Union[str, _MessageFunc]] = None, *, - start_msg: Optional[Union[str, _MessageFunc]] = '{default}', - end_msg: Optional[Union[str, _MessageFunc]] = '{default}', - error_msg: Optional[Union[str, _ErrorMessageFunc]] = '{default}', + start_msg: Optional[Union[str, _MessageFunc]] = "{default}", + end_msg: Optional[Union[str, _MessageFunc]] = "{default}", + error_msg: Optional[Union[str, _ErrorMessageFunc]] = "{default}", thread: bool = True, filenames: Optional[Union[Sequence[str], _FilenamesFunc]] = None, upload_trigger: Optional[TriggerLike] = None, @@ -294,8 +298,9 @@ def __init__( if not _slack_sdk_available: self._available = False warnings.warn( - '`slack_sdk` package is unavailable. ' - 'The Slack extension will do nothing.') + "`slack_sdk` package is unavailable. " + "The Slack extension will do nothing." + ) return self._channel = channel @@ -311,10 +316,11 @@ def __init__( self._upload_trigger = trigger_module.get_trigger(upload_trigger) if token is None: - token = os.environ.get('SLACK_BOT_TOKEN', None) + token = os.environ.get("SLACK_BOT_TOKEN", None) if token is None: raise RuntimeError( - 'A bot `token` is needed for communicating with Slack') + "A bot `token` is needed for communicating with Slack" + ) self._client = slack_sdk.WebClient(token=token) self._thread_ts: Optional[str] = None @@ -324,11 +330,12 @@ def _upload_files(self, filenames: Sequence[str]) -> Sequence[str]: for filename in filenames: response = self._client.files_upload(file=filename) assert response.get("ok") # type: ignore[no-untyped-call] - permalinks.append(response['file']['permalink']) + permalinks.append(response["file"]["permalink"]) except Exception as e: warnings.warn( - f'Slack upload failed: {type(e).__name__}: {e} ' - f'[{filenames}]') + f"Slack upload failed: {type(e).__name__}: {e} " + f"[{filenames}]" + ) return permalinks def _post_message(self, text: str) -> None: @@ -344,12 +351,13 @@ def _post_message(self, text: str) -> None: self._thread_ts = ts except Exception as e: warnings.warn( - f'Slack post failed: {type(e).__name__}: {e} ' - f'[{text}]') + f"Slack post failed: {type(e).__name__}: {e} " f"[{text}]" + ) class SlackWebhook(_SlackBase): - __doc__ = """An extension to communicate with Slack using Incoming Webhook. + __doc__ = ( + """An extension to communicate with Slack using Incoming Webhook. .. admonition:: Example @@ -358,7 +366,9 @@ class SlackWebhook(_SlackBase): ... msg="Epoch #{manager.epoch}: loss = {val/loss}", ... end_msg="{default} \\n <@username> Check out the result!", ... ) - """ + _message_spec_doc + """ + """ + + _message_spec_doc + + """ Args: url (str): Incoming webhook URL to send messages. msg (str, callable, or None): A message to be sent when triggered. @@ -373,15 +383,16 @@ class SlackWebhook(_SlackBase): context (object): Any arbitrary user object you will need when generating a message. """ + ) def __init__( self, url: str, msg: Optional[Union[str, _MessageFunc]] = None, *, - start_msg: Optional[Union[str, _MessageFunc]] = '{default}', - end_msg: Optional[Union[str, _MessageFunc]] = '{default}', - error_msg: Optional[Union[str, _ErrorMessageFunc]] = '{default}', + start_msg: Optional[Union[str, _MessageFunc]] = "{default}", + end_msg: Optional[Union[str, _MessageFunc]] = "{default}", + error_msg: Optional[Union[str, _ErrorMessageFunc]] = "{default}", context: Any = None, ) -> None: super().__init__() @@ -393,12 +404,12 @@ def __init__( self._context = context def _post_message(self, text: str) -> None: - payload = json.dumps({'text': text}).encode('utf-8') - request_headers = {'Content-Type': 'application/json; charset=utf-8'} + payload = json.dumps({"text": text}).encode("utf-8") + request_headers = {"Content-Type": "application/json; charset=utf-8"} request = urllib.request.Request( url=self._url, data=payload, - method='POST', + method="POST", headers=request_headers, ) try: @@ -406,5 +417,6 @@ def _post_message(self, text: str) -> None: assert 200 <= response.status < 300, response except Exception as e: warnings.warn( - f'Slack WebHook request failed: {type(e).__name__}: {e} ' - f'[{text}]') + f"Slack WebHook request failed: {type(e).__name__}: {e} " + f"[{text}]" + ) diff --git a/pytorch_pfn_extras/training/extensions/snapshot_writers.py b/pytorch_pfn_extras/training/extensions/snapshot_writers.py index 988c86df5..e9ce7f7fb 100644 --- a/pytorch_pfn_extras/training/extensions/snapshot_writers.py +++ b/pytorch_pfn_extras/training/extensions/snapshot_writers.py @@ -1,2 +1,3 @@ from pytorch_pfn_extras.writing import * # NOQA + # TODO(ecastill) deprecate this diff --git a/pytorch_pfn_extras/training/extensions/util.py b/pytorch_pfn_extras/training/extensions/util.py index eebf8fab3..37b42c134 100644 --- a/pytorch_pfn_extras/training/extensions/util.py +++ b/pytorch_pfn_extras/training/extensions/util.py @@ -4,11 +4,13 @@ import time from typing import Deque, Optional, Sequence, TextIO, Tuple -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) try: from IPython import get_ipython + _ipython_available = True except ImportError: _ipython_available = False @@ -16,11 +18,11 @@ def _is_notebook() -> bool: if _ipython_available and get_ipython() is not None: - return 'IPKernelApp' in get_ipython().config + return "IPKernelApp" in get_ipython().config return False -if os.name == 'nt': +if os.name == "nt": import ctypes from ctypes import windll # type: ignore [attr-defined] @@ -29,10 +31,13 @@ def _is_notebook() -> bool: _COORD = ctypes.wintypes._COORD class _CONSOLE_SCREEN_BUFFER_INFO(ctypes.Structure): - _fields_ = [('dwSize', _COORD), ('dwCursorPosition', _COORD), - ('wAttributes', ctypes.c_ushort), - ('srWindow', ctypes.wintypes.SMALL_RECT), - ('dwMaximumWindowSize', _COORD)] + _fields_ = [ + ("dwSize", _COORD), + ("dwCursorPosition", _COORD), + ("wAttributes", ctypes.c_ushort), + ("srWindow", ctypes.wintypes.SMALL_RECT), + ("dwMaximumWindowSize", _COORD), + ] def set_console_cursor_position(x: int, y: int) -> None: """Set relative cursor position from current position to (x,y)""" @@ -64,29 +69,31 @@ def erase_console(x: int, y: int, mode: int = 0) -> None: cur_pos = csbi.dwCursorPosition wr = ctypes.c_ulong() if mode == 0: - num = csbi.srWindow.Right * ( - csbi.srWindow.Bottom - cur_pos.Y) - cur_pos.X + num = ( + csbi.srWindow.Right * (csbi.srWindow.Bottom - cur_pos.Y) + - cur_pos.X + ) windll.kernel32.FillConsoleOutputCharacterA( - whnd, ord(' '), num, cur_pos, ctypes.byref(wr)) + whnd, ord(" "), num, cur_pos, ctypes.byref(wr) + ) elif mode == 1: num = cur_pos.X windll.kernel32.FillConsoleOutputCharacterA( - whnd, ord(' '), num, _COORD(0, cur_pos.Y), ctypes.byref(wr)) + whnd, ord(" "), num, _COORD(0, cur_pos.Y), ctypes.byref(wr) + ) elif mode == 2: - os.system('cls') + os.system("cls") class ProgressBar: - def __init__(self, out: Optional[TextIO] = None) -> None: self._out = sys.stdout if out is None else out - self._recent_timing: Deque[Tuple[int, float, float]] = collections.deque( - [], maxlen=100) + self._recent_timing: Deque[ + Tuple[int, float, float] + ] = collections.deque([], maxlen=100) def update_speed( - self, - iteration: int, - epoch_detail: float + self, iteration: int, epoch_detail: float ) -> Tuple[float, float]: now = time.time() self._recent_timing.append((iteration, epoch_detail, now)) @@ -96,16 +103,15 @@ def update_speed( speed_t = (iteration - old_t) / span speed_e = (epoch_detail - old_e) / span else: - speed_t = float('inf') - speed_e = float('inf') + speed_t = float("inf") + speed_e = float("inf") return speed_t, speed_e def get_lines(self) -> Sequence[str]: raise NotImplementedError def update( - self, - manager: Optional[ExtensionsManagerProtocol] = None + self, manager: Optional[ExtensionsManagerProtocol] = None ) -> None: self.erase_console() @@ -121,18 +127,18 @@ def close(self) -> None: self.flush() def erase_console(self) -> None: - if os.name == 'nt': + if os.name == "nt": erase_console(0, 0) else: - self._out.write('\033[J') + self._out.write("\033[J") def move_cursor_up(self, n: int) -> None: # move the cursor to the head of the progress bar - if os.name == 'nt': - set_console_cursor_position(0, - n) + if os.name == "nt": + set_console_cursor_position(0, -n) else: - self._out.write('\033[{:d}A'.format(n)) + self._out.write("\033[{:d}A".format(n)) def flush(self) -> None: - if hasattr(self._out, 'flush'): + if hasattr(self._out, "flush"): self._out.flush() diff --git a/pytorch_pfn_extras/training/extensions/value_observation.py b/pytorch_pfn_extras/training/extensions/value_observation.py index 906650b70..0a615db9c 100644 --- a/pytorch_pfn_extras/training/extensions/value_observation.py +++ b/pytorch_pfn_extras/training/extensions/value_observation.py @@ -1,14 +1,15 @@ from typing import Any, Callable import torch.optim - from pytorch_pfn_extras.training import extension -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) def observe_value( - observation_key: str, - target_func: Callable[[ExtensionsManagerProtocol], Any], + observation_key: str, + target_func: Callable[[ExtensionsManagerProtocol], Any], ) -> Callable[[ExtensionsManagerProtocol], None]: """Returns an extension to continuously record a value. @@ -26,17 +27,20 @@ def observe_value( .ExtensionsManager>` method. """ + @extension.make_extension( - trigger=(1, 'epoch'), priority=extension.PRIORITY_WRITER) + trigger=(1, "epoch"), priority=extension.PRIORITY_WRITER + ) def _observe_value(manager: ExtensionsManagerProtocol) -> None: manager.observation[observation_key] = target_func(manager) + return _observe_value def observe_lr( - optimizer: torch.optim.Optimizer, - param_group: int = 0, - observation_key: str = 'lr', + optimizer: torch.optim.Optimizer, + param_group: int = 0, + observation_key: str = "lr", ) -> Any: """Returns an extension to record the learning rate. @@ -57,4 +61,5 @@ def observe_lr( """ return observe_value( observation_key, - lambda manager: optimizer.param_groups[param_group]['lr']) + lambda manager: optimizer.param_groups[param_group]["lr"], + ) diff --git a/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py b/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py index 7f9b9105e..a54d0446f 100644 --- a/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py +++ b/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py @@ -1,13 +1,13 @@ -from typing import Any, Dict, Optional, Tuple, Union import warnings +from typing import Any, Dict, Optional, Tuple, Union import numpy import torch - from pytorch_pfn_extras.training import extension from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) matplotlib: Any = None _available: Optional[bool] = None @@ -16,10 +16,13 @@ _plot_common_kwargs: Any = None -def percentile(a: torch.Tensor, q: Union[float, Tuple[float, ...]], axis: int) -> Any: +def percentile( + a: torch.Tensor, q: Union[float, Tuple[float, ...]], axis: int +) -> Any: # fallback to numpy return torch.Tensor( - numpy.percentile(a.cpu().numpy(), q, axis)) # type: ignore[no-untyped-call] + numpy.percentile(a.cpu().numpy(), q, axis) + ) # type: ignore[no-untyped-call] def matplotlib_savefun(target: Tuple[Any, Any], file_o: Any) -> None: @@ -34,20 +37,24 @@ def _try_import_matplotlib() -> None: global _plot_color, _plot_color_trans, _plot_common_kwargs try: import matplotlib + _available = True except ImportError: _available = False if _available: - if hasattr(matplotlib.colors, 'to_rgba'): + if hasattr(matplotlib.colors, "to_rgba"): _to_rgba = matplotlib.colors.to_rgba else: # For matplotlib 1.x _to_rgba = matplotlib.colors.ColorConverter().to_rgba - _plot_color = _to_rgba('#1f77b4') # C0 color + _plot_color = _to_rgba("#1f77b4") # C0 color _plot_color_trans = _plot_color[:3] + (0.2,) # apply alpha _plot_common_kwargs = { - 'alpha': 0.2, 'linewidth': 0, 'color': _plot_color_trans} + "alpha": 0.2, + "linewidth": 0, + "color": _plot_color_trans, + } def _check_available() -> None: @@ -55,10 +62,12 @@ def _check_available() -> None: _try_import_matplotlib() if not _available: - warnings.warn('matplotlib is not installed on your environment, ' - 'so nothing will be plotted at this time. ' - 'Please install matplotlib to plot figures.\n\n' - ' $ pip install matplotlib\n') + warnings.warn( + "matplotlib is not installed on your environment, " + "so nothing will be plotted at this time. " + "Please install matplotlib to plot figures.\n\n" + " $ pip install matplotlib\n" + ) def _unpack_variables(x: Any, memo: Any = None) -> Any: @@ -79,10 +88,10 @@ class Reservoir: """Reservoir sample with a fixed sized buffer.""" def __init__( - self, - size: int, - data_shape: Tuple[int, ...], - dtype: Any = numpy.float32, + self, + size: int, + data_shape: Tuple[int, ...], + dtype: Any = numpy.float32, ) -> None: self.size = size self.data = numpy.zeros((size,) + data_shape, dtype=dtype) @@ -93,15 +102,17 @@ def add(self, x: Any, idx: Any = None) -> None: if self.counter < self.size: self.data[self.counter] = x self.idxs[self.counter] = idx or self.counter - elif self.counter >= self.size and \ - numpy.random.random() < self.size / float(self.counter + 1): + elif ( + self.counter >= self.size + and numpy.random.random() < self.size / float(self.counter + 1) + ): i = numpy.random.randint(self.size) self.data[i] = x self.idxs[i] = idx or self.counter self.counter += 1 def get_data(self) -> Tuple[Any, Any]: - idxs = self.idxs[:min(self.counter, self.size)] + idxs = self.idxs[: min(self.counter, self.size)] sorted_args = numpy.argsort(idxs) return idxs[sorted_args], self.data[sorted_args] @@ -111,20 +122,22 @@ class Statistician: """Helper to compute basic NumPy-like statistics.""" def __init__( - self, - collect_mean: bool, - collect_std: bool, - percentile_sigmas: Union[float, Tuple[float, ...]], + self, + collect_mean: bool, + collect_std: bool, + percentile_sigmas: Union[float, Tuple[float, ...]], ) -> None: self.collect_mean = collect_mean self.collect_std = collect_std self.percentile_sigmas = percentile_sigmas - def __call__(self, x: Any, axis: Any = 0, dtype: Any = None) -> Dict[str, Any]: + def __call__( + self, x: Any, axis: Any = 0, dtype: Any = None + ) -> Dict[str, Any]: if axis is None: axis = tuple(range(x.ndim)) elif not isinstance(axis, (tuple, list)): - axis = axis, + axis = (axis,) return self.collect(x, axis) @@ -132,14 +145,14 @@ def collect(self, x: Any, axis: int) -> Dict[str, Any]: out = dict() if self.collect_mean: - out['mean'] = x.mean(axis=axis) + out["mean"] = x.mean(axis=axis) if self.collect_std: - out['std'] = x.std(axis=axis) + out["std"] = x.std(axis=axis) if self.percentile_sigmas: p = percentile(x, self.percentile_sigmas, axis=axis) - out['percentile'] = p + out["percentile"] = p return out @@ -215,26 +228,34 @@ class VariableStatisticsPlot(extension.Extension): """ def __init__( - self, - targets: Any, - max_sample_size: int = 1000, - report_data: bool = True, - report_grad: bool = True, - plot_mean: bool = True, - plot_std: bool = True, - percentile_sigmas: Union[float, Tuple[float, ...]] = ( - 0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100), - trigger: trigger_module.TriggerLike = (1, 'epoch'), - filename: Optional[str] = None, - figsize: Optional[Tuple[int, ...]] = None, - marker: Optional[str] = None, - grid: bool = True, - **kwargs: Any, + self, + targets: Any, + max_sample_size: int = 1000, + report_data: bool = True, + report_grad: bool = True, + plot_mean: bool = True, + plot_std: bool = True, + percentile_sigmas: Union[float, Tuple[float, ...]] = ( + 0, + 0.13, + 2.28, + 15.87, + 50, + 84.13, + 97.72, + 99.87, + 100, + ), + trigger: trigger_module.TriggerLike = (1, "epoch"), + filename: Optional[str] = None, + figsize: Optional[Tuple[int, ...]] = None, + marker: Optional[str] = None, + grid: bool = True, + **kwargs: Any, ): - _check_available() - file_name = kwargs.get('file_name', 'statistics.png') + file_name = kwargs.get("file_name", "statistics.png") if filename is None: filename = file_name del file_name # avoid accidental use @@ -242,24 +263,27 @@ def __init__( self._vars = _unpack_variables(targets) if not self._vars: raise ValueError( - 'Need at least one variables for which to collect statistics.' - '\nActual: 0 <= 0') + "Need at least one variables for which to collect statistics." + "\nActual: 0 <= 0" + ) if not any((plot_mean, plot_std, bool(percentile_sigmas))): - raise ValueError('Nothing to plot') + raise ValueError("Nothing to plot") self._keys = [] if report_data: - self._keys.append('data') + self._keys.append("data") if report_grad: - self._keys.append('grad') + self._keys.append("grad") self._report_data = report_data self._report_grad = report_grad self._statistician = Statistician( - collect_mean=plot_mean, collect_std=plot_std, - percentile_sigmas=percentile_sigmas) + collect_mean=plot_mean, + collect_std=plot_std, + percentile_sigmas=percentile_sigmas, + ) self._plot_mean = plot_mean self._plot_std = plot_std @@ -270,7 +294,7 @@ def __init__( self._figsize = figsize self._marker = marker self._grid = grid - self._writer = kwargs.get('writer', None) + self._writer = kwargs.get("writer", None) if not self._plot_percentile: n_percentile = 0 @@ -280,7 +304,9 @@ def __init__( else: n_percentile = len(percentile_sigmas) self._data_shape = ( - len(self._keys), int(plot_mean) + int(plot_std) + n_percentile) + len(self._keys), + int(plot_mean) + int(plot_std) + n_percentile, + ) self._samples = Reservoir(max_sample_size, data_shape=self._data_shape) @staticmethod @@ -309,15 +335,17 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: stat_list = [] if self._plot_mean: stat_list.append( - numpy.atleast_1d(stat_dict['mean'].cpu().numpy())) + numpy.atleast_1d(stat_dict["mean"].cpu().numpy()) + ) if self._plot_std: stat_list.append( - numpy.atleast_1d(stat_dict['std'].cpu().numpy())) + numpy.atleast_1d(stat_dict["std"].cpu().numpy()) + ) if self._plot_percentile: - stat_list.append( - numpy.atleast_1d(stat_dict['percentile'])) + stat_list.append(numpy.atleast_1d(stat_dict["percentile"])) stats[i] = numpy.concatenate( # type: ignore[no-untyped-call] - stat_list, axis=0) + stat_list, axis=0 + ) self._samples.add(stats, idx=manager.iteration) @@ -325,16 +353,18 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: self.save_plot_using_module(plt, manager) def save_plot_using_module( - self, - plt: Any, - manager: ExtensionsManagerProtocol, + self, + plt: Any, + manager: ExtensionsManagerProtocol, ) -> None: - nrows = int(self._plot_mean or self._plot_std) \ - + int(self._plot_percentile) + nrows = int(self._plot_mean or self._plot_std) + int( + self._plot_percentile + ) ncols = len(self._keys) fig, axes = plt.subplots( - nrows, ncols, figsize=self._figsize, sharex=True) + nrows, ncols, figsize=self._figsize, sharex=True + ) if not isinstance(axes, numpy.ndarray): # single subplot axes = numpy.asarray([axes]) @@ -362,17 +392,26 @@ def save_plot_using_module( if self._plot_mean or self._plot_std: if self._plot_mean and self._plot_std: ax.errorbar( - idxs, data[:, col, 0], data[:, col, 1], - color=_plot_color, ecolor=_plot_color_trans, - label='mean, std', marker=self._marker) + idxs, + data[:, col, 0], + data[:, col, 1], + color=_plot_color, + ecolor=_plot_color_trans, + label="mean, std", + marker=self._marker, + ) else: if self._plot_mean: - label = 'mean' + label = "mean" elif self._plot_std: - label = 'std' + label = "std" ax.plot( - idxs, data[:, col, 0], color=_plot_color, label=label, - marker=self._marker) + idxs, + data[:, col, 0], + color=_plot_color, + label=label, + marker=self._marker, + ) row += 1 if self._plot_percentile: @@ -384,22 +423,27 @@ def save_plot_using_module( # percentile is the mid percentile and the number of # percentiles are odd ax.plot( - idxs, data[:, col, offset + i], color=_plot_color, - label='percentile', marker=self._marker) + idxs, + data[:, col, offset + i], + color=_plot_color, + label="percentile", + marker=self._marker, + ) else: if i == n_percentile_mid_floor: # Last percentiles and the number of all # percentiles are even - label = 'percentile' + label = "percentile" else: - label = '_nolegend_' + label = "_nolegend_" ax.fill_between( idxs, data[:, col, offset + i], data[:, col, -i - 1], label=label, - **_plot_common_kwargs) - ax.set_xlabel('iteration') + **_plot_common_kwargs, + ) + ax.set_xlabel("iteration") for ax in axes.ravel(): ax.legend() @@ -407,8 +451,12 @@ def save_plot_using_module( ax.grid() ax.set_axisbelow(True) - writer(self._filename, manager.out, (fig, plt), # type: ignore - savefun=matplotlib_savefun) + writer( + self._filename, + manager.out, + (fig, plt), # type: ignore + savefun=matplotlib_savefun, + ) def finalize(self, manager: ExtensionsManagerProtocol) -> None: if self._writer is not None: diff --git a/pytorch_pfn_extras/training/manager.py b/pytorch_pfn_extras/training/manager.py index 28dbc7f57..3ba206e86 100644 --- a/pytorch_pfn_extras/training/manager.py +++ b/pytorch_pfn_extras/training/manager.py @@ -1,30 +1,36 @@ import collections import contextlib import copy -from pytorch_pfn_extras.profiler import record import time +import warnings from typing import ( - Any, Dict, Generator, Mapping, Optional, Sequence, Union, TYPE_CHECKING + TYPE_CHECKING, + Any, + Dict, + Generator, + Mapping, + Optional, + Sequence, + Union, ) -import warnings - -import torch import pytorch_pfn_extras -from pytorch_pfn_extras import writing -from pytorch_pfn_extras import reporting +import torch +from pytorch_pfn_extras import reporting, writing +from pytorch_pfn_extras.profiler import record +from pytorch_pfn_extras.training import _util as util_module from pytorch_pfn_extras.training import extension as extension_module from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training import _util as util_module from pytorch_pfn_extras.training._transform_model import ( - default_transform_model, _TransformModel, + _TransformModel, + default_transform_model, ) _get_time = time.perf_counter class _ManagerProxy: - def __init__(self, manager: '_BaseExtensionsManager') -> None: + def __init__(self, manager: "_BaseExtensionsManager") -> None: self._manager = manager @property @@ -99,26 +105,27 @@ class _BaseExtensionsManager: """ def __init__( - self, - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - optimizers: Union[torch.optim.Optimizer, - Mapping[str, torch.optim.Optimizer]], - max_epochs: int, - extensions: Optional[Sequence['extension_module.ExtensionLike']], - out_dir: str, - writer: Optional[writing.Writer], - stop_trigger: 'trigger_module.TriggerLike' = None, - transform_model: _TransformModel = default_transform_model, - enable_profile: bool = False, + self, + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + optimizers: Union[ + torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer] + ], + max_epochs: int, + extensions: Optional[Sequence["extension_module.ExtensionLike"]], + out_dir: str, + writer: Optional[writing.Writer], + stop_trigger: "trigger_module.TriggerLike" = None, + transform_model: _TransformModel = default_transform_model, + enable_profile: bool = False, ) -> None: if extensions is None: extensions = [] if stop_trigger is None: self._stop_trigger = trigger_module.get_trigger( - (max_epochs, 'epoch')) + (max_epochs, "epoch") + ) else: - self._stop_trigger = trigger_module.get_trigger( - stop_trigger) + self._stop_trigger = trigger_module.get_trigger(stop_trigger) if writer is None: writer = writing.SimpleWriter(out_dir=out_dir) # triggers are stateful, so we need to make a copy for internal use @@ -142,22 +149,22 @@ def __init__( else: if not isinstance(models, torch.nn.Module): raise ValueError( - 'model must be an instance of dict or toch.nn.Module') - self._models = {'main': models} + "model must be an instance of dict or toch.nn.Module" + ) + self._models = {"main": models} if isinstance(optimizers, collections.abc.Mapping): self._optimizers = optimizers else: # TODO(ecastill) Optimizer type is not checked because of tests # using mocks and other classes - self._optimizers = {'main': optimizers} + self._optimizers = {"main": optimizers} for name, model in self._models.items(): # TODO we should not initialize extensions at this point # so, we cannot use `self.models` model = self._transform_model(name, model) self.reporter.add_observer(name, model) - self.reporter.add_observers( - name, model.named_modules()) + self.reporter.add_observers(name, model.named_modules()) self._finalized = False self.max_epochs = max_epochs self._start_iteration = 0 @@ -165,7 +172,8 @@ def __init__( self._start_time: Optional[float] = None self.__iters_per_epoch: Optional[int] = None self._extensions: Dict[ - str, extension_module.ExtensionEntry] = collections.OrderedDict() + str, extension_module.ExtensionEntry + ] = collections.OrderedDict() for ext in extensions: self.extend(ext) @@ -189,16 +197,18 @@ def _check_model_available(self) -> None: if self._model_available: return raise RuntimeError( - 'Models cannot be accessed from extensions in this iteration. ' - 'Extensions accessing models must declare ' - '`needs_model_state = True` attribute.') + "Models cannot be accessed from extensions in this iteration. " + "Extensions accessing models must declare " + "`needs_model_state = True` attribute." + ) @property def models(self) -> Mapping[str, torch.nn.Module]: self.start_extensions() self._check_model_available() - models = {k: self._transform_model(k, v) - for k, v in self._models.items()} + models = { + k: self._transform_model(k, v) for k, v in self._models.items() + } return models @property @@ -216,7 +226,8 @@ def optimizers(self) -> Mapping[str, torch.optim.Optimizer]: def elapsed_time(self) -> float: if self._start_time is None: raise RuntimeError( - 'Unavailable until the initial run_iteration call.') + "Unavailable until the initial run_iteration call." + ) return _get_time() - self._start_time @property @@ -251,22 +262,21 @@ def out(self) -> str: return self.writer.out_dir @property - def updater(self) -> '_BaseExtensionsManager': + def updater(self) -> "_BaseExtensionsManager": warnings.warn( - 'The `updater` attribute has been deprecated in v0.3.0.' - ' Use `iteration`, `epoch`, and `epoch_detail` attributes in' - ' `ExtensionsManager` instead of attributes under `updater`.' - ' You may also need to update the filename template specified to' - ' snapshot extensions (e.g., from ' - '`snapshot_iter_{.updater.iteration}` to' - ' `snapshot_iter_{.iteration}`).', DeprecationWarning) + "The `updater` attribute has been deprecated in v0.3.0." + " Use `iteration`, `epoch`, and `epoch_detail` attributes in" + " `ExtensionsManager` instead of attributes under `updater`." + " You may also need to update the filename template specified to" + " snapshot extensions (e.g., from " + "`snapshot_iter_{.updater.iteration}` to" + " `snapshot_iter_{.iteration}`).", + DeprecationWarning, + ) return self def _prepare_for_training( - self, - start_iteration: int, - start_execution: int, - iters_per_epoch: int + self, start_iteration: int, start_execution: int, iters_per_epoch: int ) -> None: self.iteration = start_iteration self.execution = start_execution @@ -281,15 +291,14 @@ def start_extensions(self) -> None: exts = self._extensions extension_order = sorted( - exts.keys(), - key=lambda name: exts[name].priority, reverse=True) - self.extensions = [(name, exts[name]) - for name in extension_order] + exts.keys(), key=lambda name: exts[name].priority, reverse=True + ) + self.extensions = [(name, exts[name]) for name in extension_order] # invoke initializer of each extension for _, entry in self.extensions: initializer = entry.extension.initialize - finished = getattr(entry.trigger, 'finished', False) + finished = getattr(entry.trigger, "finished", False) if not finished: initializer(self) @@ -301,17 +310,17 @@ def start_extensions(self) -> None: entry.extension(self) def extend( - self, - extension: Union[ - 'extension_module.ExtensionLike', - 'extension_module.ExtensionEntry', - ], - name: Optional[str] = None, - trigger: 'trigger_module.TriggerLike' = None, - priority: Optional[int] = None, - *, - call_before_training: Optional[bool] = None, - **kwargs: Dict[str, Any], + self, + extension: Union[ + "extension_module.ExtensionLike", + "extension_module.ExtensionEntry", + ], + name: Optional[str] = None, + trigger: "trigger_module.TriggerLike" = None, + priority: Optional[int] = None, + *, + call_before_training: Optional[bool] = None, + **kwargs: Dict[str, Any], ) -> None: """Registers an extension to the manager. @@ -353,7 +362,8 @@ def extend( """ if self._start_extensions_called: raise RuntimeError( - 'extend called after the extensions were initialized') + "extend called after the extensions were initialized" + ) if isinstance(extension, extension_module.ExtensionEntry): entry = extension @@ -373,7 +383,7 @@ def extend( ordinal = 0 while modified_name in self._extensions: ordinal += 1 - modified_name = '%s_%d' % (name, ordinal) + modified_name = "%s_%d" % (name, ordinal) entry._update_name(modified_name) self._extensions[modified_name] = entry @@ -392,7 +402,7 @@ def get_extension(self, name: str) -> extension_module.Extension: if name in extensions: return extensions[name].extension else: - raise ValueError('extension %s not found' % name) + raise ValueError("extension %s not found" % name) def _run_on_error(self, exc: Exception) -> None: if not self._run_on_error_called: @@ -428,15 +438,15 @@ def run_extensions(self) -> None: to_run.append((name, entry.extension)) else: with record( - f'pytorch_pfn_extras.training.ExtensionsManager' - f'.run_extensions:{name}', + f"pytorch_pfn_extras.training.ExtensionsManager" + f".run_extensions:{name}", enable=self._enable_profile, ): entry.extension(self) for name, extension in to_run: with record( - f'pytorch_pfn_extras.training.ExtensionsManager' - f'.run_extensions:{name}', + f"pytorch_pfn_extras.training.ExtensionsManager" + f".run_extensions:{name}", enable=self._enable_profile, ): extension(self) @@ -452,9 +462,10 @@ def needs_model_state(self, iteration: Optional[int] = None) -> bool: # is increased just right before calling extensions iteration = self.iteration + 1 for _, entry in self._extensions.items(): - needs_state = getattr(entry.extension, 'needs_model_state', False) - if (needs_state and entry.trigger.may_fire( - iteration, self._iters_per_epoch)): + needs_state = getattr(entry.extension, "needs_model_state", False) + if needs_state and entry.trigger.may_fire( + iteration, self._iters_per_epoch + ): return True return False @@ -469,57 +480,66 @@ def _finalize_extensions(self) -> None: pass def state_dict( - self, + self, ) -> Dict[str, Any]: to_save: Dict[str, Any] = {} - to_save['_start_iteration'] = self.iteration - to_save['_start_execution'] = self.execution + to_save["_start_iteration"] = self.iteration + to_save["_start_execution"] = self.execution # Use self.models to apply transform_model - to_save['models'] = { - name: self.models[name].state_dict() - for name in self.models} - to_save['optimizers'] = {name: self._optimizers[name].state_dict() - for name in self._optimizers} - to_save['extensions'] = {name: self._extensions[name].state_dict() - for name in self._extensions} - to_save['ppe_version'] = pytorch_pfn_extras.__version__ + to_save["models"] = { + name: self.models[name].state_dict() for name in self.models + } + to_save["optimizers"] = { + name: self._optimizers[name].state_dict() + for name in self._optimizers + } + to_save["extensions"] = { + name: self._extensions[name].state_dict() + for name in self._extensions + } + to_save["ppe_version"] = pytorch_pfn_extras.__version__ return to_save def _check_snapshot_version(self, ppe_version: Optional[str]) -> None: must_warn = ppe_version is None or ( - ppe_version != pytorch_pfn_extras.__version__) + ppe_version != pytorch_pfn_extras.__version__ + ) if not must_warn: return - msg = ('You are trying to load a snapshot file taken using a different ' - 'PPE version.\n') + msg = ( + "You are trying to load a snapshot file taken using a different " + "PPE version.\n" + ) if ppe_version is not None: - msg += (f'Snapshot taken with PPE {ppe_version} but ' - f'currently using PPE {pytorch_pfn_extras.__version__}') + msg += ( + f"Snapshot taken with PPE {ppe_version} but " + f"currently using PPE {pytorch_pfn_extras.__version__}" + ) warnings.warn(msg) def load_state_dict( - self, - to_load: Dict[str, Any], + self, + to_load: Dict[str, Any], ) -> None: - self._check_snapshot_version(to_load.get('ppe_version', None)) - self._start_iteration = to_load['_start_iteration'] + self._check_snapshot_version(to_load.get("ppe_version", None)) + self._start_iteration = to_load["_start_iteration"] self.iteration = self._start_iteration - self._start_execution = to_load.get('_start_execution', self.iteration) + self._start_execution = to_load.get("_start_execution", self.iteration) self.execution = self._start_execution for name in self.models: # TODO(ecastill) map_loc when loading the model and DDP check # Use self.models to apply transform_model - self.models[name].load_state_dict(to_load['models'][name]) + self.models[name].load_state_dict(to_load["models"][name]) for name in self._optimizers: - self._optimizers[name].load_state_dict(to_load['optimizers'][name]) + self._optimizers[name].load_state_dict(to_load["optimizers"][name]) for name in self._extensions: - self._extensions[name].load_state_dict(to_load['extensions'][name]) + self._extensions[name].load_state_dict(to_load["extensions"][name]) class ExtensionsManager(_BaseExtensionsManager): @@ -545,33 +565,43 @@ class ExtensionsManager(_BaseExtensionsManager): """ def __init__( - self, - models: Union[torch.nn.Module, Dict[str, torch.nn.Module]], - optimizers: Union[torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]], - max_epochs: int, - *, - iters_per_epoch: int, - extensions: Optional[Sequence['extension_module.ExtensionLike']] = None, - out_dir: str = 'result', - stop_trigger: 'trigger_module.TriggerLike' = None, - writer: Optional[writing.Writer] = None, - transform_model: _TransformModel = lambda n, x: x, - enable_profile: bool = False, + self, + models: Union[torch.nn.Module, Dict[str, torch.nn.Module]], + optimizers: Union[ + torch.optim.Optimizer, Dict[str, torch.optim.Optimizer] + ], + max_epochs: int, + *, + iters_per_epoch: int, + extensions: Optional[Sequence["extension_module.ExtensionLike"]] = None, + out_dir: str = "result", + stop_trigger: "trigger_module.TriggerLike" = None, + writer: Optional[writing.Writer] = None, + transform_model: _TransformModel = lambda n, x: x, + enable_profile: bool = False, ) -> None: super().__init__( - models, optimizers, max_epochs, extensions, - out_dir, writer, stop_trigger, transform_model, enable_profile) + models, + optimizers, + max_epochs, + extensions, + out_dir, + writer, + stop_trigger, + transform_model, + enable_profile, + ) if iters_per_epoch < 1: raise ValueError( - 'iters_per_epoch must be an integer >= 1 ({} given)'.format( - iters_per_epoch)) + "iters_per_epoch must be an integer >= 1 ({} given)".format( + iters_per_epoch + ) + ) self._prepare_for_training(0, 0, iters_per_epoch) @contextlib.contextmanager def run_iteration( - self, - *, - step_optimizers: Optional[Sequence[str]] = None + self, *, step_optimizers: Optional[Sequence[str]] = None ) -> Generator[None, None, None]: """Context manager to run an iteration. @@ -583,7 +613,7 @@ def run_iteration( to call `zero_grad` and `step` """ if self._finalized: - raise RuntimeError('Attempted to run a finalized manager') + raise RuntimeError("Attempted to run a finalized manager") if self._start_time is None: self._start_time = _get_time() self.start_extensions() @@ -639,36 +669,47 @@ class IgniteExtensionsManager(_BaseExtensionsManager): enable_profile (bool): Flag to enable/disable profiling of iterations. Default is `False`. """ + def __init__( - self, - engine: 'ignite.engine.Engine', - models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], - optimizers: Union[torch.optim.Optimizer, - Mapping[str, torch.optim.Optimizer]], - max_epochs: int, - *, - extensions: Optional[Sequence['extension_module.ExtensionLike']] = None, - out_dir: str = 'result', - writer: Optional[writing.Writer] = None, - enable_profile: bool = False, + self, + engine: "ignite.engine.Engine", + models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]], + optimizers: Union[ + torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer] + ], + max_epochs: int, + *, + extensions: Optional[Sequence["extension_module.ExtensionLike"]] = None, + out_dir: str = "result", + writer: Optional[writing.Writer] = None, + enable_profile: bool = False, ) -> None: import ignite + if not isinstance(engine, ignite.engine.Engine): raise TypeError("Argument 'engine' must be of ignite.Engine type.") - if (util_module._get_ignite_version(ignite.__version__) - < util_module._get_ignite_version('0.3.0')): - raise ImportError('Ignite version found {}. ' - 'Required is >=0.3.0'.format(ignite.__version__)) + if util_module._get_ignite_version( + ignite.__version__ + ) < util_module._get_ignite_version("0.3.0"): + raise ImportError( + "Ignite version found {}. " + "Required is >=0.3.0".format(ignite.__version__) + ) super().__init__( - models, optimizers, max_epochs, extensions, out_dir, writer, - enable_profile=enable_profile) + models, + optimizers, + max_epochs, + extensions, + out_dir, + writer, + enable_profile=enable_profile, + ) self.engine = engine self._start_epoch = 0 # Used to correctly restore snapshots self.set_ignite_handlers() def set_ignite_handlers(self) -> None: - from ignite.engine import Engine - from ignite.engine import Events + from ignite.engine import Engine, Events # Set a handler that sets the reporter scope on every iteration @self.engine.on(Events.ITERATION_STARTED) @@ -689,7 +730,8 @@ def set_training_started(engine: Engine) -> None: self._start_time = _get_time() # Initialize manager again after all state is restored self._prepare_for_training( - start_iteration, start_iteration, iters_per_epoch) + start_iteration, start_iteration, iters_per_epoch + ) # Make all the next # handlers to be executed after user defined ones @@ -710,13 +752,13 @@ def set_extensions_cleanup(engine: Engine) -> None: def state_dict(self) -> Dict[str, Any]: to_save = super().state_dict() - to_save['_epoch_length'] = self.engine.state.epoch_length - to_save['_start_iteration'] = self.engine.state.iteration + to_save["_epoch_length"] = self.engine.state.epoch_length + to_save["_start_iteration"] = self.engine.state.iteration return to_save def load_state_dict( - self, - to_load: Dict[str, Any], + self, + to_load: Dict[str, Any], ) -> None: super().load_state_dict(to_load) - self._start_epoch = self._start_iteration // to_load['_epoch_length'] + self._start_epoch = self._start_iteration // to_load["_epoch_length"] diff --git a/pytorch_pfn_extras/training/metrics.py b/pytorch_pfn_extras/training/metrics.py index 7e7321663..4139e0ca3 100644 --- a/pytorch_pfn_extras/training/metrics.py +++ b/pytorch_pfn_extras/training/metrics.py @@ -2,7 +2,6 @@ import torch - Batch = Dict[str, torch.Tensor] MetricType = Callable[[Batch, Batch], Batch] @@ -17,12 +16,13 @@ class AccuracyMetric: .. seealso: :func:`pytorch_pfn_extras.engine.create_evaluator` """ + def __init__(self, label_key: str, output_key: str) -> None: self.label_key = label_key self.output_key = output_key def _preprocess_input( - self, batch: Batch, out: Batch + self, batch: Batch, out: Batch ) -> Tuple[torch.Tensor, int, torch.Tensor]: labels = batch[self.label_key].cpu() n_output = labels.shape[0] diff --git a/pytorch_pfn_extras/training/trigger.py b/pytorch_pfn_extras/training/trigger.py index cbfb573e9..5a7bd42de 100644 --- a/pytorch_pfn_extras/training/trigger.py +++ b/pytorch_pfn_extras/training/trigger.py @@ -1,8 +1,12 @@ from pytorch_pfn_extras.training._trigger_util import Trigger # NOQA -from pytorch_pfn_extras.training._trigger_util import get_trigger # NOQA from pytorch_pfn_extras.training._trigger_util import TriggerFunc # NOQA from pytorch_pfn_extras.training._trigger_util import TriggerLike # NOQA +from pytorch_pfn_extras.training._trigger_util import get_trigger # NOQA +from pytorch_pfn_extras.training._trigger_util import ( # NOQA + _never_fire_trigger, +) # For backward compatibility -from pytorch_pfn_extras.training.triggers.interval_trigger import IntervalTrigger # NOQA -from pytorch_pfn_extras.training._trigger_util import _never_fire_trigger # NOQA +from pytorch_pfn_extras.training.triggers.interval_trigger import ( # NOQA + IntervalTrigger, +) diff --git a/pytorch_pfn_extras/training/triggers/__init__.py b/pytorch_pfn_extras/training/triggers/__init__.py index 074978b92..99602a71c 100644 --- a/pytorch_pfn_extras/training/triggers/__init__.py +++ b/pytorch_pfn_extras/training/triggers/__init__.py @@ -1,9 +1,21 @@ # import classes and functions -from pytorch_pfn_extras.training.triggers.early_stopping_trigger import EarlyStoppingTrigger # NOQA -from pytorch_pfn_extras.training.triggers.interval_trigger import IntervalTrigger # NOQA -from pytorch_pfn_extras.training.triggers.manual_schedule_trigger import ManualScheduleTrigger # NOQA -from pytorch_pfn_extras.training.triggers.minmax_value_trigger import BestValueTrigger # NOQA -from pytorch_pfn_extras.training.triggers.minmax_value_trigger import MaxValueTrigger # NOQA -from pytorch_pfn_extras.training.triggers.minmax_value_trigger import MinValueTrigger # NOQA -from pytorch_pfn_extras.training.triggers.once_trigger import OnceTrigger # NOQA -from pytorch_pfn_extras.training.triggers.time_trigger import TimeTrigger # NOQA +from pytorch_pfn_extras.training.triggers.early_stopping_trigger import ( # NOQA + EarlyStoppingTrigger, +) +from pytorch_pfn_extras.training.triggers.interval_trigger import ( # NOQA + IntervalTrigger, +) +from pytorch_pfn_extras.training.triggers.manual_schedule_trigger import ( # NOQA + ManualScheduleTrigger, +) +from pytorch_pfn_extras.training.triggers.minmax_value_trigger import ( # NOQA + BestValueTrigger, + MaxValueTrigger, + MinValueTrigger, +) +from pytorch_pfn_extras.training.triggers.once_trigger import ( # NOQA + OnceTrigger, +) +from pytorch_pfn_extras.training.triggers.time_trigger import ( # NOQA + TimeTrigger, +) diff --git a/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py b/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py index 855a4ebec..7262e2e43 100644 --- a/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py +++ b/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py @@ -1,15 +1,18 @@ import operator -from typing import Tuple, TYPE_CHECKING import warnings +from typing import TYPE_CHECKING, Tuple from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import trigger -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) if TYPE_CHECKING: - from pytorch_pfn_extras.training._trigger_util import TriggerLike - from pytorch_pfn_extras.training._trigger_util import UnitLiteral + from pytorch_pfn_extras.training._trigger_util import ( + TriggerLike, + UnitLiteral, + ) class EarlyStoppingTrigger(trigger.Trigger): @@ -46,13 +49,13 @@ class EarlyStoppingTrigger(trigger.Trigger): """ def __init__( - self, - check_trigger: 'TriggerLike' = (1, 'epoch'), - monitor: str = 'main/loss', - patience: int = 3, - mode: str = 'auto', - verbose: bool = False, - max_trigger: Tuple[int, 'UnitLiteral'] = (100, 'epoch'), + self, + check_trigger: "TriggerLike" = (1, "epoch"), + monitor: str = "main/loss", + patience: int = 3, + mode: str = "auto", + verbose: bool = False, + max_trigger: Tuple[int, "UnitLiteral"] = (100, "epoch"), ) -> None: self.count = 0 self.patience = patience @@ -64,14 +67,14 @@ def __init__( self._init_summary() - if mode == 'max': + if mode == "max": self._compare = operator.gt - elif mode == 'min': + elif mode == "min": self._compare = operator.lt else: - if 'accuracy' in monitor: + if "accuracy" in monitor: self._compare = operator.gt else: @@ -79,13 +82,13 @@ def __init__( if self._compare == operator.gt: if verbose: - print('early stopping: operator is greater') - self.best = float('-inf') + print("early stopping: operator is greater") + self.best = float("-inf") else: if verbose: - print('early stopping: operator is less') - self.best = float('inf') + print("early stopping: operator is less") + self.best = float("inf") def __call__(self, manager: ExtensionsManagerProtocol) -> bool: """Decides whether the training loop should be stopped. @@ -114,7 +117,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: return False if self.monitor not in observation.keys(): - warnings.warn('{} is not in observation'.format(self.monitor)) + warnings.warn("{} is not in observation".format(self.monitor)) return False stat = self._summary.compute_mean() @@ -130,7 +133,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: if self._stop_condition(): if self.verbose: - print('Epoch {}: early stopping'.format(manager.epoch)) + print("Epoch {}: early stopping".format(manager.epoch)) return True return False diff --git a/pytorch_pfn_extras/training/triggers/interval_trigger.py b/pytorch_pfn_extras/training/triggers/interval_trigger.py index 1a1847a81..58dcad773 100644 --- a/pytorch_pfn_extras/training/triggers/interval_trigger.py +++ b/pytorch_pfn_extras/training/triggers/interval_trigger.py @@ -1,8 +1,9 @@ -from typing import Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple from pytorch_pfn_extras.training import trigger -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) if TYPE_CHECKING: from pytorch_pfn_extras.training._trigger_util import UnitLiteral @@ -30,10 +31,11 @@ class IntervalTrigger(trigger.Trigger): """ - def __init__(self, period: float, unit: 'UnitLiteral'): - if unit not in ('epoch', 'iteration'): + def __init__(self, period: float, unit: "UnitLiteral"): + if unit not in ("epoch", "iteration"): raise ValueError( - 'Trigger unit must be either \'epoch\' or \'iteration\'.') + "Trigger unit must be either 'epoch' or 'iteration'." + ) self.period = period self.unit = unit @@ -64,17 +66,17 @@ def __str__(self) -> str: Returns: str: IntervalTrigger(, '') """ - return '{}({}, \'{}\')'.format( + return "{}({}, '{}')".format( self.__class__.__name__, self.period, self.unit ) def may_fire(self, iteration: int, epoch_length: int) -> bool: if iteration == 0: - if self.unit == 'epoch': + if self.unit == "epoch": return epoch_length == 0 else: return self.period == 0 - if self.unit == 'epoch': + if self.unit == "epoch": fire = (iteration % (epoch_length * self.period)) == 0 else: fire = (iteration % self.period) == 0 diff --git a/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py b/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py index c9662d1f0..4914755ee 100644 --- a/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py +++ b/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py @@ -1,8 +1,9 @@ -from typing import Sequence, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Sequence, Union from pytorch_pfn_extras.training import trigger -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) if TYPE_CHECKING: from pytorch_pfn_extras.training._trigger_util import UnitLiteral @@ -27,12 +28,15 @@ class ManualScheduleTrigger(trigger.Trigger): """ - def __init__(self, points: Union[float, Sequence[float]], unit: 'UnitLiteral'): - if unit not in ('epoch', 'iteration'): + def __init__( + self, points: Union[float, Sequence[float]], unit: "UnitLiteral" + ): + if unit not in ("epoch", "iteration"): raise ValueError( - 'Trigger unit must be either \'epoch\' or \'iteration\'.') + "Trigger unit must be either 'epoch' or 'iteration'." + ) - self.points = (points if isinstance(points, list) else [points]) + self.points = points if isinstance(points, list) else [points] self.unit = unit def __call__(self, manager: ExtensionsManagerProtocol) -> bool: @@ -53,9 +57,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: return fire def may_fire(self, iteration: int, epoch_length: int) -> bool: - if self.unit == 'epoch': - fire = any( - int(p * epoch_length) == iteration for p in self.points) + if self.unit == "epoch": + fire = any(int(p * epoch_length) == iteration for p in self.points) else: fire = any(p == iteration for p in self.points) return fire diff --git a/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py b/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py index d0e74902c..dcf286f5f 100644 --- a/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py +++ b/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from pytorch_pfn_extras import reporting from pytorch_pfn_extras.training import trigger as trigger_module -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol - +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) if TYPE_CHECKING: from pytorch_pfn_extras.training._trigger_util import TriggerLike @@ -26,10 +27,10 @@ class BestValueTrigger(trigger_module.Trigger): """ def __init__( - self, - key: str, - compare: Callable[[float, float], bool], - trigger: 'TriggerLike' = (1, 'epoch'), + self, + key: str, + compare: Callable[[float, float], bool], + trigger: "TriggerLike" = (1, "epoch"), ) -> None: self._key = key self._best_value: Optional[float] = None @@ -63,9 +64,12 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: stats = summary.compute_mean() if key not in stats: - raise KeyError('Key "{}" not found in the observation ' - '(current available keys are {})' - .format(key, list(stats.keys()))) + raise KeyError( + 'Key "{}" not found in the observation ' + "(current available keys are {})".format( + key, list(stats.keys()) + ) + ) value = float(stats[key]) # copy to CPU self._init_summary() @@ -79,15 +83,17 @@ def _init_summary(self) -> None: self._summary = reporting.DictSummary() def state_dict(self) -> Dict[str, Any]: - state = {'interval_trigger': self._interval_trigger.state_dict(), - '_summary': self._summary.state_dict(), - '_best_value': self._best_value} + state = { + "interval_trigger": self._interval_trigger.state_dict(), + "_summary": self._summary.state_dict(), + "_best_value": self._best_value, + } return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - self._interval_trigger.load_state_dict(to_load['interval_trigger']) - self._summary.load_state_dict(to_load['_summary']) - self._best_value = to_load['_best_value'] + self._interval_trigger.load_state_dict(to_load["interval_trigger"]) + self._summary.load_state_dict(to_load["_summary"]) + self._best_value = to_load["_best_value"] def may_fire(self, iteration: int, epoch_length: int) -> bool: return self._interval_trigger.may_fire(iteration, epoch_length) @@ -110,9 +116,10 @@ class MaxValueTrigger(BestValueTrigger): """ - def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')): + def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")): super().__init__( - key, lambda max_value, new_value: new_value > max_value, trigger) + key, lambda max_value, new_value: new_value > max_value, trigger + ) class MinValueTrigger(BestValueTrigger): @@ -132,6 +139,7 @@ class MinValueTrigger(BestValueTrigger): """ - def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')): + def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")): super().__init__( - key, lambda min_value, new_value: new_value < min_value, trigger) + key, lambda min_value, new_value: new_value < min_value, trigger + ) diff --git a/pytorch_pfn_extras/training/triggers/once_trigger.py b/pytorch_pfn_extras/training/triggers/once_trigger.py index 538ab3ee0..011ff0481 100644 --- a/pytorch_pfn_extras/training/triggers/once_trigger.py +++ b/pytorch_pfn_extras/training/triggers/once_trigger.py @@ -1,7 +1,9 @@ from typing import Any, Dict from pytorch_pfn_extras.training import trigger -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) class OnceTrigger(trigger.Trigger): @@ -38,11 +40,11 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: return fire def state_dict(self) -> Dict[str, Any]: - state = {'_flag_first': self._flag_first} + state = {"_flag_first": self._flag_first} return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - self._flag_first = to_load['_flag_first'] + self._flag_first = to_load["_flag_first"] def may_fire(self, iteration: int, epoch_length: int) -> bool: return not (self._flag_first or self._flag_resumed) diff --git a/pytorch_pfn_extras/training/triggers/time_trigger.py b/pytorch_pfn_extras/training/triggers/time_trigger.py index 4d74100e4..e4006ac05 100644 --- a/pytorch_pfn_extras/training/triggers/time_trigger.py +++ b/pytorch_pfn_extras/training/triggers/time_trigger.py @@ -1,7 +1,9 @@ from typing import Any, Dict from pytorch_pfn_extras.training import trigger -from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol +from pytorch_pfn_extras.training._manager_protocol import ( + ExtensionsManagerProtocol, +) class TimeTrigger(trigger.Trigger): @@ -27,8 +29,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool: return False def state_dict(self) -> Dict[str, Any]: - state = {'next_time': self._next_time} + state = {"next_time": self._next_time} return state def load_state_dict(self, to_load: Dict[str, Any]) -> None: - self._next_time = to_load['next_time'] + self._next_time = to_load["next_time"] diff --git a/pytorch_pfn_extras/utils/checkpoint.py b/pytorch_pfn_extras/utils/checkpoint.py index 685843183..0af2efdd4 100644 --- a/pytorch_pfn_extras/utils/checkpoint.py +++ b/pytorch_pfn_extras/utils/checkpoint.py @@ -16,20 +16,22 @@ class _CheckpointFunction(torch.utils.checkpoint.CheckpointFunction): can help deal with incorrect values in the BatchNormalization persistent parameters. """ + @staticmethod def forward( # type: ignore[override] - ctx: Any, - run_function: Any, - preserve_rng_state: bool, - *args: Any, + ctx: Any, + run_function: Any, + preserve_rng_state: bool, + *args: Any, ) -> Any: _patch_bn_momentum(run_function) return super(_CheckpointFunction, _CheckpointFunction).forward( - ctx, run_function, preserve_rng_state, *args) + ctx, run_function, preserve_rng_state, *args + ) def _patch_bn_momentum(module: torch.nn.Module) -> None: - if not hasattr(module, '_bn_momentum_patched'): + if not hasattr(module, "_bn_momentum_patched"): if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): return if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): @@ -43,8 +45,9 @@ def _patch_bn_momentum(module: torch.nn.Module) -> None: def checkpoint(function: torch.nn.Module, *args: Any, **kwargs: Any) -> Any: # Hack to mix *args with **kwargs in a python 2.7-compliant way - preserve = kwargs.pop('preserve_rng_state', True) + preserve = kwargs.pop("preserve_rng_state", True) if kwargs: raise ValueError( - 'Unexpected keyword arguments: ' + ','.join(arg for arg in kwargs)) + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) return _CheckpointFunction.apply(function, preserve, *args) # type: ignore[no-untyped-call] diff --git a/pytorch_pfn_extras/utils/comparer.py b/pytorch_pfn_extras/utils/comparer.py index aeaa091d1..30afe1650 100644 --- a/pytorch_pfn_extras/utils/comparer.py +++ b/pytorch_pfn_extras/utils/comparer.py @@ -1,40 +1,46 @@ import collections +import concurrent.futures import pathlib import re import threading import weakref -import concurrent.futures from typing import ( - Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, - Tuple, Type, Union, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, ) +import pytorch_pfn_extras import torch.nn import torch.testing - -import pytorch_pfn_extras from pytorch_pfn_extras import handler as _handler_module from pytorch_pfn_extras.handler import _logic -from pytorch_pfn_extras.training import _trainer +from pytorch_pfn_extras.training import _evaluator, _trainer from pytorch_pfn_extras.training import manager as manager_module -from pytorch_pfn_extras.training import _evaluator from pytorch_pfn_extras.training import trigger as trigger_module - _thread_local = threading.local() _intermediate_prefix = "intermedaite:" class _ComparableHandler(_handler_module.BaseHandler): def __init__( - self, - handler: _handler_module.BaseHandler, - name: str, - get_target_cb: Any, - compare_cb: Any, - trigger: Optional[trigger_module.Trigger] = None, - *, - dir: Optional[str] = None, + self, + handler: _handler_module.BaseHandler, + name: str, + get_target_cb: Any, + compare_cb: Any, + trigger: Optional[trigger_module.Trigger] = None, + *, + dir: Optional[str] = None, ) -> None: self._handler = handler self._get_target_cb = get_target_cb @@ -52,7 +58,8 @@ def train_setup(self, trainer: _trainer._Trainer, loader: Any) -> None: self._handler.train_setup(trainer, loader) def train_epoch_begin( - self, trainer: _trainer._Trainer, loader: Iterable[Any]) -> None: + self, trainer: _trainer._Trainer, loader: Iterable[Any] + ) -> None: self._handler.train_epoch_begin(trainer, loader) _thread_local.handler = self self._epoch = trainer.epoch @@ -61,31 +68,33 @@ def train_epoch_end(self, trainer: _trainer._Trainer) -> None: self._handler.train_epoch_end(trainer) def train_validation_begin( - self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator) -> None: + self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator + ) -> None: self._handler.train_validation_begin(trainer, evaluator) _thread_local.handler = self def train_validation_end( - self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator) -> None: + self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator + ) -> None: self._handler.train_validation_end(trainer, evaluator) def train_step( - self, - trainer: _trainer._Trainer, - batch_idx: int, - batch: Any, - complete_fn: Callable[[int, Any], None], + self, + trainer: _trainer._Trainer, + batch_idx: int, + batch: Any, + complete_fn: Callable[[int, Any], None], ) -> None: self._batch_idx = batch_idx self._reset_intermediate_values() self._handler.train_step(trainer, batch_idx, batch, complete_fn) def train_post_step( - self, - trainer: _trainer.Trainer, - batch_idx: int, - batch: Any, - outputs: Any + self, + trainer: _trainer.Trainer, + batch_idx: int, + batch: Any, + outputs: Any, ) -> None: class _ManagerProxy(manager_module._ManagerProxy): @property @@ -100,7 +109,8 @@ def iteration(self) -> int: self._compare(trainer, batch_idx, outputs) def eval_setup( - self, evaluator: _evaluator.Evaluator, loader: Iterable[Any]) -> None: + self, evaluator: _evaluator.Evaluator, loader: Iterable[Any] + ) -> None: self._handler.eval_setup(evaluator, loader) def eval_loop_begin(self, evaluator: _evaluator.Evaluator) -> None: @@ -108,11 +118,11 @@ def eval_loop_begin(self, evaluator: _evaluator.Evaluator) -> None: _thread_local.handler = self def eval_step( - self, - evaluator: _evaluator.Evaluator, - batch_idx: int, - batch: Any, - complete_fn: Callable[[int, Any], None], + self, + evaluator: _evaluator.Evaluator, + batch_idx: int, + batch: Any, + complete_fn: Callable[[int, Any], None], ) -> None: self._batch_idx = batch_idx self._reset_intermediate_values() @@ -122,11 +132,11 @@ def eval_loop_end(self, evaluator: _evaluator.Evaluator) -> None: self._handler.eval_loop_end(evaluator) def eval_post_step( - self, - evaluator: _evaluator.Evaluator, - batch_idx: int, - batch: Any, - outputs: Any, + self, + evaluator: _evaluator.Evaluator, + batch_idx: int, + batch: Any, + outputs: Any, ) -> None: self._handler.eval_post_step(evaluator, batch_idx, batch, outputs) self._compare(evaluator, batch_idx, outputs) @@ -146,12 +156,13 @@ def _add_intermediate_value(self, name: str, value: torch.Tensor) -> None: values = self._intermediate_values[self._batch_idx] counts = self._intermediate_counts[self._batch_idx] value = value.detach() - count = counts.get('name', 0) - counts['name'] = count + 1 - name = _intermediate_prefix + name + f'_{count}' + count = counts.get("name", 0) + counts["name"] = count + 1 + name = _intermediate_prefix + name + f"_{count}" # Defer importing onnx for performance and to avoid linkage issues. import pytorch_pfn_extras.onnx + if pytorch_pfn_extras.onnx.available: pytorch_pfn_extras.onnx.as_output(name, value) values[name] = value @@ -159,10 +170,12 @@ def _add_intermediate_value(self, name: str, value: torch.Tensor) -> None: def _overwrite_handler(engine: Any, *args: Any, **kwargs: Any) -> None: engine.handler = _ComparableHandler(engine.handler, *args, **kwargs) - evaluator = getattr(engine, 'evaluator', None) + evaluator = getattr(engine, "evaluator", None) if evaluator is not None: # For trainer with evaluator - evaluator.handler = _ComparableHandler(evaluator.handler, *args, **kwargs) + evaluator.handler = _ComparableHandler( + evaluator.handler, *args, **kwargs + ) _CompareFn = Callable[[str, str, str, Any, Any], None] @@ -170,9 +183,9 @@ def _overwrite_handler(engine: Any, *args: Any, **kwargs: Any) -> None: def get_default_comparer( - rtol: float = 1e-04, - atol: float = 0, - equal_nan: bool = True, + rtol: float = 1e-04, + atol: float = 0, + equal_nan: bool = True, ) -> _CompareFn: """Creates default comparer function. @@ -184,9 +197,10 @@ def get_default_comparer( atol (float): Absolute tolerance. equal_nan (bool): If ``True``, NaNs will be ignored. """ + def compare_fn( - backend1: str, backend2: str, name: str, - val1: Any, val2: Any) -> None: + backend1: str, backend2: str, name: str, val1: Any, val2: Any + ) -> None: # TODO select the device where # the tensors will be compared? if isinstance(val1, torch.Tensor): @@ -208,17 +222,17 @@ def compare_fn( def _compare_targets( - compare_fn: _CompareFn, - targets: Dict[str, Any], - baseline: Optional[str], - batch_idx: int, + compare_fn: _CompareFn, + targets: Dict[str, Any], + baseline: Optional[str], + batch_idx: int, ) -> None: names = list(targets.keys()) if baseline is None: baseline = names[0] keys = sorted(targets[baseline].keys()) - err_msg = '' + err_msg = "" for backend in set(names) - set([baseline]): for val_name in keys: out1 = targets[baseline][val_name] @@ -228,18 +242,19 @@ def _compare_targets( except AssertionError as e: err_msg += ( f"Comparing '{baseline}' and '{backend}' in '{val_name}'\n" - f"{str(e)}\n") + f"{str(e)}\n" + ) if err_msg: - raise AssertionError(f'Batch: {batch_idx}\n' + str(err_msg)) + raise AssertionError(f"Batch: {batch_idx}\n" + str(err_msg)) class _ComparerBase: def __init__( - self, - engines: Mapping[str, _Engine], - *, - compare_fn: _CompareFn = _default_comparer, - concurrency: Optional[int] = None, + self, + engines: Mapping[str, _Engine], + *, + compare_fn: _CompareFn = _default_comparer, + concurrency: Optional[int] = None, ) -> None: e_type = type(next(iter(engines.values()))) if e_type not in ( @@ -257,16 +272,19 @@ def __init__( self.compare_fn = compare_fn self._finalized = False self._semaphore = threading.Semaphore( - len(engines) if concurrency is None else concurrency) + len(engines) if concurrency is None else concurrency + ) self.targets: Dict[str, Dict[str, Any]] = {} self._iters: Dict[str, int] = {} # engines must be a dict for name, engine in engines.items(): - _overwrite_handler(engine, name, self._get_target, self.compare_targets) + _overwrite_handler( + engine, name, self._get_target, self.compare_targets + ) def _assert_incompatible_trigger(self, condition: bool) -> None: if not condition: - raise ValueError('Engines have different triggers.') + raise ValueError("Engines have different triggers.") def run_engine(self, engine: _Engine, loaders: Any) -> None: try: @@ -307,26 +325,27 @@ def compare(self, loaders: Any, n_iters: Optional[int] = None) -> None: ) as executor: futures = [] for name, engine in self.engines.items(): - futures.append(executor.submit( - self.run_engine, engine, loaders[name])) + futures.append( + executor.submit(self.run_engine, engine, loaders[name]) + ) for future in concurrent.futures.as_completed(futures): future.result() def _get_target( - self, - handle: _handler_module.BaseHandler, - engine: _Engine, - batch_idx: int, - outputs: Any, + self, + handle: _handler_module.BaseHandler, + engine: _Engine, + batch_idx: int, + outputs: Any, ) -> Dict[str, Any]: - raise NotImplementedError('Comparers must override _get_target') + raise NotImplementedError("Comparers must override _get_target") def compare_targets( - self, - name: str, - engine: _Engine, - batch_idx: int, - target: Dict[str, Any], + self, + name: str, + engine: _Engine, + batch_idx: int, + target: Dict[str, Any], ) -> None: self._iters[name] += 1 if (self.n_iters is None) or (self._iters[name] % self.n_iters == 0): @@ -335,7 +354,9 @@ def compare_targets( self.targets[name] = target if len(self.targets.keys()) == len(self.engines.keys()): # all outputs have been filled, lets compare and reset - _compare_targets(self.compare_fn, self.targets, None, batch_idx) + _compare_targets( + self.compare_fn, self.targets, None, batch_idx + ) self.targets = {} self._assert_incompatible_trigger(not self._finalized) # Excplicitly synchronize @@ -346,12 +367,12 @@ def compare_targets( class OutputsComparer(_ComparerBase): def __init__( - self, - engines: Mapping[str, _Engine], - to_compare_keys: Optional[Sequence[str]] = None, - *, - compare_fn: _CompareFn = _default_comparer, - concurrency: Optional[int] = None, + self, + engines: Mapping[str, _Engine], + to_compare_keys: Optional[Sequence[str]] = None, + *, + compare_fn: _CompareFn = _default_comparer, + concurrency: Optional[int] = None, ) -> None: """A class for comparison of iteration outputs. @@ -378,15 +399,17 @@ def __init__( >>> comp.compare({"cpu": loader, "gpu": loader}]) """ # If to_compare_key is None, then we compare all - super().__init__(engines, compare_fn=compare_fn, concurrency=concurrency) + super().__init__( + engines, compare_fn=compare_fn, concurrency=concurrency + ) self.to_compare_keys = to_compare_keys def _get_target( - self, - handle: _handler_module.BaseHandler, - engine: _Engine, - batch_idx: int, - outputs: Any, + self, + handle: _handler_module.BaseHandler, + engine: _Engine, + batch_idx: int, + outputs: Any, ) -> Dict[str, Any]: keys = ( self.to_compare_keys @@ -398,12 +421,12 @@ def _get_target( class ModelComparer(_ComparerBase): def __init__( - self, - engines: Mapping[str, _Engine], - to_compare_keys: Optional[Sequence[str]] = None, - *, - compare_fn: _CompareFn = _default_comparer, - concurrency: Optional[int] = None, + self, + engines: Mapping[str, _Engine], + to_compare_keys: Optional[Sequence[str]] = None, + *, + compare_fn: _CompareFn = _default_comparer, + concurrency: Optional[int] = None, ): """A class for comparison of iteration model parameters. @@ -430,7 +453,9 @@ def __init__( >>> comp.compare({"cpu": loader, "gpu": loader}]) """ # If to_compare_key is None, then we compare all - super().__init__(engines, compare_fn=compare_fn, concurrency=concurrency) + super().__init__( + engines, compare_fn=compare_fn, concurrency=concurrency + ) self.to_compare_keys = to_compare_keys self._preprocessed_keys: Optional[List[str]] = None @@ -447,16 +472,17 @@ def _preprocess_keys(self, sdict: Dict[str, Any]) -> None: matched = True if not matched: raise ValueError( - f'didnt find a match for {tc_k} in the model') + f"didnt find a match for {tc_k} in the model" + ) def _get_target( - self, - handle: _handler_module.BaseHandler, - engine: _Engine, - batch_idx: int, - outputs: Any, + self, + handle: _handler_module.BaseHandler, + engine: _Engine, + batch_idx: int, + outputs: Any, ) -> Dict[str, Any]: - sdict = engine.models['main'].state_dict() + sdict = engine.models["main"].state_dict() if self._preprocessed_keys is None: self._preprocess_keys(sdict) assert self._preprocessed_keys is not None @@ -465,9 +491,10 @@ def _get_target( # New comparer interface + def _filter( - keys: Union[bool, str, Sequence[str]], - get_dict: Callable[[], Dict[str, Any]], + keys: Union[bool, str, Sequence[str]], + get_dict: Callable[[], Dict[str, Any]], ) -> Dict[str, Any]: if keys is False: return {} @@ -487,23 +514,22 @@ def _filter( ret[sd_k] = sdict[sd_k] break else: - raise ValueError(f'didnt find a match for {tc_k} in the model') + raise ValueError(f"didnt find a match for {tc_k} in the model") return ret - raise ValueError(f'Unsupported type: {type(keys)}') + raise ValueError(f"Unsupported type: {type(keys)}") class Comparer: - def __init__( - self, - *, - trigger: Optional[trigger_module.TriggerLike] = None, - compare_fn: _CompareFn = _default_comparer, - concurrency: Optional[int] = None, - outputs: Union[bool, str, Sequence[str]] = True, - params: Union[bool, str, Sequence[str]] = False, - baseline: Optional[str] = None, + self, + *, + trigger: Optional[trigger_module.TriggerLike] = None, + compare_fn: _CompareFn = _default_comparer, + concurrency: Optional[int] = None, + outputs: Union[bool, str, Sequence[str]] = True, + params: Union[bool, str, Sequence[str]] = False, + baseline: Optional[str] = None, ) -> None: """A class for comparison of iteration outputs and model parameters. @@ -559,21 +585,24 @@ def __init__( self._trigger = trigger_module.get_trigger(trigger) def _get_target( - self, - handler: _ComparableHandler, - engine: _Engine, - batch_idx: int, - outputs: Dict[str, Any], + self, + handler: _ComparableHandler, + engine: _Engine, + batch_idx: int, + outputs: Dict[str, Any], ) -> Dict[str, Any]: targets = {} outputs = _filter(self._output_keys, lambda: outputs) - targets.update({ - k if k.startswith(_intermediate_prefix) else 'output:' + k: v - for k, v in outputs.items()}) + targets.update( + { + k if k.startswith(_intermediate_prefix) else "output:" + k: v + for k, v in outputs.items() + } + ) targets.update(handler._intermediate_values.pop(batch_idx)) - params = _filter(self._param_keys, engine.models['main'].state_dict) - targets.update({'param:' + k: v for k, v in params.items()}) + params = _filter(self._param_keys, engine.models["main"].state_dict) + targets.update({"param:" + k: v for k, v in params.items()}) return targets def _assert_incompatible_trigger(self, condition: bool) -> None: @@ -581,22 +610,23 @@ def _assert_incompatible_trigger(self, condition: bool) -> None: raise ValueError("Engines have different triggers.") def _get_filename( - self, engine: _Engine, handler_name: str, batch_idx: int) -> str: - name = f'dump_{self._count:08}' - name += '_' + type(engine).__name__ + self, engine: _Engine, handler_name: str, batch_idx: int + ) -> str: + name = f"dump_{self._count:08}" + name += "_" + type(engine).__name__ orig_engine, _, _ = self._engines[handler_name] epoch = orig_engine.handler._epoch # type: ignore if epoch is not None: - name += f'_epoch_{epoch}' - name += f'_iter_{batch_idx}' + name += f"_epoch_{epoch}" + name += f"_iter_{batch_idx}" return name def _compare_targets( - self, - name: str, - engine: _Engine, - batch_idx: int, - target: Dict[str, Any], + self, + name: str, + engine: _Engine, + batch_idx: int, + target: Dict[str, Any], ) -> None: # Save the outputs of this iteration with self._report_lock: @@ -604,7 +634,8 @@ def _compare_targets( if len(self._targets.keys()) == len(self._engines.keys()): # all outputs have been filled, lets compare and reset _compare_targets( - self._compare_fn, self._targets, self._baseline, batch_idx) + self._compare_fn, self._targets, self._baseline, batch_idx + ) self._targets = {} self._count += 1 self._assert_incompatible_trigger(not self._finalized) @@ -619,11 +650,11 @@ def _compare_targets( self._semaphore.acquire() def add_engine( - self, - name: str, - engine: _Engine, - *args: Any, - **kwargs: Any, + self, + name: str, + engine: _Engine, + *args: Any, + **kwargs: Any, ) -> None: """Add an engine to compare variables. @@ -649,7 +680,8 @@ def add_engine( raise ValueError(f"Engine named {name} already registered") _overwrite_handler( - engine, name, self._get_target, self._compare_targets, self._trigger) + engine, name, self._get_target, self._compare_targets, self._trigger + ) self._engines[name] = engine, args, kwargs @@ -666,18 +698,20 @@ def add_dump(self, name: str, dir: str) -> None: self._engines[name] = engine, (), {} def _dump_targets( - self, - name: str, - engine: _Engine, - batch_idx: int, - target: Dict[str, Any], + self, + name: str, + engine: _Engine, + batch_idx: int, + target: Dict[str, Any], ) -> None: name = self._get_filename(engine, name, batch_idx) assert isinstance(engine.handler, _ComparableHandler) - torch.save(target, f'{engine.handler._dir}/{name}') + torch.save(target, f"{engine.handler._dir}/{name}") self._count += 1 - def dump(self, engine: _Engine, dir: str, *args: Any, **kwargs: Any) -> None: + def dump( + self, engine: _Engine, dir: str, *args: Any, **kwargs: Any + ) -> None: """Add an engine to compare variables. Args: @@ -691,10 +725,15 @@ def dump(self, engine: _Engine, dir: str, *args: Any, **kwargs: Any) -> None: pathlib.Path(dir).mkdir(parents=True, exist_ok=True) self._count = 0 - name = '__dump' + name = "__dump" _overwrite_handler( - engine, name, self._get_target, self._dump_targets, - self._trigger, dir=dir) + engine, + name, + self._get_target, + self._dump_targets, + self._trigger, + dir=dir, + ) self._engines[name] = engine, args, {} engine.run(*args, **kwargs) @@ -719,14 +758,16 @@ def _run_engine(self, engine: _Engine, args: Any, kwargs: Any) -> None: self._semaphore.release() def compare(self) -> None: - """Compares outputs. - """ + """Compares outputs.""" self._count = 0 n_workers = len(self._engines) self._barrier = threading.Barrier(n_workers) self._semaphore = threading.Semaphore( - n_workers if self._concurrency is None else self._concurrency) - with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor: + n_workers if self._concurrency is None else self._concurrency + ) + with concurrent.futures.ThreadPoolExecutor( + max_workers=n_workers + ) as executor: futures = [] for _, (engine, args, kwargs) in self._engines.items(): futures.append(executor.submit(self._run_engine, engine, args, kwargs)) # type: ignore[arg-type] @@ -735,7 +776,7 @@ def compare(self) -> None: def intermediate_value(name: str, value: torch.Tensor) -> None: - if not hasattr(_thread_local, 'handler'): + if not hasattr(_thread_local, "handler"): return _thread_local.handler._add_intermediate_value(name, value) @@ -751,10 +792,11 @@ def run(self) -> None: assert comparer is not None for path in sorted(pathlib.Path(self._dir).iterdir()): filename = path.name - if filename.startswith('dump_'): + if filename.startswith("dump_"): target = torch.load(path) # type: ignore[no-untyped-call] - iter_str = '_iter_' + iter_str = "_iter_" pos = filename.find(iter_str) - batch_idx = int(filename[pos + len(iter_str):]) + batch_idx = int(filename[pos + len(iter_str) :]) comparer._compare_targets( - self.name, None, batch_idx, target) # type: ignore + self.name, None, batch_idx, target # type: ignore + ) diff --git a/pytorch_pfn_extras/writing/__init__.py b/pytorch_pfn_extras/writing/__init__.py index 3ae2e3f9f..4e8411ef2 100644 --- a/pytorch_pfn_extras/writing/__init__.py +++ b/pytorch_pfn_extras/writing/__init__.py @@ -1,9 +1,11 @@ -from pytorch_pfn_extras.writing._writer_base import Writer # NOQA -from pytorch_pfn_extras.writing._writer_base import StandardWriter # NOQA -from pytorch_pfn_extras.writing._simple_writer import SimpleWriter # NOQA -from pytorch_pfn_extras.writing._parallel_writer import ThreadWriter # NOQA from pytorch_pfn_extras.writing._parallel_writer import ProcessWriter # NOQA +from pytorch_pfn_extras.writing._parallel_writer import ThreadWriter # NOQA +from pytorch_pfn_extras.writing._queue_writer import ProcessQueueWriter # NOQA from pytorch_pfn_extras.writing._queue_writer import QueueWriter # NOQA from pytorch_pfn_extras.writing._queue_writer import ThreadQueueWriter # NOQA -from pytorch_pfn_extras.writing._queue_writer import ProcessQueueWriter # NOQA -from pytorch_pfn_extras.writing._tensorboard_writer import TensorBoardWriter # NOQA +from pytorch_pfn_extras.writing._simple_writer import SimpleWriter # NOQA +from pytorch_pfn_extras.writing._tensorboard_writer import ( # NOQA + TensorBoardWriter, +) +from pytorch_pfn_extras.writing._writer_base import StandardWriter # NOQA +from pytorch_pfn_extras.writing._writer_base import Writer # NOQA diff --git a/pytorch_pfn_extras/writing/_parallel_writer.py b/pytorch_pfn_extras/writing/_parallel_writer.py index e40f37b3a..2c784e15a 100644 --- a/pytorch_pfn_extras/writing/_parallel_writer.py +++ b/pytorch_pfn_extras/writing/_parallel_writer.py @@ -1,12 +1,14 @@ import multiprocessing -import threading import sys +import threading from typing import Any, Optional import torch - from pytorch_pfn_extras.writing._writer_base import ( - StandardWriter, _TargetType, _SaveFun, _FileSystem, + StandardWriter, + _FileSystem, + _SaveFun, + _TargetType, ) @@ -21,47 +23,51 @@ class ThreadWriter(StandardWriter[threading.Thread]): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - **kwds: Any + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + **kwds: Any, ) -> None: super().__init__(savefun=savefun, fs=fs, out_dir=out_dir, **kwds) def _save_with_exitcode( - self, - filename: str, - out_dir: str, - target: _TargetType, - savefun: _SaveFun, - append: bool, - **savefun_kwargs: Any, + self, + filename: str, + out_dir: str, + target: _TargetType, + savefun: _SaveFun, + append: bool, + **savefun_kwargs: Any, ) -> None: try: self.save( - filename, out_dir, target, savefun, append, **savefun_kwargs) + filename, out_dir, target, savefun, append, **savefun_kwargs + ) except Exception as e: thread = threading.current_thread() thread.exitcode = -1 # type: ignore[attr-defined] print( f'Error: ThreadWriter failed in thread "{thread.name}": ' - f'{type(e).__name__}: {str(e)}', file=sys.stderr) + f"{type(e).__name__}: {str(e)}", + file=sys.stderr, + ) def create_worker( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False, - **savefun_kwargs: Any, + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, + **savefun_kwargs: Any, ) -> threading.Thread: return threading.Thread( target=self._save_with_exitcode, args=(filename, out_dir, target, savefun, append), - kwargs=savefun_kwargs) + kwargs=savefun_kwargs, + ) class ProcessWriter(StandardWriter[multiprocessing.Process]): @@ -80,25 +86,26 @@ class ProcessWriter(StandardWriter[multiprocessing.Process]): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - **kwds: Any, + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + **kwds: Any, ) -> None: super().__init__(savefun=savefun, fs=fs, out_dir=out_dir, **kwds) def create_worker( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False, - **savefun_kwargs: Any, + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, + **savefun_kwargs: Any, ) -> multiprocessing.Process: return multiprocessing.Process( target=self.save, args=(filename, out_dir, target, savefun, append), - kwargs=savefun_kwargs) + kwargs=savefun_kwargs, + ) diff --git a/pytorch_pfn_extras/writing/_queue_writer.py b/pytorch_pfn_extras/writing/_queue_writer.py index 885ad2d4a..b4348570b 100644 --- a/pytorch_pfn_extras/writing/_queue_writer.py +++ b/pytorch_pfn_extras/writing/_queue_writer.py @@ -4,15 +4,19 @@ from typing import Generic, Optional, Tuple import torch - +from pytorch_pfn_extras.writing._simple_writer import SimpleWriter from pytorch_pfn_extras.writing._writer_base import ( - Writer, _TargetType, _SaveFun, _TaskFun, _Worker, _FileSystem, + Writer, + _FileSystem, + _SaveFun, + _TargetType, + _TaskFun, + _Worker, ) -from pytorch_pfn_extras.writing._simple_writer import SimpleWriter - -_QueUnit = Optional[Tuple[ - _TaskFun, str, str, _TargetType, Optional[_SaveFun], bool]] +_QueUnit = Optional[ + Tuple[_TaskFun, str, str, _TargetType, Optional[_SaveFun], bool] +] class QueueWriter(Writer, Generic[_Worker]): @@ -41,11 +45,11 @@ class QueueWriter(Writer, Generic[_Worker]): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - task: Optional[_TaskFun] = None, + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + task: Optional[_TaskFun] = None, ) -> None: super().__init__(fs=fs, out_dir=out_dir) self._started = False @@ -60,28 +64,29 @@ def __init__( self._started = True def __call__( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, ) -> None: assert not self._finalized self._queue.put( - (self._task, filename, out_dir, target, savefun, append)) + (self._task, filename, out_dir, target, savefun, append) + ) def create_task(self, savefun: _SaveFun) -> _TaskFun: return SimpleWriter(savefun=savefun) - def create_queue(self) -> 'queue.Queue[_QueUnit]': + def create_queue(self) -> "queue.Queue[_QueUnit]": raise NotImplementedError - def create_consumer(self, q: 'queue.Queue[_QueUnit]') -> _Worker: + def create_consumer(self, q: "queue.Queue[_QueUnit]") -> _Worker: raise NotImplementedError - def consume(self, q: 'queue.Queue[_QueUnit]') -> None: + def consume(self, q: "queue.Queue[_QueUnit]") -> None: while True: task = q.get() if task is None: @@ -89,7 +94,8 @@ def consume(self, q: 'queue.Queue[_QueUnit]') -> None: return else: task[0]( - task[1], task[2], task[3], savefun=task[4], append=task[5]) + task[1], task[2], task[3], savefun=task[4], append=task[5] + ) q.task_done() def finalize(self) -> None: @@ -116,18 +122,18 @@ class ThreadQueueWriter(QueueWriter[threading.Thread]): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - task: Optional[_TaskFun] = None + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + task: Optional[_TaskFun] = None, ) -> None: super().__init__(savefun=savefun, fs=fs, task=task, out_dir=out_dir) - def create_queue(self) -> 'queue.Queue[_QueUnit]': + def create_queue(self) -> "queue.Queue[_QueUnit]": return queue.Queue() - def create_consumer(self, q: 'queue.Queue[_QueUnit]') -> threading.Thread: + def create_consumer(self, q: "queue.Queue[_QueUnit]") -> threading.Thread: return threading.Thread(target=self.consume, args=(q,)) @@ -149,16 +155,18 @@ class ProcessQueueWriter(QueueWriter[multiprocessing.Process]): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - task: Optional[_TaskFun] = None + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + task: Optional[_TaskFun] = None, ) -> None: super().__init__(savefun=savefun, fs=fs, out_dir=out_dir, task=task) - def create_queue(self) -> 'queue.Queue[_QueUnit]': + def create_queue(self) -> "queue.Queue[_QueUnit]": return multiprocessing.JoinableQueue() - def create_consumer(self, q: 'queue.Queue[_QueUnit]') -> multiprocessing.Process: + def create_consumer( + self, q: "queue.Queue[_QueUnit]" + ) -> multiprocessing.Process: return multiprocessing.Process(target=self.consume, args=(q,)) diff --git a/pytorch_pfn_extras/writing/_simple_writer.py b/pytorch_pfn_extras/writing/_simple_writer.py index 9d14a3f4e..b26db45a7 100644 --- a/pytorch_pfn_extras/writing/_simple_writer.py +++ b/pytorch_pfn_extras/writing/_simple_writer.py @@ -1,9 +1,11 @@ from typing import Any, Optional import torch - from pytorch_pfn_extras.writing._writer_base import ( - Writer, _TargetType, _SaveFun, _FileSystem + Writer, + _FileSystem, + _SaveFun, + _TargetType, ) @@ -29,24 +31,24 @@ class SimpleWriter(Writer): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - **kwds: Any, + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + **kwds: Any, ) -> None: super().__init__(fs=fs, out_dir=out_dir) self._savefun = savefun self._kwds = kwds def __call__( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, ) -> None: if savefun is None: savefun = self._savefun diff --git a/pytorch_pfn_extras/writing/_tensorboard_writer.py b/pytorch_pfn_extras/writing/_tensorboard_writer.py index aa99f5d22..8e004e972 100644 --- a/pytorch_pfn_extras/writing/_tensorboard_writer.py +++ b/pytorch_pfn_extras/writing/_tensorboard_writer.py @@ -1,13 +1,15 @@ -from typing import Any, KeysView, Optional import warnings +from typing import Any, KeysView, Optional from pytorch_pfn_extras.writing._writer_base import ( - _TargetType, _SaveFun, _FileSystem + _FileSystem, + _SaveFun, + _TargetType, ) class TensorBoardWriter(object): - """ Writer that sends statistics to TensorBoard. + """Writer that sends statistics to TensorBoard. This class contains a `torch.utils.tensorboard.SummaryWriter` object that is used to send the collected statistics to TensorBoard. @@ -20,38 +22,40 @@ class TensorBoardWriter(object): stats (list): List of statistic keys. kwds: Passed as an additional arguments to SummaryWriter. """ + def __init__( - self, - savefun: Optional[_SaveFun] = None, - fs: _FileSystem = None, - out_dir: str = '', - stats: Optional[KeysView[str]] = None, - **kwds: Any + self, + savefun: Optional[_SaveFun] = None, + fs: _FileSystem = None, + out_dir: str = "", + stats: Optional[KeysView[str]] = None, + **kwds: Any, ) -> None: self._writer = None try: import torch.utils.tensorboard except ImportError: warnings.warn( - 'tensorboard is unavailable. ' - 'TensorBoardWriter will do nothing.') + "tensorboard is unavailable. " + "TensorBoardWriter will do nothing." + ) return self._stats = stats - self._writer = ( - torch.utils.tensorboard.SummaryWriter( # type: ignore[no-untyped-call] - log_dir=out_dir, **kwds)) + self._writer = torch.utils.tensorboard.SummaryWriter( # type: ignore[no-untyped-call] + log_dir=out_dir, **kwds + ) def __del__(self) -> None: self.finalize() def __call__( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False, + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, ) -> None: """Sends the statistics to the TensorBoard. @@ -71,14 +75,15 @@ def __call__( stats_cpu = target[-1] if not isinstance(stats_cpu, dict): - raise TypeError('target must be dict or list of dicts') + raise TypeError("target must be dict or list of dicts") keys = stats_cpu.keys() if self._stats is not None: keys = self._stats # type: ignore[assignment] for key in keys: value = stats_cpu[key] self._writer.add_scalar( # type: ignore[no-untyped-call] - key, value, stats_cpu['iteration']) + key, value, stats_cpu["iteration"] + ) def finalize(self) -> None: if self._writer is not None: diff --git a/pytorch_pfn_extras/writing/_writer_base.py b/pytorch_pfn_extras/writing/_writer_base.py index b5d41a908..10f96c58d 100644 --- a/pytorch_pfn_extras/writing/_writer_base.py +++ b/pytorch_pfn_extras/writing/_writer_base.py @@ -1,23 +1,32 @@ -import multiprocessing import io +import multiprocessing import os import shutil import sys import threading import types from typing import ( - Any, Callable, Generic, IO, Iterator, List, Mapping, Optional, Sequence, - Type, TypeVar, Union, + IO, + Any, + Callable, + Generic, + Iterator, + List, + Mapping, + Optional, + Sequence, + Type, + TypeVar, + Union, ) import torch - _TargetType = Union[Sequence[Any], Mapping[str, Any]] _SaveFun = Callable[..., None] _HookFun = Callable[[], None] _TaskFun = Callable[..., None] -_Worker = TypeVar('_Worker', threading.Thread, multiprocessing.Process) +_Worker = TypeVar("_Worker", threading.Thread, multiprocessing.Process) _FileSystem = Any @@ -40,6 +49,7 @@ class _PosixFileSystem(object): This class currently abstracts POSIX """ + def __init__(self, root: Optional[str] = None) -> None: if root is None: self._root = os.getcwd() @@ -50,10 +60,11 @@ def get_actual_path(self, path: str) -> str: return os.path.join(self.root, path) def _wrap_fileobject( - self, - file_obj: IO[Any], - file_path: str, - *args: Any, **kwargs: Any, + self, + file_obj: IO[Any], + file_path: str, + *args: Any, + **kwargs: Any, ) -> IO[Any]: return file_obj @@ -66,35 +77,51 @@ def root(self, root: str) -> None: self._root = root def open( - self, - file_path: str, - mode: str = 'r', - buffering: int = -1, - encoding: Optional[str] = None, - errors: Optional[str] = None, - newline: Optional[str] = None, - closefd: bool = True, - opener: Optional[Callable[[str, int], int]] = None, + self, + file_path: str, + mode: str = "r", + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + closefd: bool = True, + opener: Optional[Callable[[str, int], int]] = None, ) -> IO[Any]: file_path = self.get_actual_path(file_path) - file_obj = io.open(file_path, mode, - buffering, encoding, errors, - newline, closefd, opener) + file_obj = io.open( + file_path, + mode, + buffering, + encoding, + errors, + newline, + closefd, + opener, + ) return self._wrap_fileobject( - file_obj, file_path, mode, buffering, encoding, - errors, newline, closefd, opener) + file_obj, + file_path, + mode, + buffering, + encoding, + errors, + newline, + closefd, + opener, + ) def list( - self, - path_or_prefix: Optional[str] = None, - recursive: bool = False, + self, + path_or_prefix: Optional[str] = None, + recursive: bool = False, ) -> Iterator[str]: if path_or_prefix is not None: path_or_prefix = self.get_actual_path(path_or_prefix) if recursive: if path_or_prefix is None: raise ValueError( - "'path_or_prefix' must not be none in recursive mode.") + "'path_or_prefix' must not be none in recursive mode." + ) path_or_prefix = path_or_prefix.rstrip("/") # plus 1 to include the trailing slash prefix_end_index = len(path_or_prefix) + 1 @@ -104,7 +131,9 @@ def list( yield file.name def _recursive_list( - self, prefix_end_index: int, path: str, + self, + prefix_end_index: int, + path: str, ) -> Iterator[str]: path = self.get_actual_path(path) for file in os.scandir(path): @@ -119,14 +148,14 @@ def stat(self, path: str) -> _PosixFileStat: def close(self) -> None: pass - def __enter__(self) -> '_PosixFileSystem': + def __enter__(self) -> "_PosixFileSystem": return self def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[types.TracebackType], + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], ) -> None: pass @@ -135,20 +164,17 @@ def isdir(self, file_path: str) -> bool: return os.path.isdir(path) def mkdir( - self, - file_path: str, - mode: int = 0o777, - *args: Any, - dir_fd: Optional[int] = None + self, + file_path: str, + mode: int = 0o777, + *args: Any, + dir_fd: Optional[int] = None, ) -> None: file_path = self.get_actual_path(file_path) return os.mkdir(file_path, mode, *args, dir_fd=dir_fd) def makedirs( - self, - file_path: str, - mode: int = 0o777, - exist_ok: bool = False + self, file_path: str, mode: int = 0o777, exist_ok: bool = False ) -> None: file_path = self.get_actual_path(file_path) return os.makedirs(file_path, mode, exist_ok) @@ -160,9 +186,11 @@ def rename(self, src: str, dst: str) -> None: try: return os.replace(src, dst) except OSError: - print('Destination {} is a directory ' - 'but source is not'.format(dst), - file=sys.stderr) + print( + "Destination {} is a directory " + "but source is not".format(dst), + file=sys.stderr, + ) raise def remove(self, file_path: str, recursive: bool = False) -> None: @@ -195,9 +223,9 @@ class Writer: """ def __init__( - self, - fs: _FileSystem = None, - out_dir: str = '', + self, + fs: _FileSystem = None, + out_dir: str = "", ) -> None: self._post_save_hooks: List[_HookFun] = [] self.fs = fs or _PosixFileSystem() @@ -205,13 +233,13 @@ def __init__( self._initialized = False def __call__( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, ) -> None: """Does the actual writing to the file. @@ -250,13 +278,13 @@ def finalize(self) -> None: pass def save( - self, - filename: str, - out_dir: str, - target: _TargetType, - savefun: _SaveFun, - append: bool, - **savefun_kwargs: Any, + self, + filename: str, + out_dir: str, + target: _TargetType, + savefun: _SaveFun, + append: bool, + **savefun_kwargs: Any, ) -> None: out_dir = self.out_dir if not self._initialized: @@ -265,19 +293,19 @@ def save( dest = os.path.join(out_dir, filename) if append: - with self.fs.open(dest, 'ab') as f: + with self.fs.open(dest, "ab") as f: # HDFS does not support overwrite savefun(target, f, **savefun_kwargs) else: # Some filesystems are not compatible with temp folders, etc # so we rely on raw temp files - prefix = 'tmp_{}'.format(filename) + prefix = "tmp_{}".format(filename) tmppath = os.path.join(out_dir, prefix) make_backup = self.fs.exists(dest) - with self.fs.open(tmppath, 'wb') as f: + with self.fs.open(tmppath, "wb") as f: savefun(target, f, **savefun_kwargs) if make_backup: - bak = '{}.bak'.format(dest) + bak = "{}.bak".format(dest) # Check if another backup file exists # due to some unexpected termination of an earlier # process @@ -331,11 +359,11 @@ class StandardWriter(Writer, Generic[_Worker]): """ def __init__( - self, - savefun: _SaveFun = torch.save, - fs: _FileSystem = None, - out_dir: str = '', - **kwds: Any, + self, + savefun: _SaveFun = torch.save, + fs: _FileSystem = None, + out_dir: str = "", + **kwds: Any, ) -> None: super().__init__(fs=fs, out_dir=out_dir) self._savefun = savefun @@ -345,13 +373,13 @@ def __init__( self._finalized = False def __call__( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, ) -> None: assert not self._finalized if savefun is None: @@ -360,21 +388,26 @@ def __call__( self.finalize() self._filename = filename self._worker = self.create_worker( - filename, out_dir, target, - savefun=savefun, append=append, **self._kwds) + filename, + out_dir, + target, + savefun=savefun, + append=append, + **self._kwds, + ) self._worker.start() self._started = True self._finalized = False def create_worker( - self, - filename: str, - out_dir: str, - target: _TargetType, - *, - savefun: Optional[_SaveFun] = None, - append: bool = False, - **savefun_kwargs: Any, + self, + filename: str, + out_dir: str, + target: _TargetType, + *, + savefun: Optional[_SaveFun] = None, + append: bool = False, + **savefun_kwargs: Any, ) -> _Worker: """Creates a worker for the snapshot. @@ -390,13 +423,13 @@ def finalize(self) -> None: return if self._worker is None: - raise RuntimeError('worker is not created') + raise RuntimeError("worker is not created") try: if self._started and not self._finalized: self._worker.join() - exitcode = getattr(self._worker, 'exitcode', 0) + exitcode = getattr(self._worker, "exitcode", 0) if exitcode != 0: - raise RuntimeError(f'exit code is non-zero: {exitcode}') + raise RuntimeError(f"exit code is non-zero: {exitcode}") finally: self._started = False self._finalized = True diff --git a/setup.py b/setup.py index 71b32251a..2868a27a0 100644 --- a/setup.py +++ b/setup.py @@ -1,25 +1,25 @@ import os -import setuptools +import setuptools here = os.path.abspath(os.path.dirname(__file__)) # Get __version__ variable -exec(open(os.path.join(here, 'pytorch_pfn_extras', '_version.py')).read()) +exec(open(os.path.join(here, "pytorch_pfn_extras", "_version.py")).read()) setuptools.setup( - name='pytorch-pfn-extras', - version=__version__, # NOQA - description='Supplementary components to accelerate research and ' - 'development in PyTorch.', - author='Preferred Networks, Inc.', - license='MIT License', - install_requires=['numpy', 'packaging', 'torch', 'typing-extensions>=3.10'], + name="pytorch-pfn-extras", + version=__version__, # NOQA + description="Supplementary components to accelerate research and " + "development in PyTorch.", + author="Preferred Networks, Inc.", + license="MIT License", + install_requires=["numpy", "packaging", "torch", "typing-extensions>=3.10"], extras_require={ - 'test': ['pytest', 'onnxruntime', 'torchvision'], - 'onnx': ['onnx'], + "test": ["pytest", "onnxruntime", "torchvision"], + "onnx": ["onnx"], }, - python_requires='>=3.6.0', - packages=setuptools.find_packages(exclude=['tests', 'tests.*']), - package_data={'pytorch_pfn_extras': ['py.typed']}, + python_requires=">=3.6.0", + packages=setuptools.find_packages(exclude=["tests", "tests.*"]), + package_data={"pytorch_pfn_extras": ["py.typed"]}, ) diff --git a/stubs/torch/_C/__init__.pyi b/stubs/torch/_C/__init__.pyi index bfe8ef95a..fb7e05143 100644 --- a/stubs/torch/_C/__init__.pyi +++ b/stubs/torch/_C/__init__.pyi @@ -1,34 +1,63 @@ # @generated from torch/_C/__init__.pyi.in # flake8: noqa -import torch -from torch.package import PackageExporter -from torch import Tensor, inf -from torch.autograd.graph import Node as _Node +import builtins from enum import Enum from pathlib import Path from typing import ( - Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List, - NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union, - Literal, Generic, Set, AnyStr) + Any, + AnyStr, + BinaryIO, + Callable, + ContextManager, + Dict, + Generic, + Iterable, + Iterator, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, +) +import torch +from torch import Tensor, inf +from torch.autograd.graph import Node as _Node +from torch.package import PackageExporter +from torch.storage import TypedStorage from torch.types import ( - _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt, _dispatchkey + Device, + Number, + Storage, + SymInt, + _bool, + _device, + _dispatchkey, + _dtype, + _float, + _int, + _layout, + _qscheme, + _size, ) -from torch.storage import TypedStorage - -import builtins - -# This module is defined in torch/csrc/Module.cpp -from . import _nn as _nn -from . import _onnx as _onnx -from . import _VariableFunctions as _VariableFunctions from . import _functorch as _functorch from . import _lazy as _lazy from . import _lazy_ts_backend as _lazy_ts_backend +from . import _nn as _nn +from . import _onnx as _onnx +from . import _VariableFunctions as _VariableFunctions + +# This module is defined in torch/csrc/Module.cpp -T = TypeVar('T') +T = TypeVar("T") S = TypeVar("S", bound="torch.Tensor") # Defined in torch/csrc/Device.cpp @@ -41,7 +70,6 @@ class device: # THPDevice_pynew @overload def __init__(self, device: Union[_device, _int, str]) -> None: ... - @overload def __init__(self, type: str, index: _int) -> None: ... @@ -49,9 +77,7 @@ class device: # def __call__(self, func: T) -> T: ... def __enter__(self) -> "device": ... - def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... - def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce # Defined in torch/csrc/Stream.cpp @@ -60,7 +86,7 @@ class Stream: device_index: _int device_type: _int - device: device # The device of the stream + device: device # The device of the stream ... @@ -70,10 +96,8 @@ class Size(Tuple[_int, ...]): @overload # type: ignore[override] def __getitem__(self: Size, key: _int) -> _int: ... - @overload def __getitem__(self: Size, key: slice) -> Size: ... - def numel(self: Size) -> _int: ... ... @@ -107,7 +131,6 @@ class finfo: @overload def __init__(self, dtype: _dtype) -> None: ... - @overload def __init__(self) -> None: ... @@ -139,21 +162,20 @@ quint4x2: dtype = ... quint2x4: dtype = ... # Defined in torch/csrc/Layout.cpp -class layout: - ... +class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp def DisableTorchFunction(): ... def DisableTorchFunctionSubclass(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp -strided : layout = ... -sparse_coo : layout = ... -sparse_csr : layout = ... -sparse_csc : layout = ... -sparse_bsr : layout = ... -sparse_bsc : layout = ... -_mkldnn : layout = ... +strided: layout = ... +sparse_coo: layout = ... +sparse_csr: layout = ... +sparse_csc: layout = ... +sparse_bsr: layout = ... +sparse_bsc: layout = ... +_mkldnn: layout = ... # Defined in torch/csrc/MemoryFormat.cpp class memory_format: ... @@ -175,45 +197,42 @@ per_channel_symmetric: qscheme = ... per_channel_affine_float_qparams: qscheme = ... # Defined in torch/csrc/autograd/python_function.cpp -class _FunctionBase: - ... +class _FunctionBase: ... # Defined in torch/csrc/autograd/python_legacy_variable.cpp class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy def __init__( self, - data: Optional[Tensor]=..., - requires_grad: Optional[_bool]=..., - volatile: Optional[_bool]=..., - _grad_fn: Optional[_FunctionBase]=... + data: Optional[Tensor] = ..., + requires_grad: Optional[_bool] = ..., + volatile: Optional[_bool] = ..., + _grad_fn: Optional[_FunctionBase] = ..., ) -> None: ... # Defined in torch/csrc/jit/python/init.cpp class IODescriptor: ... - class JITException: ... class Future: - def __init__(self, devices: List[device]) -> None: ... - def done(self) -> _bool: ... - def value(self) -> Any: ... - def wait(self) -> Any: ... - def add_done_callback(self, callback: Callable) -> None: ... - def then(self, callback: Callable) -> Future: ... - def set_result(self, result: Any) -> None: ... - def _set_unwrap_func(self, callback: Callable) -> None: ... + def __init__(self, devices: List[device]) -> None: ... + def done(self) -> _bool: ... + def value(self) -> Any: ... + def wait(self) -> Any: ... + def add_done_callback(self, callback: Callable) -> None: ... + def then(self, callback: Callable) -> Future: ... + def set_result(self, result: Any) -> None: ... + def _set_unwrap_func(self, callback: Callable) -> None: ... class _Await: - def __init__(self) -> None: ... - def fn(self) -> Callable: ... - def args(self) -> Tuple[Any, ...]: ... - def is_nowait(self) -> _bool: ... + def __init__(self) -> None: ... + def fn(self) -> Callable: ... + def args(self) -> Tuple[Any, ...]: ... + def is_nowait(self) -> _bool: ... def _jit_set_num_profiled_runs(num: _size) -> _size: ... # Defined in torch/csrc/jit/passes/mobile_optimizer_type.h -class _MobileOptimizerType: - ... +class _MobileOptimizerType: ... CONV_BN_FUSION: _MobileOptimizerType INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType @@ -229,50 +248,66 @@ def _awaitable_wait(aw: _Await) -> Any: ... def _awaitable_nowait(x: Any) -> _Await: ... def _collect_all(futures: List[Future]) -> Future: ... def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ... - def unify_type_list(types: List[JitType]) -> JitType: ... -def _freeze_module(module: ScriptModule, - preserved_attrs: List[str] = [], - freeze_interfaces: _bool = True, - preserveParameters: _bool = True) -> ScriptModule: ... -def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ... -def _jit_pass_optimize_for_inference(module: 'torch.jit.ScriptModule', - other_methods: List[str] = []) -> None: ... +def _freeze_module( + module: ScriptModule, + preserved_attrs: List[str] = [], + freeze_interfaces: _bool = True, + preserveParameters: _bool = True, +) -> ScriptModule: ... +def _jit_pass_optimize_frozen_graph( + Graph, optimize_numerics: _bool = True +) -> None: ... +def _jit_pass_optimize_for_inference( + module: "torch.jit.ScriptModule", other_methods: List[str] = [] +) -> None: ... def _jit_pass_fold_frozen_conv_bn(graph: Graph): ... def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ... def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ... def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ... def _jit_pass_concat_frozen_linear(graph: Graph): ... def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ... -def _jit_pass_transpose_frozen_linear(graph:Graph): ... -def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ... - +def _jit_pass_transpose_frozen_linear(graph: Graph): ... +def _jit_pass_remove_dropout(module: "torch.jit.ScriptModule"): ... def _is_tracing() -> _bool: ... def _jit_init() -> _bool: ... def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ... def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ... -def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, Callable, List[Any]]: ... +def _get_operation_overload( + op_name: str, op_overload_name: str +) -> Tuple[Callable, Callable, List[Any]]: ... def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ... -def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule', - optimization_blocklist: Set[_MobileOptimizerType], - preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... -def _clone_module_with_class(module: 'torch.jit.ScriptModule', - ignored_methods: List[AnyStr], - ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ... -def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule', - optimization_blocklist: Set[_MobileOptimizerType], - preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... -def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule', - preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... +def _jit_pass_optimize_for_mobile( + module: "torch.jit.ScriptModule", + optimization_blocklist: Set[_MobileOptimizerType], + preserved_methods: List[AnyStr], +) -> "torch.jit.ScriptModule": ... +def _clone_module_with_class( + module: "torch.jit.ScriptModule", + ignored_methods: List[AnyStr], + ignored_attributes: List[AnyStr], +) -> "torch.jit.ScriptModule": ... +def _jit_pass_vulkan_optimize_for_mobile( + module: "torch.jit.ScriptModule", + optimization_blocklist: Set[_MobileOptimizerType], + preserved_methods: List[AnyStr], +) -> "torch.jit.ScriptModule": ... +def _jit_pass_metal_optimize_for_mobile( + module: "torch.jit.ScriptModule", preserved_methods: List[AnyStr] +) -> "torch.jit.ScriptModule": ... def _jit_pass_inline(Graph) -> None: ... def _jit_pass_constant_propagation(Graph) -> None: ... def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ... -def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ... +def _jit_register_decomposition_for_schema( + schema: FunctionSchema, Graph +) -> None: ... def _jit_erase_non_input_shape_information(Graph) -> None: ... -def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ... +def _jit_get_schemas_for_operator(name: str) -> List[FunctionSchema]: ... def _jit_get_all_schemas() -> List[FunctionSchema]: ... -def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ... +def _jit_check_alias_annotation( + g: Graph, args: Tuple[Any, ...], unqualified_op_name: str +): ... def _jit_can_fuse_on_cpu() -> _bool: ... def _jit_can_fuse_on_gpu() -> _bool: ... def _jit_can_fuse_on_cpu_legacy() -> _bool: ... @@ -295,36 +330,48 @@ def _jit_cat_wo_conditionals(optimize_cat: _bool): ... def _jit_opt_conditionals(opt_conds: _bool): ... def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ... def _jit_pass_erase_shape_information(graph: Graph): ... -def _jit_pass_fold_convbn(module: 'torch.jit.ScriptModule'): ... -def _jit_pass_insert_observers(module: 'torch.jit.ScriptModule', - method_name: str, - qconfig_dict: Dict[str, Any], - inplace: _bool, - quant_type: _int): ... -def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule', - method_name: str, - inplace: _bool, - debug: _bool, - quant_type: _int): ... -def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule', - method_name: str, - inplace: _bool, - debug: _bool, - quant_type: _int): ... -def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule', - quant_type: _int, - preserved_attrs: Sequence[str]): ... -def _jit_pass_quant_finalize_for_ondevice_ptq(module: 'torch.jit.ScriptModule', - quant_type: _int, - method_name: str): ... -def _jit_pass_insert_observer_method_for_ondevice_ptq(module: 'torch.jit.ScriptModule', - method_name: str, - qconfig_dict: Dict[str, Any], - inplace: _bool, - quant_type: _int): ... +def _jit_pass_fold_convbn(module: "torch.jit.ScriptModule"): ... +def _jit_pass_insert_observers( + module: "torch.jit.ScriptModule", + method_name: str, + qconfig_dict: Dict[str, Any], + inplace: _bool, + quant_type: _int, +): ... +def _jit_pass_insert_quant_dequant( + module: "torch.jit.ScriptModule", + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int, +): ... +def _jit_pass_insert_quant_dequant_for_ondevice_ptq( + module: "torch.jit.ScriptModule", + method_name: str, + inplace: _bool, + debug: _bool, + quant_type: _int, +): ... +def _jit_pass_quant_finalize( + module: "torch.jit.ScriptModule", + quant_type: _int, + preserved_attrs: Sequence[str], +): ... +def _jit_pass_quant_finalize_for_ondevice_ptq( + module: "torch.jit.ScriptModule", quant_type: _int, method_name: str +): ... +def _jit_pass_insert_observer_method_for_ondevice_ptq( + module: "torch.jit.ScriptModule", + method_name: str, + qconfig_dict: Dict[str, Any], + inplace: _bool, + quant_type: _int, +): ... def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ... def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ... -def _jit_set_fusion_strategy(strategy: List[Tuple[str, _int]]) -> List[Tuple[str, _int]]: ... +def _jit_set_fusion_strategy( + strategy: List[Tuple[str, _int]] +) -> List[Tuple[str, _int]]: ... def _jit_try_infer_type(obj: Any) -> InferredType: ... def _jit_get_trigger_value(trigger_name: str) -> _int: ... @@ -333,23 +380,43 @@ ResolutionCallback = Callable[[str], Callable[..., Any]] # Defined in torch/csrc/jit/python/script_init.cpp # and torch/csrc/jit/python/init.cpp -def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ... +def _create_function_from_graph( + qualname: str, graph: Graph +) -> ScriptFunction: ... def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ... def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ... def _jit_assert_is_instance(obj: Any, type: JitType): ... def _jit_clear_class_registry() -> None: ... -def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ... +def _jit_set_emit_hooks( + ModuleHook: Optional[Callable], FunctionHook: Optional[Callable] +) -> None: ... def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ... -def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ... -def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ... +def _load_for_lite_interpreter( + filename: Union[str, Path], map_location: Union[_device, str, None] +): ... +def _load_for_lite_interpreter_from_buffer( + buffer: BinaryIO, map_location: Union[_device, str, None] +): ... def _export_operator_list(module: LiteScriptModule): ... -def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ... +def _quantize_ondevice_ptq_dynamic( + module: LiteScriptModule, method_name: str +): ... def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ... def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ... -def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ... -def _backport_for_mobile_from_buffer(buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int) -> None: ... -def _backport_for_mobile_to_buffer(filename_input: Union[str, Path], to_version: _int) -> bytes:... -def _backport_for_mobile_from_buffer_to_buffer(buffer: BinaryIO, to_version: _int) -> bytes:... +def _backport_for_mobile( + filename_input: Union[str, Path], + filename_output: Union[str, Path], + to_version: _int, +) -> None: ... +def _backport_for_mobile_from_buffer( + buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int +) -> None: ... +def _backport_for_mobile_to_buffer( + filename_input: Union[str, Path], to_version: _int +) -> bytes: ... +def _backport_for_mobile_from_buffer_to_buffer( + buffer: BinaryIO, to_version: _int +) -> bytes: ... def _get_model_ops_and_info(filename: Union[str, Path]): ... def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ... def _get_mobile_model_contained_types(filename: Union[str, Path]): ... @@ -365,7 +432,7 @@ def _create_function_from_trace( var_lookup_fn: Callable[[Tensor], str], strict: _bool, force_outplace: _bool, - argument_names: List[str] + argument_names: List[str], ) -> Tuple[Graph, Stack]: ... def _create_function_from_trace_with_dict( qualname: str, @@ -374,7 +441,7 @@ def _create_function_from_trace_with_dict( var_lookup_fn: Callable[[Tensor], str], strict: _bool, force_outplace: _bool, - argument_names: List[str] + argument_names: List[str], ) -> Tuple[Graph, Stack]: ... def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... @@ -383,25 +450,46 @@ def _get_upgraders_map_size() -> _int: ... def _dump_upgraders_map() -> Dict[str, str]: ... def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ... -def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ... +def merge_type_from_type_comment( + decl: Decl, type_annotation_decl: Decl, is_method: _bool +) -> Decl: ... def parse_ir(input: str, parse_tensor_constants: _bool) -> Graph: ... def parse_schema(schema: str) -> FunctionSchema: ... def get_device(input: Tensor) -> _int: ... - -def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallback) -> JitType: ... +def _resolve_type_from_object( + obj: Any, range: SourceRange, rcb: ResolutionCallback +) -> JitType: ... def _create_module_with_type(ty: JitType) -> ScriptModule: ... def _create_object_with_type(ty: ClassType) -> ScriptObject: ... def _run_emit_module_hook(m: ScriptModule): ... -def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ... - +def _replace_overloaded_method_decl( + overload_decl: Decl, implementation_def: Def, new_name: str +) -> Def: ... def _jit_pass_lower_all_tuples(graph: Graph) -> None: ... -def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ... -def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, params_dict: Dict[str, IValue], opset_version: _int) -> None: ... -def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool, is_script: _bool, opset_version: _int) -> None: ... -def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Optional[ScriptModule] = None) -> None: ... +def _jit_pass_onnx_set_dynamic_input_shape( + graph: Graph, + dynamic_axes: Dict[str, Dict[_int, str]], + input_names: List[str], +) -> None: ... +def _jit_pass_onnx_graph_shape_type_inference( + graph: Graph, params_dict: Dict[str, IValue], opset_version: _int +) -> None: ... +def _jit_pass_onnx_assign_output_shape( + graph: Graph, + tensors: List[Tensor], + desc: IODescriptor, + onnx_shape_inference: _bool, + is_script: _bool, + opset_version: _int, +) -> None: ... +def _jit_pass_onnx_remove_inplace_ops_for_onnx( + graph: Graph, module: Optional[ScriptModule] = None +) -> None: ... def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ... def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ... -def _jit_pass_peephole(graph: Graph, disable_shape_peepholes: _bool = False) -> None: ... +def _jit_pass_peephole( + graph: Graph, disable_shape_peepholes: _bool = False +) -> None: ... def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ... def _jit_pass_fuse_addmm(graph: Graph) -> None: ... def _jit_pass_onnx_preprocess(graph: Graph) -> None: ... @@ -409,102 +497,133 @@ def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ... def _jit_pass_onnx_remove_print(graph: Graph) -> None: ... def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ... def _jit_pass_onnx_unpack_quantized_weights( - graph: Graph, - paramsDict: Dict[str, IValue], - caffe2: _bool + graph: Graph, paramsDict: Dict[str, IValue], caffe2: _bool ) -> Dict[str, IValue]: ... def _jit_pass_onnx_quantization_insert_permutes( - graph: Graph, - paramsDict: Dict[str, IValue] + graph: Graph, paramsDict: Dict[str, IValue] ) -> Dict[str, IValue]: ... -def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ... -def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ... +def _jit_pass_custom_pattern_based_rewrite_graph( + pattern: str, fused_node_name: str, graph: Graph +) -> None: ... +def _jit_onnx_list_model_parameters( + module: ScriptModule, +) -> Tuple[ScriptModule, List[IValue]]: ... def _jit_pass_erase_number_types(graph: Graph) -> None: ... def _jit_pass_onnx_lint(graph: Graph) -> None: ... -def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ... -def _jit_pass_onnx_scalar_type_analysis(graph: Graph, lowprecision_cast: _bool, opset_version: _int) -> None: ... -def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ... -def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ... +def _jit_pass_onnx( + graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes +) -> Graph: ... +def _jit_pass_onnx_scalar_type_analysis( + graph: Graph, lowprecision_cast: _bool, opset_version: _int +) -> None: ... +def _jit_pass_onnx_peephole( + graph: Graph, opset_version: _int, fixed_batch_size: _bool +) -> None: ... +def _jit_pass_dce_allow_deleting_nodes_with_side_effects( + graph: Graph, +) -> None: ... def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ... -def _jit_pass_onnx_function_extraction(graph: Graph, module_names : Set[str], param_names : List[str]) -> Dict[Node, Dict[str, str]]: ... +def _jit_pass_onnx_function_extraction( + graph: Graph, module_names: Set[str], param_names: List[str] +) -> Dict[Node, Dict[str, str]]: ... def _jit_pass_onnx_clear_scope_records() -> None: ... -def _jit_pass_onnx_track_scope_attributes(graph: Graph, onnx_attrs: Dict[str, Any]) -> None: ... +def _jit_pass_onnx_track_scope_attributes( + graph: Graph, onnx_attrs: Dict[str, Any] +) -> None: ... def _jit_is_onnx_log_enabled() -> _bool: ... def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ... def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ... def _jit_onnx_log(*args: Any) -> None: ... -def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ... +def _jit_pass_lower_graph( + graph: Graph, m: Module +) -> Tuple[Graph, List[IValue]]: ... def _jit_pass_inline_fork_wait(graph: Graph) -> None: ... -def _jit_pass_onnx_deduplicate_initializers(graph: Graph, params_dict: Dict[str, IValue], is_train: _bool) -> Dict[str, IValue]: ... -def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... -def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ... -def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ... +def _jit_pass_onnx_deduplicate_initializers( + graph: Graph, params_dict: Dict[str, IValue], is_train: _bool +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eval_peephole( + graph: Graph, paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_constant_fold( + graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int +) -> Dict[str, IValue]: ... +def _jit_pass_onnx_eliminate_unused_items( + graph: Graph, paramsDict: Dict[str, IValue] +) -> Dict[str, IValue]: ... def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ... -def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ... +def _jit_pass_filter_non_tensor_arguments( + params: Dict[str, IValue] +) -> Dict[str, Tensor]: ... def _jit_decay_packed_param_input_types(graph: Graph) -> None: ... -def _jit_pass_onnx_node_shape_type_inference(n: Node, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ... -def _jit_onnx_convert_pattern_from_subblock(block: Block, n: Node, env: Dict[Value, Value]) -> List[Value]: ... +def _jit_pass_onnx_node_shape_type_inference( + n: Node, paramsDict: Dict[str, IValue], opset_version: _int +) -> None: ... +def _jit_onnx_convert_pattern_from_subblock( + block: Block, n: Node, env: Dict[Value, Value] +) -> List[Value]: ... def _jit_pass_onnx_block( old_block: Block, new_block: Block, operator_export_type: _onnx.OperatorExportTypes, env: Dict[Value, Value], - is_sub_block: _bool + is_sub_block: _bool, ) -> Dict[Value, Value]: ... -def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ... -def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> List[Value]: ... -def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ... - +def _jit_pass_onnx_assign_scoped_names_for_node_and_value( + graph: Graph, +) -> None: ... +def _jit_pass_fixup_onnx_controlflow_node( + n: Node, opset_version: _int +) -> List[Value]: ... +def _jit_onnx_create_full_scope_name( + class_name: str, variable_name: str +) -> str: ... def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ... - def _generate_upgraders_graph() -> Dict[str, Graph]: ... - def _calculate_package_version_based_on_upgraders(val: _bool): ... - def _get_version_calculator_flag() -> _bool: ... - -def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ... +def _jit_script_interface_compile( + name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool +): ... def _jit_script_compile_overload( qualname: str, overload_decl: Decl, implementation_def: Def, rcb: ResolutionCallback, implementation_defaults: Dict[str, Any], - signature: Any + signature: Any, ): ... def _jit_script_compile( qual_name: str, definition: Def, rcb: ResolutionCallback, - defaults: Dict[str, Any] + defaults: Dict[str, Any], ): ... def _jit_script_class_compile( qual_name: str, definition: ClassDef, defaults: Dict[str, Dict[str, Any]], - rcb: ResolutionCallback + rcb: ResolutionCallback, ): ... def _parse_source_def(src: str) -> Def: ... def import_ir_module( cu: CompilationUnit, filename: Union[str, Path], map_location: Union[_device, str, None], - extra_files: Dict[str, Any] + extra_files: Dict[str, Any], ) -> ScriptModule: ... def import_ir_module_from_buffer( cu: CompilationUnit, buffer: BinaryIO, map_location: Union[_device, str, None], - extra_files: Dict[str, Any] + extra_files: Dict[str, Any], ) -> ScriptModule: ... def _import_ir_module_from_package( cu: CompilationUnit, reader: PyTorchFileReader, storage_context: DeserializationStorageContext, map_location: Union[_device, str, None], - ts_id: str + ts_id: str, ) -> ScriptModule: ... - def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ... def _check_onnx_proto(proto: str) -> None: ... def _propagate_and_assign_input_shapes( @@ -512,12 +631,11 @@ def _propagate_and_assign_input_shapes( inputs: Tuple[Tensor, ...], param_count_list: List[_int], with_grad: _bool, - propagate: _bool + propagate: _bool, ) -> Graph: ... # Defined in torch/csrc/jit/runtime/graph_executor.h -class GraphExecutorState: - ... +class GraphExecutorState: ... # Defined in torch/torch/csrc/jit/ir/alias_analysis.h class AliasDb: @@ -539,7 +657,7 @@ class Use: # Defined in torch/csrc/jit/ir/ir.h class Value: - def type(self)-> JitType: ... + def type(self) -> JitType: ... def setType(self, t: JitType) -> Value: ... def setTypeAs(self, other: Value) -> Value: ... def inferTypeFrom(self, t: Tensor) -> None: ... @@ -677,31 +795,34 @@ class Graph: def setInsertPoint(self, n: Union[Block, Node]) -> None: ... def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ... def insertPoint(self) -> Node: ... - def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ... + def insertGraph( + self, callee: Graph, inputs: List[Value] + ) -> List[Value]: ... def makeMultiOutputIntoTuple(self) -> None: ... def copy(self) -> Graph: ... def create(self, name: str, *args, num_outputs: _int = 1) -> Node: ... - def op(self, kind: str, *args: Any, **kwargs: Any) -> Union[Value, Sequence[Value]]: ... + def op( + self, kind: str, *args: Any, **kwargs: Any + ) -> Union[Value, Sequence[Value]]: ... ... - # Defined in torch/aten/src/ATen/core/alias_info.h class AliasInfo: is_write: _bool before_set: Set[str] after_set: Set[str] - # Defined in torch/aten/src/ATen/core/function_schema.h class Argument: name: str type: JitType default_value: Optional[Any] def has_default_value(self) -> _bool: ... - kwarg_only : _bool + kwarg_only: _bool is_out: _bool alias_info: Optional[AliasInfo] ... + class FunctionSchema: arguments: List[Argument] returns: List[Argument] @@ -713,26 +834,28 @@ class _UpgraderEntry: bumped_at_version: _int upgrader_name: str old_schema: str - def __init__(self, bumped_at_version: _int, upgrader_name: str, old_schema: str) -> None: ... + def __init__( + self, bumped_at_version: _int, upgrader_name: str, old_schema: str + ) -> None: ... class _UpgraderRange: min_version: _int max_version: _int def _get_max_operator_version() -> _int: ... - def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ... - def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ... - -def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ... - +def _test_only_add_entry_to_op_version( + op_name: str, entry: _UpgraderEntry +) -> None: ... def _test_only_remove_entry_to_op_version(op_name: str) -> None: ... # Defined in torch/csrc/jit/python/script_init.cpp class ScriptModuleSerializer: def __init__(self, export_writer: PyTorchFileWriter) -> None: ... - def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ... + def serialize( + self, model: ScriptModule, script_module_id: _int + ) -> None: ... def write_files(self) -> None: ... def storage_context(self) -> SerializationStorageContext: ... ... @@ -759,13 +882,19 @@ class ConcreteModuleTypeBuilder: def set_module_list(self): ... def set_parameter_list(self): ... def set_parameter_dict(self): ... - def add_attribute(self, name: str, ty: JitType, is_param: _bool, is_buffer: _bool): ... + def add_attribute( + self, name: str, ty: JitType, is_param: _bool, is_buffer: _bool + ): ... def add_module(self, name: str, meta: ConcreteModuleType): ... def add_constant(self, name: str, value: Any): ... - def add_overload(self, method_name: str, overloaded_method_names: List[str]): ... + def add_overload( + self, method_name: str, overloaded_method_names: List[str] + ): ... def add_builtin_function(self, name: str, symbol_name: str): ... def add_failed_attribute(self, name: str, failure_reason: str): ... - def add_function_attribute(self, name: str, ty: JitType, func: Callable[..., Any]): ... + def add_function_attribute( + self, name: str, ty: JitType, func: Callable[..., Any] + ): ... def add_ignored_attribute(self, name: str): ... def add_ignored_attributes(self, names: List[str]): ... def add_forward_hook(self, hook: Callable[..., Any]): ... @@ -773,8 +902,7 @@ class ConcreteModuleTypeBuilder: class ConcreteModuleType: def get_constants(self) -> Dict[str, Any]: ... - def equals(self, other: 'ConcreteModuleType') -> _bool: ... - + def equals(self, other: "ConcreteModuleType") -> _bool: ... @staticmethod def from_jit_type(ty: JitType) -> ConcreteModuleType: ... @@ -784,18 +912,21 @@ class CallStack: class ErrorReport: def __init__(self, range: SourceRange) -> None: ... def what(self) -> str: ... - @staticmethod def call_stack() -> str: ... class CompilationUnit: - def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ... + def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ... def find_function(self, name: str) -> ScriptFunction: ... def __getattr__(self, name: str) -> ScriptFunction: ... - def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ... + def define( + self, script: str, rcb: ResolutionCallback = ..., _frames_up: _int = ... + ): ... def get_interface(self, name: str) -> InterfaceType: ... def get_functions(self) -> List[ScriptFunction]: ... - def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ... + def create_function( + self, name: str, graph: Graph, shouldMangle: _bool = ... + ) -> ScriptFunction: ... def get_class(self, name: str) -> ClassType: ... class ScriptObject: @@ -842,16 +973,19 @@ class BufferDict: def __init__(self, mod: ScriptModule) -> None: ... # Defined in torch/csrc/jit/api/module.h -class Module: - ... +class Module: ... # Defined in torch/csrc/Module.cpp -def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension +def _initExtension( + shm_manager_path: str, +) -> None: ... # THPModule_initExtension def _autograd_init() -> _bool: ... # THPAutograd_initExtension def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames def _has_distributed() -> _bool: ... # THPModule_hasDistributed -def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType +def _set_default_tensor_type( + type, +) -> None: ... # THPModule_setDefaultTensorType def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN @@ -860,56 +994,101 @@ def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN def _show_config() -> str: ... # THPModule_showConfig def _cxx_flags() -> str: ... # THPModule_cxxFlags def _parallel_info() -> str: ... # THPModule_parallelInfo -def _set_backcompat_broadcast_warn(arg: _bool) -> None: ... # THPModule_setBackcompatBroadcastWarn -def _get_backcompat_broadcast_warn() -> _bool: ... # THPModule_getBackcompatBroadcastWarn -def _set_backcompat_keepdim_warn(arg: _bool) -> None: ... # THPModule_setBackcompatKeepdimWarn -def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn +def _set_backcompat_broadcast_warn( + arg: _bool, +) -> None: ... # THPModule_setBackcompatBroadcastWarn +def _get_backcompat_broadcast_warn() -> ( + _bool +): ... # THPModule_getBackcompatBroadcastWarn +def _set_backcompat_keepdim_warn( + arg: _bool, +) -> None: ... # THPModule_setBackcompatKeepdimWarn +def _get_backcompat_keepdim_warn() -> ( + _bool +): ... # THPModule_getBackcompatKeepdimWarn def get_num_thread() -> _int: ... # THPModule_getNumThreads def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads -def set_num_interop_threads(nthreads: _int) -> None: ... # THPModule_setNumInteropThreads +def set_num_interop_threads( + nthreads: _int, +) -> None: ... # THPModule_setNumInteropThreads def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash -def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP -def _set_sdp_use_mem_efficient(arg: _bool) -> None: ... # THPModule_setSDPUseMemEfficient +def _get_mem_efficient_sdp_enabled() -> ( + _bool +): ... # THPModule_userEnabledMathSDP +def _set_sdp_use_mem_efficient( + arg: _bool, +) -> None: ... # THPModule_setSDPUseMemEfficient def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn -def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn +def _set_mkldnn_enabled( + arg: _bool, +) -> None: ... # THPModule_setUserEnabledMkldnn def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN -def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN -def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms -def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly -def _set_deterministic_algorithms(mode: _bool, *, warn_only: _bool=...) -> None: ... # THPModule_setDeterministicAlgorithms +def _set_cudnn_deterministic( + arg: _bool, +) -> None: ... # THPModule_setDeterministicCuDNN +def _get_deterministic_algorithms() -> ( + _bool +): ... # THPModule_deterministicAlgorithms +def _get_deterministic_algorithms_warn_only() -> ( + _bool +): ... # THPModule_deterministicAlgorithmsWarnOnly +def _set_deterministic_algorithms( + mode: _bool, *, warn_only: _bool = ... +) -> None: ... # THPModule_setDeterministicAlgorithms def _get_warnAlways() -> _bool: ... # THPModule_warnAlways def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN -def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN +def _set_cudnn_allow_tf32( + arg: _bool, +) -> None: ... # THPModule_setAllowTF32CuDNN def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS -def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS -def _get_float32_matmul_precision() -> str: ... #THPModule_float32MatmulPrecision -def _set_float32_matmul_precision(arg: str) -> None: ... #THPModule_setFloat32MatmulPrecision -def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS -def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS -def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... #THPModule_allowBF16ReductionCuBLAS -def _set_cublas_allow_bf16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowBF16ReductionCuBLAS +def _set_cublas_allow_tf32( + arg: _bool, +) -> None: ... # THPModule_setAllowTF32CuBLAS +def _get_float32_matmul_precision() -> ( + str +): ... # THPModule_float32MatmulPrecision +def _set_float32_matmul_precision( + arg: str, +) -> None: ... # THPModule_setFloat32MatmulPrecision +def _get_cublas_allow_fp16_reduced_precision_reduction() -> ( + _bool +): ... # THPModule_allowFP16ReductionCuBLAS +def _set_cublas_allow_fp16_reduced_precision_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS +def _get_cublas_allow_bf16_reduced_precision_reduction() -> ( + _bool +): ... # THPModule_allowBF16ReductionCuBLAS +def _set_cublas_allow_bf16_reduced_precision_reduction( + arg: _bool, +) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS def _set_conj(x: Tensor, conj: _bool) -> None: ... def _set_neg(x: Tensor, neg: _bool) -> None: ... def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ... def _meta_in_tls_dispatch_include() -> _bool: ... def _select_conv_backend(*args, **kwargs) -> ConvBackend: ... -def _conv_determine_backend_memory_format(input: Tensor, weight: Tensor, backend: ConvBackend) -> memory_format: ... +def _conv_determine_backend_memory_format( + input: Tensor, weight: Tensor, backend: ConvBackend +) -> memory_format: ... def _has_storage(x: Tensor) -> _bool: ... def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ... + # NB: There is no Capsule type in typing, see # https://code.activestate.com/lists/python-dev/139675/ def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack -def _get_cpp_backtrace(frames_to_skip: _int, maximum_number_of_frames: _int) -> str: ... # THPModule_getCppBacktrace +def _get_cpp_backtrace( + frames_to_skip: _int, maximum_number_of_frames: _int +) -> str: ... # THPModule_getCppBacktrace def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype def _get_default_device() -> str: ... # THPModule_getDefaultDevice @@ -917,34 +1096,60 @@ def _get_qengine() -> _int: ... # THPModule_qEngine def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK -def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants -def _set_check_sparse_tensor_invariants(arg: _bool) -> None: ... # THPModule_setCheckSparseTensorInvariants -def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator -def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator -def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction -def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function -def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary -def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic -def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting -def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting +def _check_sparse_tensor_invariants() -> ( + _bool +): ... # THPModule_checkSparseTensorInvariants +def _set_check_sparse_tensor_invariants( + arg: _bool, +) -> None: ... # THPModule_setCheckSparseTensorInvariants +def _set_default_mobile_cpu_allocator() -> ( + None +): ... # THPModule_setDefaultMobileCPUAllocator +def _unset_default_mobile_cpu_allocator() -> ( + None +): ... # THPModule_unsetDefaultMobileCPUAllocator +def _is_torch_function_enabled() -> ( + _bool +): ... # THPModule_isEnabledTorchFunction +def _has_torch_function( + args: Iterable[Any], +) -> _bool: ... # THPModule_has_torch_function +def _has_torch_function_unary( + Any, +) -> _bool: ... # THPModule_has_torch_function_unary +def _has_torch_function_variadic( + *args: Any, +) -> _bool: ... # THPModule_has_torch_function_variadic +def _vmapmode_increment_nesting() -> ( + _int +): ... # THPModule_vmapmode_increment_nesting +def _vmapmode_decrement_nesting() -> ( + _int +): ... # THPModule_vmapmode_decrement_nesting def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython def _demangle(str) -> str: ... # c10::demangle -def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_torch_function -def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function +def _disabled_torch_function_impl( + func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict +) -> Any: ... # THPModule_disable_torch_function +def _disabled_torch_dispatch_impl( + func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict +) -> Any: ... # THPModule_disable_dispatch_function def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ... def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ... + class _LinalgBackend: Default: _LinalgBackend Cusolver: _LinalgBackend Magma: _LinalgBackend -class ConvBackend(Enum): - ... +class ConvBackend(Enum): ... # Defined in `valgrind.h` and `callgrind.h` respecitively. def _valgrind_supported_platform() -> _bool: ... # NVALGRIND def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT -def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS +def _valgrind_toggle_and_dump_stats() -> ( + None +): ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS has_openmp: _bool has_mkl: _bool @@ -985,16 +1190,16 @@ def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ... def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ... def __set_forward_AD_enabled(enabled: _bool) -> None: ... def __is_forward_AD_enabled() -> _bool: ... -def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... +def _register_default_hooks( + pack_hook: Callable, unpack_hook: Callable +) -> None: ... def _reset_default_hooks() -> None: ... - -def _is_torch_function_mode_enabled()-> _bool: ... +def _is_torch_function_mode_enabled() -> _bool: ... def _set_torch_function_mode(cls: Any) -> None: ... def _push_on_torch_function_stack(cls: Any) -> None: ... def _pop_torch_function_stack() -> Any: ... def _get_function_stack_at(idx: _int) -> Any: ... def _len_torch_function_stack() -> _int: ... - def _set_torch_dispatch_mode(cls: Any) -> None: ... def _push_on_torch_dispatch_stack(cls: Any) -> None: ... def _pop_torch_dispatch_stack() -> Any: ... @@ -1017,14 +1222,9 @@ class _ViewReplayEnabled: def __init__(self, mode: _bool) -> None: ... # Defined in torch/csrc/jit/python/script_init.cpp -class LoggerBase: - ... - -class NoopLogger(LoggerBase): - ... - -class LockingLogger(LoggerBase): - ... +class LoggerBase: ... +class NoopLogger(LoggerBase): ... +class LockingLogger(LoggerBase): ... class AggregationType(Enum): SUM = 0 @@ -1032,13 +1232,15 @@ class AggregationType(Enum): class FileCheck: def run(self, test_string: str) -> None: ... - def check(self, test_string: str) -> 'FileCheck': ... - def check_not(self, test_string: str) -> 'FileCheck': ... - def check_same(self, test_string: str) -> 'FileCheck': ... - def check_next(self, test_string: str) -> 'FileCheck': ... - def check_count(self, test_string: str, count: _int, exactly: _bool = False) -> 'FileCheck': ... - def check_dag(self, test_string: str) -> 'FileCheck': ... - def check_source_highlighted(self, test_string: str) -> 'FileCheck': ... + def check(self, test_string: str) -> "FileCheck": ... + def check_not(self, test_string: str) -> "FileCheck": ... + def check_same(self, test_string: str) -> "FileCheck": ... + def check_next(self, test_string: str) -> "FileCheck": ... + def check_count( + self, test_string: str, count: _int, exactly: _bool = False + ) -> "FileCheck": ... + def check_dag(self, test_string: str) -> "FileCheck": ... + def check_source_highlighted(self, test_string: str) -> "FileCheck": ... ... # Defined in torch/csrc/jit/python/init.cpp @@ -1055,7 +1257,9 @@ class PyTorchFileWriter: def __init__(self, name: str) -> None: ... @overload def __init__(self, buffer: BinaryIO) -> None: ... - def write_record(self, name: str, data: Union[bytes, _int], size: _int) -> None: ... + def write_record( + self, name: str, data: Union[bytes, _int], size: _int + ) -> None: ... def write_end_of_file(self) -> None: ... def set_min_version(self, version: _int) -> None: ... def get_all_written_records(self) -> List[str]: ... @@ -1087,7 +1291,6 @@ class Generator: def seed(self) -> _int: ... def initial_seed(self) -> _int: ... - # Defined in torch/csrc/utils/python_dispatch.cpp class _DispatchOperatorHandle: @@ -1096,27 +1299,51 @@ class _DispatchOperatorHandle: class _DispatchModule: def def_(self, schema: str, alias: str = "") -> _DispatchModule: ... def def_legacy(self, schema: str) -> _DispatchModule: ... - def def_name_t_t(self, name: str, dispatch: str, debug: str = "default_def_name_t_t") -> _DispatchModule: ... - def def_schema_t_t(self, schema: str, dispatch: str, alias: str, debug: str = "default_def_schema_t_t") -> _DispatchModule: ... - def impl_t_t(self, name: str, dispatch: str, debug: str = "impl_t_t") -> _DispatchModule: ... - def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ... + def def_name_t_t( + self, name: str, dispatch: str, debug: str = "default_def_name_t_t" + ) -> _DispatchModule: ... + def def_schema_t_t( + self, + schema: str, + dispatch: str, + alias: str, + debug: str = "default_def_schema_t_t", + ) -> _DispatchModule: ... + def impl_t_t( + self, name: str, dispatch: str, debug: str = "impl_t_t" + ) -> _DispatchModule: ... + def impl( + self, name: str, dispatch: str, func: Callable + ) -> _DispatchModule: ... def define(self, schema: str, alias: str = "") -> _DispatchModule: ... def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ... -def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> _DispatchModule: ... +def _dispatch_library( + kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0 +) -> _DispatchModule: ... def _dispatch_dump(name: str) -> str: ... def _dispatch_dump_table(name: str) -> str: ... def _dispatch_check_invariants(name: str) -> None: ... def _dispatch_check_all_invariants() -> None: ... def _dispatch_has_kernel(name: str) -> _bool: ... -def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: _dispatchkey) -> _bool: ... -def _dispatch_has_kernel_for_any_dispatch_key(name: str, dispatch_key_set: DispatchKeySet) -> _bool: ... -def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: _dispatchkey) -> _bool: ... +def _dispatch_has_kernel_for_dispatch_key( + name: str, dispatch: _dispatchkey +) -> _bool: ... +def _dispatch_has_kernel_for_any_dispatch_key( + name: str, dispatch_key_set: DispatchKeySet +) -> _bool: ... +def _dispatch_has_computed_kernel_for_dispatch_key( + name: str, dispatch: _dispatchkey +) -> _bool: ... def _dispatch_find_dangling_impls() -> List[str]: ... def _dispatch_get_all_op_names() -> List[str]: ... -def _dispatch_tls_set_dispatch_key_excluded(dispatch: _dispatchkey, val: _bool) -> None: ... +def _dispatch_tls_set_dispatch_key_excluded( + dispatch: _dispatchkey, val: _bool +) -> None: ... def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ... -def _dispatch_tls_set_dispatch_key_included(dispatch: _dispatchkey, val: _bool) -> None: ... +def _dispatch_tls_set_dispatch_key_included( + dispatch: _dispatchkey, val: _bool +) -> None: ... def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ... def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ... def _dispatch_key_name(dispatch: _dispatchkey) -> str: ... @@ -1241,14 +1468,19 @@ class DispatchKeySet: def __repr__(self) -> str: ... _dispatch_autogradother_backends: DispatchKeySet + def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ... def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ... def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ... -def _dispatch_get_backend_keyset_from_autograd(dispatch: _dispatchkey) -> DispatchKeySet: ... +def _dispatch_get_backend_keyset_from_autograd( + dispatch: _dispatchkey, +) -> DispatchKeySet: ... def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ... def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ... def _dispatch_tls_local_include_set() -> DispatchKeySet: ... -def _dispatch_is_included_in_alias(dispatch_a: _dispatchkey, dispatch_b: _dispatchkey) -> _bool: ... +def _dispatch_is_included_in_alias( + dispatch_a: _dispatchkey, dispatch_b: _dispatchkey +) -> _bool: ... class ExcludeDispatchKeyGuard: pass @@ -1256,9 +1488,12 @@ class ExcludeDispatchKeyGuard: class _AutoDispatchBelowAutograd: pass -def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ... -def _dispatch_get_registrations_for_dispatch_key(dispatch_key: str = "") -> List[str]: ... - +def _dispatch_print_registrations_for_dispatch_key( + dispatch_key: str = "", +) -> None: ... +def _dispatch_get_registrations_for_dispatch_key( + dispatch_key: str = "", +) -> List[str]: ... def _are_functorch_transforms_active() -> _bool: ... # Define in torch/csrc/autograd/init.cpp @@ -1270,7 +1505,6 @@ class _EnablePythonDispatcher: def _set_python_dispatcher(dispatcher: object) -> None: ... - # Defined in torch/csrc/utils/init.cpp class BenchmarkConfig: num_calling_threads: _int @@ -1353,7 +1587,9 @@ class _TensorBase(metaclass=_TensorMeta): def __float__(self) -> builtins.float: ... def __floordiv__(self, other: Any) -> Tensor: ... def __ge__(self, other: Any) -> Tensor: ... - def __getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor: ... + def __getitem__( + self, indices: Union[None, _int, slice, Tensor, List, Tuple] + ) -> Tensor: ... def __gt__(self, other: Any) -> Tensor: ... def __iadd__(self, other: Any) -> Tensor: ... @overload @@ -1374,13 +1610,13 @@ class _TensorBase(metaclass=_TensorMeta): def __imul__(self, other: Any) -> Tensor: ... def __index__(self) -> builtins.int: ... @overload - def __init__(self, *args: Any, device: Device=None) -> None: ... + def __init__(self, *args: Any, device: Device = None) -> None: ... @overload def __init__(self, storage: Storage) -> None: ... @overload def __init__(self, other: Tensor) -> None: ... @overload - def __init__(self, size: _size, *, device: Device=None) -> None: ... + def __init__(self, size: _size, *, device: Device = None) -> None: ... def __int__(self) -> builtins.int: ... def __invert__(self) -> Tensor: ... @overload @@ -1439,7 +1675,11 @@ class _TensorBase(metaclass=_TensorMeta): def __rsub__(self, other: Any) -> Tensor: ... def __rtruediv__(self, other: Any) -> Tensor: ... def __rxor__(self, other: Any) -> Tensor: ... - def __setitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple], val: Union[Tensor, Number]) -> None: ... + def __setitem__( + self, + indices: Union[None, _int, slice, Tensor, List, Tuple], + val: Union[Tensor, Number], + ) -> None: ... def __sub__(self, other: Any) -> Tensor: ... def __truediv__(self, other: Any) -> Tensor: ... @overload @@ -1448,9 +1688,25 @@ class _TensorBase(metaclass=_TensorMeta): def __xor__(self, other: Number) -> Tensor: ... @overload def __xor__(self, other: Any) -> Tensor: ... - def _addmm_activation(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1, use_gelu: _bool=False) -> Tensor: ... - def _autocast_to_full_precision(self, cuda_enabled: _bool, cpu_enabled: _bool) -> Tensor: ... - def _autocast_to_reduced_precision(self, cuda_enabled: _bool, cpu_enabled: _bool, cuda_dtype: _dtype, cpu_dtype: _dtype) -> Tensor: ... + def _addmm_activation( + self, + mat1: Tensor, + mat2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + use_gelu: _bool = False, + ) -> Tensor: ... + def _autocast_to_full_precision( + self, cuda_enabled: _bool, cpu_enabled: _bool + ) -> Tensor: ... + def _autocast_to_reduced_precision( + self, + cuda_enabled: _bool, + cpu_enabled: _bool, + cuda_dtype: _dtype, + cpu_dtype: _dtype, + ) -> Tensor: ... def _coalesced_(self, coalesced: _bool) -> Tensor: ... def _conj(self) -> Tensor: ... def _conj_physical(self) -> Tensor: ... @@ -1461,12 +1717,19 @@ class _TensorBase(metaclass=_TensorMeta): def _is_any_true(self) -> Tensor: ... def _is_view(self) -> _bool: ... def _is_zerotensor(self) -> _bool: ... - def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False, dispatch_device: _bool=False, device_for_backend_keys: Optional[_device] = None) -> Tensor: ... + def _make_subclass( + cls, + data: Tensor, + require_grad: _bool = False, + dispatch_strides: _bool = False, + dispatch_device: _bool = False, + device_for_backend_keys: Optional[_device] = None, + ) -> Tensor: ... def _neg_view(self) -> Tensor: ... def _nested_tensor_size(self) -> Tensor: ... def _nested_tensor_strides(self) -> Tensor: ... def _nnz(self) -> _int: ... - def _to_dense(self, dtype: Optional[_dtype]=None) -> Tensor: ... + def _to_dense(self, dtype: Optional[_dtype] = None) -> Tensor: ... def _values(self) -> Tensor: ... def abs(self) -> Tensor: ... def abs_(self) -> Tensor: ... @@ -1476,43 +1739,108 @@ class _TensorBase(metaclass=_TensorMeta): def acos_(self) -> Tensor: ... def acosh(self) -> Tensor: ... def acosh_(self) -> Tensor: ... - def add(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ... - def add_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1) -> Tensor: ... - def addbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addcdiv(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ... - def addcdiv_(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ... - def addcmul(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ... - def addcmul_(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ... - def addmm(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addmm_(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addmv(self, mat: Tensor, vec: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addmv_(self, mat: Tensor, vec: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addr(self, vec1: Tensor, vec2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def addr_(self, vec1: Tensor, vec2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... + def add( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + alpha: Optional[Number] = 1, + out: Optional[Tensor] = None, + ) -> Tensor: ... + def add_( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + alpha: Optional[Number] = 1, + ) -> Tensor: ... + def addbmm( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + ) -> Tensor: ... + def addbmm_( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + ) -> Tensor: ... + def addcdiv( + self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1 + ) -> Tensor: ... + def addcdiv_( + self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1 + ) -> Tensor: ... + def addcmul( + self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1 + ) -> Tensor: ... + def addcmul_( + self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1 + ) -> Tensor: ... + def addmm( + self, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... + def addmm_( + self, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... + def addmv( + self, mat: Tensor, vec: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... + def addmv_( + self, mat: Tensor, vec: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... + def addr( + self, vec1: Tensor, vec2: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... + def addr_( + self, vec1: Tensor, vec2: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... def adjoint(self) -> Tensor: ... def align_as(self, other: Tensor) -> Tensor: ... @overload - def align_to(self, order: Sequence[Union[str, ellipsis, None]], ellipsis_idx: _int) -> Tensor: ... + def align_to( + self, order: Sequence[Union[str, ellipsis, None]], ellipsis_idx: _int + ) -> Tensor: ... @overload - def align_to(self, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ... + def align_to( + self, names: Sequence[Union[str, ellipsis, None]] + ) -> Tensor: ... @overload def all(self) -> Tensor: ... @overload - def all(self, dim: _int, keepdim: _bool=False) -> Tensor: ... + def all(self, dim: _int, keepdim: _bool = False) -> Tensor: ... @overload - def all(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> Tensor: ... - def allclose(self, other: Tensor, rtol: _float=1e-05, atol: _float=1e-08, equal_nan: _bool=False) -> _bool: ... - def amax(self, dim: Union[_int, _size]=(), keepdim: _bool=False) -> Tensor: ... - def amin(self, dim: Union[_int, _size]=(), keepdim: _bool=False) -> Tensor: ... - def aminmax(self, *, dim: Optional[_int]=None, keepdim: _bool=False) -> torch.return_types.aminmax: ... + def all( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> Tensor: ... + def allclose( + self, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, + ) -> _bool: ... + def amax( + self, dim: Union[_int, _size] = (), keepdim: _bool = False + ) -> Tensor: ... + def amin( + self, dim: Union[_int, _size] = (), keepdim: _bool = False + ) -> Tensor: ... + def aminmax( + self, *, dim: Optional[_int] = None, keepdim: _bool = False + ) -> torch.return_types.aminmax: ... def angle(self) -> Tensor: ... @overload def any(self) -> Tensor: ... @overload - def any(self, dim: _int, keepdim: _bool=False) -> Tensor: ... + def any(self, dim: _int, keepdim: _bool = False) -> Tensor: ... @overload - def any(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> Tensor: ... + def any( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> Tensor: ... def apply_(self, callable: Callable) -> Tensor: ... def arccos(self) -> Tensor: ... def arccos_(self) -> Tensor: ... @@ -1528,18 +1856,42 @@ class _TensorBase(metaclass=_TensorMeta): def arctan_(self) -> Tensor: ... def arctanh(self) -> Tensor: ... def arctanh_(self) -> Tensor: ... - def argmax(self, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... - def argmin(self, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... - @overload - def argsort(self, *, stable: _bool, dim: _int=-1, descending: _bool=False) -> Tensor: ... - @overload - def argsort(self, dim: _int=-1, descending: _bool=False) -> Tensor: ... - @overload - def argsort(self, dim: Union[str, ellipsis, None], descending: _bool=False) -> Tensor: ... + def argmax( + self, dim: Optional[_int] = None, keepdim: _bool = False + ) -> Tensor: ... + def argmin( + self, dim: Optional[_int] = None, keepdim: _bool = False + ) -> Tensor: ... + @overload + def argsort( + self, *, stable: _bool, dim: _int = -1, descending: _bool = False + ) -> Tensor: ... + @overload + def argsort(self, dim: _int = -1, descending: _bool = False) -> Tensor: ... + @overload + def argsort( + self, dim: Union[str, ellipsis, None], descending: _bool = False + ) -> Tensor: ... def argwhere(self) -> Tensor: ... - def as_strided(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]]=None) -> Tensor: ... - def as_strided_(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]]=None) -> Tensor: ... - def as_strided_scatter(self, src: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]]=None) -> Tensor: ... + def as_strided( + self, + size: Sequence[Union[_int, SymInt]], + stride: Sequence[Union[_int, SymInt]], + storage_offset: Optional[Union[_int, SymInt]] = None, + ) -> Tensor: ... + def as_strided_( + self, + size: Sequence[Union[_int, SymInt]], + stride: Sequence[Union[_int, SymInt]], + storage_offset: Optional[Union[_int, SymInt]] = None, + ) -> Tensor: ... + def as_strided_scatter( + self, + src: Tensor, + size: Sequence[Union[_int, SymInt]], + stride: Sequence[Union[_int, SymInt]], + storage_offset: Optional[Union[_int, SymInt]] = None, + ) -> Tensor: ... def as_subclass(self, cls: Type[S]) -> S: ... def asin(self) -> Tensor: ... def asin_(self) -> Tensor: ... @@ -1551,18 +1903,40 @@ class _TensorBase(metaclass=_TensorMeta): def atan_(self) -> Tensor: ... def atanh(self) -> Tensor: ... def atanh_(self) -> Tensor: ... - def baddbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - def baddbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... - @overload - def bernoulli(self, *, generator: Optional[Generator]=None) -> Tensor: ... - @overload - def bernoulli(self, p: _float, *, generator: Optional[Generator]=None) -> Tensor: ... - @overload - def bernoulli_(self, p: Tensor, *, generator: Optional[Generator]=None) -> Tensor: ... - @overload - def bernoulli_(self, p: _float=0.5, *, generator: Optional[Generator]=None) -> Tensor: ... + def baddbmm( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + ) -> Tensor: ... + def baddbmm_( + self, + batch1: Tensor, + batch2: Tensor, + *, + beta: Number = 1, + alpha: Number = 1, + ) -> Tensor: ... + @overload + def bernoulli(self, *, generator: Optional[Generator] = None) -> Tensor: ... + @overload + def bernoulli( + self, p: _float, *, generator: Optional[Generator] = None + ) -> Tensor: ... + @overload + def bernoulli_( + self, p: Tensor, *, generator: Optional[Generator] = None + ) -> Tensor: ... + @overload + def bernoulli_( + self, p: _float = 0.5, *, generator: Optional[Generator] = None + ) -> Tensor: ... def bfloat16(self) -> Tensor: ... - def bincount(self, weights: Optional[Tensor]=None, minlength: _int=0) -> Tensor: ... + def bincount( + self, weights: Optional[Tensor] = None, minlength: _int = 0 + ) -> Tensor: ... @overload def bitwise_and(self, other: Tensor) -> Tensor: ... @overload @@ -1612,24 +1986,42 @@ class _TensorBase(metaclass=_TensorMeta): @overload def broadcast_to(self, *size: _int) -> Tensor: ... def byte(self) -> Tensor: ... - def cauchy_(self, median: _float=0, sigma: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ... + def cauchy_( + self, + median: _float = 0, + sigma: _float = 1, + *, + generator: Optional[Generator] = None, + ) -> Tensor: ... def ccol_indices(self) -> Tensor: ... def ceil(self) -> Tensor: ... def ceil_(self) -> Tensor: ... - def chalf(self, *, memory_format: Optional[memory_format]=None) -> Tensor: ... + def chalf( + self, *, memory_format: Optional[memory_format] = None + ) -> Tensor: ... def char(self) -> Tensor: ... - def cholesky(self, upper: _bool=False) -> Tensor: ... - def cholesky_inverse(self, upper: _bool=False) -> Tensor: ... - def cholesky_solve(self, input2: Tensor, upper: _bool=False) -> Tensor: ... - def chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]: ... + def cholesky(self, upper: _bool = False) -> Tensor: ... + def cholesky_inverse(self, upper: _bool = False) -> Tensor: ... + def cholesky_solve( + self, input2: Tensor, upper: _bool = False + ) -> Tensor: ... + def chunk(self, chunks: _int, dim: _int = 0) -> List[Tensor]: ... @overload - def clamp(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ... + def clamp( + self, min: Optional[Tensor] = None, max: Optional[Tensor] = None + ) -> Tensor: ... @overload - def clamp(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ... + def clamp( + self, min: Optional[Number] = None, max: Optional[Number] = None + ) -> Tensor: ... @overload - def clamp_(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ... + def clamp_( + self, min: Optional[Tensor] = None, max: Optional[Tensor] = None + ) -> Tensor: ... @overload - def clamp_(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ... + def clamp_( + self, min: Optional[Number] = None, max: Optional[Number] = None + ) -> Tensor: ... @overload def clamp_max(self, max: Tensor) -> Tensor: ... @overload @@ -1647,21 +2039,31 @@ class _TensorBase(metaclass=_TensorMeta): @overload def clamp_min_(self, min: Number) -> Tensor: ... @overload - def clip(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ... + def clip( + self, min: Optional[Tensor] = None, max: Optional[Tensor] = None + ) -> Tensor: ... @overload - def clip(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ... + def clip( + self, min: Optional[Number] = None, max: Optional[Number] = None + ) -> Tensor: ... @overload - def clip_(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ... + def clip_( + self, min: Optional[Tensor] = None, max: Optional[Tensor] = None + ) -> Tensor: ... @overload - def clip_(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ... - def clone(self, *, memory_format: Optional[memory_format]=None) -> Tensor: ... + def clip_( + self, min: Optional[Number] = None, max: Optional[Number] = None + ) -> Tensor: ... + def clone( + self, *, memory_format: Optional[memory_format] = None + ) -> Tensor: ... def coalesce(self) -> Tensor: ... def col_indices(self) -> Tensor: ... def conj(self) -> Tensor: ... def conj_physical(self) -> Tensor: ... def conj_physical_(self) -> Tensor: ... def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ... - def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ... + def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ... @overload def copysign(self, other: Tensor) -> Tensor: ... @overload @@ -1676,40 +2078,70 @@ class _TensorBase(metaclass=_TensorMeta): def cosh(self) -> Tensor: ... def cosh_(self) -> Tensor: ... @overload - def count_nonzero(self, dim: Optional[_int]=None) -> Tensor: ... + def count_nonzero(self, dim: Optional[_int] = None) -> Tensor: ... @overload def count_nonzero(self, dim: _size) -> Tensor: ... @overload def count_nonzero(self, *dim: _int) -> Tensor: ... - def cov(self, *, correction: _int=1, fweights: Optional[Tensor]=None, aweights: Optional[Tensor]=None) -> Tensor: ... + def cov( + self, + *, + correction: _int = 1, + fweights: Optional[Tensor] = None, + aweights: Optional[Tensor] = None, + ) -> Tensor: ... def cpu(self) -> Tensor: ... - def cross(self, other: Tensor, dim: Optional[_int]=None) -> Tensor: ... + def cross(self, other: Tensor, dim: Optional[_int] = None) -> Tensor: ... def crow_indices(self) -> Tensor: ... - def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ... + def cuda( + self, + device: Optional[Union[_device, _int, str]] = None, + non_blocking: _bool = False, + ) -> Tensor: ... @overload def cummax(self, dim: _int) -> torch.return_types.cummax: ... @overload - def cummax(self, dim: Union[str, ellipsis, None]) -> torch.return_types.cummax: ... + def cummax( + self, dim: Union[str, ellipsis, None] + ) -> torch.return_types.cummax: ... @overload def cummin(self, dim: _int) -> torch.return_types.cummin: ... @overload - def cummin(self, dim: Union[str, ellipsis, None]) -> torch.return_types.cummin: ... + def cummin( + self, dim: Union[str, ellipsis, None] + ) -> torch.return_types.cummin: ... @overload - def cumprod(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumprod( + self, dim: _int, *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumprod(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumprod( + self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumprod_(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumprod_( + self, dim: _int, *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumprod_(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumprod_( + self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumsum(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumsum( + self, dim: _int, *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumsum(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumsum( + self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumsum_(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumsum_( + self, dim: _int, *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def cumsum_(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ... + def cumsum_( + self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None + ) -> Tensor: ... def data_ptr(self) -> _int: ... def deg2rad(self) -> Tensor: ... def deg2rad_(self) -> Tensor: ... @@ -1718,35 +2150,72 @@ class _TensorBase(metaclass=_TensorMeta): def det(self) -> Tensor: ... def detach(self) -> Tensor: ... def detach_(self) -> Tensor: ... - def diag(self, diagonal: _int=0) -> Tensor: ... - def diag_embed(self, offset: _int=0, dim1: _int=-2, dim2: _int=-1) -> Tensor: ... - def diagflat(self, offset: _int=0) -> Tensor: ... + def diag(self, diagonal: _int = 0) -> Tensor: ... + def diag_embed( + self, offset: _int = 0, dim1: _int = -2, dim2: _int = -1 + ) -> Tensor: ... + def diagflat(self, offset: _int = 0) -> Tensor: ... @overload - def diagonal(self, *, outdim: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None], dim2: Union[str, ellipsis, None], offset: _int=0) -> Tensor: ... - @overload - def diagonal(self, offset: _int=0, dim1: _int=0, dim2: _int=1) -> Tensor: ... - def diagonal_scatter(self, src: Tensor, offset: _int=0, dim1: _int=0, dim2: _int=1) -> Tensor: ... - def diff(self, n: _int=1, dim: _int=-1, prepend: Optional[Tensor]=None, append: Optional[Tensor]=None) -> Tensor: ... + def diagonal( + self, + *, + outdim: Union[str, ellipsis, None], + dim1: Union[str, ellipsis, None], + dim2: Union[str, ellipsis, None], + offset: _int = 0, + ) -> Tensor: ... + @overload + def diagonal( + self, offset: _int = 0, dim1: _int = 0, dim2: _int = 1 + ) -> Tensor: ... + def diagonal_scatter( + self, src: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1 + ) -> Tensor: ... + def diff( + self, + n: _int = 1, + dim: _int = -1, + prepend: Optional[Tensor] = None, + append: Optional[Tensor] = None, + ) -> Tensor: ... def digamma(self) -> Tensor: ... def digamma_(self) -> Tensor: ... def dim(self) -> _int: ... - def dist(self, other: Tensor, p: Number=2) -> Tensor: ... - def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ... - def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ... + def dist(self, other: Tensor, p: Number = 2) -> Tensor: ... + def div( + self, + other: Union[Tensor, Number], + *, + rounding_mode: Optional[str] = None, + ) -> Tensor: ... + def div_( + self, + other: Union[Tensor, Number], + *, + rounding_mode: Optional[str] = None, + ) -> Tensor: ... @overload def divide(self, other: Tensor) -> Tensor: ... @overload - def divide(self, other: Tensor, *, rounding_mode: Optional[str]) -> Tensor: ... + def divide( + self, other: Tensor, *, rounding_mode: Optional[str] + ) -> Tensor: ... @overload - def divide(self, other: Number, *, rounding_mode: Optional[str]) -> Tensor: ... + def divide( + self, other: Number, *, rounding_mode: Optional[str] + ) -> Tensor: ... @overload def divide(self, other: Number) -> Tensor: ... @overload def divide_(self, other: Tensor) -> Tensor: ... @overload - def divide_(self, other: Tensor, *, rounding_mode: Optional[str]) -> Tensor: ... + def divide_( + self, other: Tensor, *, rounding_mode: Optional[str] + ) -> Tensor: ... @overload - def divide_(self, other: Number, *, rounding_mode: Optional[str]) -> Tensor: ... + def divide_( + self, other: Number, *, rounding_mode: Optional[str] + ) -> Tensor: ... @overload def divide_(self, other: Number) -> Tensor: ... def dot(self, tensor: Tensor) -> Tensor: ... @@ -1778,28 +2247,48 @@ class _TensorBase(metaclass=_TensorMeta): def exp2_(self) -> Tensor: ... def exp_(self) -> Tensor: ... @overload - def expand(self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool=False) -> Tensor: ... + def expand( + self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool = False + ) -> Tensor: ... @overload - def expand(self, *size: _int, implicit: _bool=False) -> Tensor: ... + def expand(self, *size: _int, implicit: _bool = False) -> Tensor: ... def expand_as(self, other: Tensor) -> Tensor: ... def expm1(self) -> Tensor: ... def expm1_(self) -> Tensor: ... - def exponential_(self, lambd: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ... + def exponential_( + self, lambd: _float = 1, *, generator: Optional[Generator] = None + ) -> Tensor: ... @overload def fill_(self, value: Tensor) -> Tensor: ... @overload def fill_(self, value: Number) -> Tensor: ... - def fill_diagonal_(self, fill_value: Number, wrap: _bool=False) -> Tensor: ... + def fill_diagonal_( + self, fill_value: Number, wrap: _bool = False + ) -> Tensor: ... def fix(self) -> Tensor: ... def fix_(self) -> Tensor: ... @overload - def flatten(self, start_dim: _int=0, end_dim: _int=-1) -> Tensor: ... + def flatten(self, start_dim: _int = 0, end_dim: _int = -1) -> Tensor: ... @overload - def flatten(self, start_dim: _int, end_dim: _int, out_dim: Union[str, ellipsis, None]) -> Tensor: ... + def flatten( + self, + start_dim: _int, + end_dim: _int, + out_dim: Union[str, ellipsis, None], + ) -> Tensor: ... @overload - def flatten(self, start_dim: Union[str, ellipsis, None], end_dim: Union[str, ellipsis, None], out_dim: Union[str, ellipsis, None]) -> Tensor: ... + def flatten( + self, + start_dim: Union[str, ellipsis, None], + end_dim: Union[str, ellipsis, None], + out_dim: Union[str, ellipsis, None], + ) -> Tensor: ... @overload - def flatten(self, dims: Sequence[Union[str, ellipsis, None]], out_dim: Union[str, ellipsis, None]) -> Tensor: ... + def flatten( + self, + dims: Sequence[Union[str, ellipsis, None]], + out_dim: Union[str, ellipsis, None], + ) -> Tensor: ... @overload def flip(self, dims: _size) -> Tensor: ... @overload @@ -1817,8 +2306,15 @@ class _TensorBase(metaclass=_TensorMeta): def float_power_(self, exponent: Number) -> Tensor: ... def floor(self) -> Tensor: ... def floor_(self) -> Tensor: ... - def floor_divide(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor]=None) -> Tensor: ... - def floor_divide_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: ... + def floor_divide( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + out: Optional[Tensor] = None, + ) -> Tensor: ... + def floor_divide_( + self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat] + ) -> Tensor: ... def fmax(self, other: Tensor) -> Tensor: ... def fmin(self, other: Tensor) -> Tensor: ... @overload @@ -1833,9 +2329,17 @@ class _TensorBase(metaclass=_TensorMeta): def frac_(self) -> Tensor: ... def frexp(self) -> torch.return_types.frexp: ... @overload - def gather(self, dim: _int, index: Tensor, *, sparse_grad: _bool=False) -> Tensor: ... + def gather( + self, dim: _int, index: Tensor, *, sparse_grad: _bool = False + ) -> Tensor: ... @overload - def gather(self, dim: Union[str, ellipsis, None], index: Tensor, *, sparse_grad: _bool=False) -> Tensor: ... + def gather( + self, + dim: Union[str, ellipsis, None], + index: Tensor, + *, + sparse_grad: _bool = False, + ) -> Tensor: ... def gcd(self, other: Tensor) -> Tensor: ... def gcd_(self, other: Tensor) -> Tensor: ... @overload @@ -1846,7 +2350,9 @@ class _TensorBase(metaclass=_TensorMeta): def ge_(self, other: Tensor) -> Tensor: ... @overload def ge_(self, other: Number) -> Tensor: ... - def geometric_(self, p: _float, *, generator: Optional[Generator]=None) -> Tensor: ... + def geometric_( + self, p: _float, *, generator: Optional[Generator] = None + ) -> Tensor: ... def geqrf(self) -> torch.return_types.geqrf: ... def ger(self, vec2: Tensor) -> Tensor: ... def get_device(self) -> _int: ... @@ -1875,15 +2381,30 @@ class _TensorBase(metaclass=_TensorMeta): @overload def gt_(self, other: Number) -> Tensor: ... def half(self) -> Tensor: ... - def hardshrink(self, lambd: Number=0.5) -> Tensor: ... + def hardshrink(self, lambd: Number = 0.5) -> Tensor: ... def has_names(self) -> _bool: ... def heaviside(self, values: Tensor) -> Tensor: ... def heaviside_(self, values: Tensor) -> Tensor: ... - def histc(self, bins: _int=100, min: Number=0, max: Number=0) -> Tensor: ... + def histc( + self, bins: _int = 100, min: Number = 0, max: Number = 0 + ) -> Tensor: ... @overload - def histogram(self, bins: Tensor, *, weight: Optional[Tensor]=None, density: _bool=False) -> torch.return_types.histogram: ... + def histogram( + self, + bins: Tensor, + *, + weight: Optional[Tensor] = None, + density: _bool = False, + ) -> torch.return_types.histogram: ... @overload - def histogram(self, bins: _int=100, *, range: Optional[Sequence[_float]]=None, weight: Optional[Tensor]=None, density: _bool=False) -> torch.return_types.histogram: ... + def histogram( + self, + bins: _int = 100, + *, + range: Optional[Sequence[_float]] = None, + weight: Optional[Tensor] = None, + density: _bool = False, + ) -> torch.return_types.histogram: ... @overload def hsplit(self, sections: _int) -> List[Tensor]: ... @overload @@ -1899,42 +2420,101 @@ class _TensorBase(metaclass=_TensorMeta): def igammac(self, other: Tensor) -> Tensor: ... def igammac_(self, other: Tensor) -> Tensor: ... @overload - def index_add(self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number=1) -> Tensor: ... + def index_add( + self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number = 1 + ) -> Tensor: ... @overload - def index_add(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor, *, alpha: Number=1) -> Tensor: ... - def index_add_(self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number=1) -> Tensor: ... - @overload - def index_copy(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ... - @overload - def index_copy(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: ... - @overload - def index_copy_(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ... - @overload - def index_copy_(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: ... + def index_add( + self, + dim: Union[str, ellipsis, None], + index: Tensor, + source: Tensor, + *, + alpha: Number = 1, + ) -> Tensor: ... + def index_add_( + self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number = 1 + ) -> Tensor: ... + @overload + def index_copy( + self, dim: _int, index: Tensor, source: Tensor + ) -> Tensor: ... + @overload + def index_copy( + self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor + ) -> Tensor: ... + @overload + def index_copy_( + self, dim: _int, index: Tensor, source: Tensor + ) -> Tensor: ... + @overload + def index_copy_( + self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor + ) -> Tensor: ... @overload def index_fill(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: ... @overload - def index_fill(self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: ... + def index_fill( + self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor + ) -> Tensor: ... @overload def index_fill(self, dim: _int, index: Tensor, value: Number) -> Tensor: ... @overload - def index_fill(self, dim: Union[str, ellipsis, None], index: Tensor, value: Number) -> Tensor: ... + def index_fill( + self, dim: Union[str, ellipsis, None], index: Tensor, value: Number + ) -> Tensor: ... @overload - def index_fill_(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: ... + def index_fill_( + self, dim: _int, index: Tensor, value: Tensor + ) -> Tensor: ... @overload - def index_fill_(self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: ... + def index_fill_( + self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor + ) -> Tensor: ... @overload - def index_fill_(self, dim: _int, index: Tensor, value: Number) -> Tensor: ... + def index_fill_( + self, dim: _int, index: Tensor, value: Number + ) -> Tensor: ... @overload - def index_fill_(self, dim: Union[str, ellipsis, None], index: Tensor, value: Number) -> Tensor: ... - def index_put(self, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool=False) -> Tensor: ... - def index_put_(self, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool=False) -> Tensor: ... - def index_reduce(self, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ... - def index_reduce_(self, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ... + def index_fill_( + self, dim: Union[str, ellipsis, None], index: Tensor, value: Number + ) -> Tensor: ... + def index_put( + self, + indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], + values: Tensor, + accumulate: _bool = False, + ) -> Tensor: ... + def index_put_( + self, + indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], + values: Tensor, + accumulate: _bool = False, + ) -> Tensor: ... + def index_reduce( + self, + dim: _int, + index: Tensor, + source: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: ... + def index_reduce_( + self, + dim: _int, + index: Tensor, + source: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: ... @overload def index_select(self, dim: _int, index: Tensor) -> Tensor: ... @overload - def index_select(self, dim: Union[str, ellipsis, None], index: Tensor) -> Tensor: ... + def index_select( + self, dim: Union[str, ellipsis, None], index: Tensor + ) -> Tensor: ... def indices(self) -> Tensor: ... def inner(self, other: Tensor) -> Tensor: ... def int(self) -> Tensor: ... @@ -1957,7 +2537,9 @@ class _TensorBase(metaclass=_TensorMeta): is_nested: _bool def is_nonzero(self) -> _bool: ... is_ort: _bool - def is_pinned(self, device: Optional[Union[_device, str, None]]=None) -> _bool: ... + def is_pinned( + self, device: Optional[Union[_device, str, None]] = None + ) -> _bool: ... is_quantized: _bool def is_same_size(self, other: Tensor) -> _bool: ... def is_set_to(self, tensor: Tensor) -> _bool: ... @@ -1965,20 +2547,41 @@ class _TensorBase(metaclass=_TensorMeta): is_sparse: _bool is_sparse_csr: _bool is_vulkan: _bool - def isclose(self, other: Tensor, rtol: _float=1e-05, atol: _float=1e-08, equal_nan: _bool=False) -> Tensor: ... + def isclose( + self, + other: Tensor, + rtol: _float = 1e-05, + atol: _float = 1e-08, + equal_nan: _bool = False, + ) -> Tensor: ... def isfinite(self) -> Tensor: ... def isinf(self) -> Tensor: ... def isnan(self) -> Tensor: ... def isneginf(self) -> Tensor: ... def isposinf(self) -> Tensor: ... def isreal(self) -> Tensor: ... - def istft(self, n_fft: _int, hop_length: Optional[_int]=None, win_length: Optional[_int]=None, window: Optional[Tensor]=None, center: _bool=True, normalized: _bool=False, onesided: Optional[_bool]=None, length: Optional[_int]=None, return_complex: _bool=False) -> Tensor: ... + def istft( + self, + n_fft: _int, + hop_length: Optional[_int] = None, + win_length: Optional[_int] = None, + window: Optional[Tensor] = None, + center: _bool = True, + normalized: _bool = False, + onesided: Optional[_bool] = None, + length: Optional[_int] = None, + return_complex: _bool = False, + ) -> Tensor: ... def item(self) -> Number: ... def kron(self, other: Tensor) -> Tensor: ... @overload - def kthvalue(self, k: _int, dim: _int=-1, keepdim: _bool=False) -> torch.return_types.kthvalue: ... + def kthvalue( + self, k: _int, dim: _int = -1, keepdim: _bool = False + ) -> torch.return_types.kthvalue: ... @overload - def kthvalue(self, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.kthvalue: ... + def kthvalue( + self, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> torch.return_types.kthvalue: ... def lcm(self, other: Tensor) -> Tensor: ... def lcm_(self, other: Tensor) -> Tensor: ... def ldexp(self, other: Tensor) -> Tensor: ... @@ -2025,11 +2628,21 @@ class _TensorBase(metaclass=_TensorMeta): def log2(self) -> Tensor: ... def log2_(self) -> Tensor: ... def log_(self) -> Tensor: ... - def log_normal_(self, mean: _float=1, std: _float=2, *, generator: Optional[Generator]=None) -> Tensor: ... - @overload - def log_softmax(self, dim: _int, dtype: Optional[_dtype]=None) -> Tensor: ... - @overload - def log_softmax(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ... + def log_normal_( + self, + mean: _float = 1, + std: _float = 2, + *, + generator: Optional[Generator] = None, + ) -> Tensor: ... + @overload + def log_softmax( + self, dim: _int, dtype: Optional[_dtype] = None + ) -> Tensor: ... + @overload + def log_softmax( + self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None + ) -> Tensor: ... def logaddexp(self, other: Tensor) -> Tensor: ... def logaddexp2(self, other: Tensor) -> Tensor: ... @overload @@ -2045,12 +2658,16 @@ class _TensorBase(metaclass=_TensorMeta): def logical_or_(self, other: Tensor) -> Tensor: ... def logical_xor(self, other: Tensor) -> Tensor: ... def logical_xor_(self, other: Tensor) -> Tensor: ... - def logit(self, eps: Optional[_float]=None) -> Tensor: ... - def logit_(self, eps: Optional[_float]=None) -> Tensor: ... + def logit(self, eps: Optional[_float] = None) -> Tensor: ... + def logit_(self, eps: Optional[_float] = None) -> Tensor: ... @overload - def logsumexp(self, dim: Union[_int, _size], keepdim: _bool=False) -> Tensor: ... + def logsumexp( + self, dim: Union[_int, _size], keepdim: _bool = False + ) -> Tensor: ... @overload - def logsumexp(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool=False) -> Tensor: ... + def logsumexp( + self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False + ) -> Tensor: ... def long(self) -> Tensor: ... @overload def lt(self, other: Tensor) -> Tensor: ... @@ -2082,36 +2699,64 @@ class _TensorBase(metaclass=_TensorMeta): @overload def max(self, other: Tensor) -> Tensor: ... @overload - def max(self, dim: _int, keepdim: _bool=False) -> torch.return_types.max: ... + def max( + self, dim: _int, keepdim: _bool = False + ) -> torch.return_types.max: ... @overload - def max(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.max: ... + def max( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> torch.return_types.max: ... def maximum(self, other: Tensor) -> Tensor: ... @overload - def mean(self, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def mean(self, *, dtype: Optional[_dtype] = None) -> Tensor: ... @overload - def mean(self, dim: Optional[Union[_int, _size]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def mean( + self, + dim: Optional[Union[_int, _size]], + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... @overload - def mean(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def mean( + self, + dim: Sequence[Union[str, ellipsis, None]], + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... @overload def median(self) -> Tensor: ... @overload - def median(self, dim: _int, keepdim: _bool=False) -> torch.return_types.median: ... + def median( + self, dim: _int, keepdim: _bool = False + ) -> torch.return_types.median: ... @overload - def median(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.median: ... + def median( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> torch.return_types.median: ... @overload def min(self) -> Tensor: ... @overload def min(self, other: Tensor) -> Tensor: ... @overload - def min(self, dim: _int, keepdim: _bool=False) -> torch.return_types.min: ... + def min( + self, dim: _int, keepdim: _bool = False + ) -> torch.return_types.min: ... @overload - def min(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.min: ... + def min( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> torch.return_types.min: ... def minimum(self, other: Tensor) -> Tensor: ... def mm(self, mat2: Tensor) -> Tensor: ... @overload - def mode(self, dim: _int=-1, keepdim: _bool=False) -> torch.return_types.mode: ... + def mode( + self, dim: _int = -1, keepdim: _bool = False + ) -> torch.return_types.mode: ... @overload - def mode(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.mode: ... + def mode( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> torch.return_types.mode: ... @overload def moveaxis(self, source: _int, destination: _int) -> Tensor: ... @overload @@ -2121,9 +2766,22 @@ class _TensorBase(metaclass=_TensorMeta): @overload def movedim(self, source: _size, destination: _size) -> Tensor: ... def msort(self) -> Tensor: ... - def mul(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor]=None) -> Tensor: ... - def mul_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: ... - def multinomial(self, num_samples: _int, replacement: _bool=False, *, generator: Optional[Generator]=None) -> Tensor: ... + def mul( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + out: Optional[Tensor] = None, + ) -> Tensor: ... + def mul_( + self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat] + ) -> Tensor: ... + def multinomial( + self, + num_samples: _int, + replacement: _bool = False, + *, + generator: Optional[Generator] = None, + ) -> Tensor: ... @overload def multiply(self, other: Tensor) -> Tensor: ... @overload @@ -2135,25 +2793,71 @@ class _TensorBase(metaclass=_TensorMeta): def mv(self, vec: Tensor) -> Tensor: ... def mvlgamma(self, p: _int) -> Tensor: ... def mvlgamma_(self, p: _int) -> Tensor: ... - def nan_to_num(self, nan: Optional[_float]=None, posinf: Optional[_float]=None, neginf: Optional[_float]=None) -> Tensor: ... - def nan_to_num_(self, nan: Optional[_float]=None, posinf: Optional[_float]=None, neginf: Optional[_float]=None) -> Tensor: ... - def nanmean(self, dim: Optional[Union[_int, _size]]=None, keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def nan_to_num( + self, + nan: Optional[_float] = None, + posinf: Optional[_float] = None, + neginf: Optional[_float] = None, + ) -> Tensor: ... + def nan_to_num_( + self, + nan: Optional[_float] = None, + posinf: Optional[_float] = None, + neginf: Optional[_float] = None, + ) -> Tensor: ... + def nanmean( + self, + dim: Optional[Union[_int, _size]] = None, + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... @overload def nanmedian(self) -> Tensor: ... @overload - def nanmedian(self, dim: _int, keepdim: _bool=False) -> torch.return_types.nanmedian: ... - @overload - def nanmedian(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.nanmedian: ... - @overload - def nanquantile(self, q: Tensor, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ... + def nanmedian( + self, dim: _int, keepdim: _bool = False + ) -> torch.return_types.nanmedian: ... @overload - def nanquantile(self, q: _float, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ... - def nansum(self, dim: Optional[Union[_int, _size]]=None, keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def nanmedian( + self, dim: Union[str, ellipsis, None], keepdim: _bool = False + ) -> torch.return_types.nanmedian: ... @overload - def narrow(self, dim: _int, start: Tensor, length: Union[_int, SymInt]) -> Tensor: ... - @overload - def narrow(self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: ... - def narrow_copy(self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: ... + def nanquantile( + self, + q: Tensor, + dim: Optional[_int] = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: ... + @overload + def nanquantile( + self, + q: _float, + dim: Optional[_int] = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: ... + def nansum( + self, + dim: Optional[Union[_int, _size]] = None, + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... + @overload + def narrow( + self, dim: _int, start: Tensor, length: Union[_int, SymInt] + ) -> Tensor: ... + @overload + def narrow( + self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt] + ) -> Tensor: ... + def narrow_copy( + self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt] + ) -> Tensor: ... def ndimension(self) -> _int: ... @overload def ne(self, other: Tensor) -> Tensor: ... @@ -2169,37 +2873,126 @@ class _TensorBase(metaclass=_TensorMeta): def negative_(self) -> Tensor: ... def nelement(self) -> _int: ... @overload - def new(self, *args: Any, device: Device=None) ->Tensor: ... + def new(self, *args: Any, device: Device = None) -> Tensor: ... @overload def new(self, storage: Storage) -> Tensor: ... @overload def new(self, other: Tensor) -> Tensor: ... @overload - def new(self, size: _size, *, device: Device=None) -> Tensor: ... - @overload - def new_empty(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... - @overload - def new_empty(self, *size: _int, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... - def new_empty_strided(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... - def new_full(self, size: Sequence[Union[_int, SymInt]], fill_value: Number, *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... - @overload - def new_ones(self, size: _size, dtype: Optional[_dtype]=None, device: Device=None, requires_grad: _bool=False) -> Tensor: ... + def new(self, size: _size, *, device: Device = None) -> Tensor: ... @overload - def new_ones(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... - @overload - def new_ones(self, *size: _int, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... - def new_tensor(self, data: Any, dtype: Optional[_dtype]=None, device: Device=None, requires_grad: _bool=False) -> Tensor: ... + def new_empty( + self, + size: Sequence[Union[_int, SymInt]], + *, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + @overload + def new_empty( + self, + *size: _int, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + def new_empty_strided( + self, + size: Sequence[Union[_int, SymInt]], + stride: Sequence[Union[_int, SymInt]], + *, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + def new_full( + self, + size: Sequence[Union[_int, SymInt]], + fill_value: Number, + *, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + @overload + def new_ones( + self, + size: _size, + dtype: Optional[_dtype] = None, + device: Device = None, + requires_grad: _bool = False, + ) -> Tensor: ... @overload - def new_zeros(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... + def new_ones( + self, + size: Sequence[Union[_int, SymInt]], + *, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + @overload + def new_ones( + self, + *size: _int, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + def new_tensor( + self, + data: Any, + dtype: Optional[_dtype] = None, + device: Device = None, + requires_grad: _bool = False, + ) -> Tensor: ... @overload - def new_zeros(self, *size: _int, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ... + def new_zeros( + self, + size: Sequence[Union[_int, SymInt]], + *, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... + @overload + def new_zeros( + self, + *size: _int, + dtype: Optional[_dtype] = None, + layout: Optional[_layout] = None, + device: Optional[Union[_device, str, None]] = None, + pin_memory: Optional[_bool] = False, + requires_grad: Optional[_bool] = False, + ) -> Tensor: ... def nextafter(self, other: Tensor) -> Tensor: ... def nextafter_(self, other: Tensor) -> Tensor: ... @overload - def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ... + def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ... @overload def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ... - def normal_(self, mean: _float=0, std: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ... + def normal_( + self, + mean: _float = 0, + std: _float = 1, + *, + generator: Optional[Generator] = None, + ) -> Tensor: ... @overload def not_equal(self, other: Tensor) -> Tensor: ... @overload @@ -2209,16 +3002,24 @@ class _TensorBase(metaclass=_TensorMeta): @overload def not_equal_(self, other: Number) -> Tensor: ... def numel(self) -> _int: ... - def numpy(self, *, force: _bool=False) -> Any: ... + def numpy(self, *, force: _bool = False) -> Any: ... def orgqr(self, input2: Tensor) -> Tensor: ... - def ormqr(self, input2: Tensor, input3: Tensor, left: _bool=True, transpose: _bool=False) -> Tensor: ... + def ormqr( + self, + input2: Tensor, + input3: Tensor, + left: _bool = True, + transpose: _bool = False, + ) -> Tensor: ... def outer(self, vec2: Tensor) -> Tensor: ... @overload def permute(self, dims: _size) -> Tensor: ... @overload def permute(self, *dims: _int) -> Tensor: ... - def pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor: ... - def pinverse(self, rcond: _float=1e-15) -> Tensor: ... + def pin_memory( + self, device: Optional[Union[_device, str, None]] = None + ) -> Tensor: ... + def pinverse(self, rcond: _float = 1e-15) -> Tensor: ... def polygamma(self, n: _int) -> Tensor: ... def polygamma_(self, n: _int) -> Tensor: ... def positive(self) -> Tensor: ... @@ -2232,37 +3033,77 @@ class _TensorBase(metaclass=_TensorMeta): def pow_(self, exponent: Number) -> Tensor: ... def prelu(self, weight: Tensor) -> Tensor: ... @overload - def prod(self, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def prod(self, *, dtype: Optional[_dtype] = None) -> Tensor: ... @overload - def prod(self, dim: _int, keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def prod( + self, + dim: _int, + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... @overload - def prod(self, dim: Union[str, ellipsis, None], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... - def put(self, index: Tensor, source: Tensor, accumulate: _bool=False) -> Tensor: ... - def put_(self, index: Tensor, source: Tensor, accumulate: _bool=False) -> Tensor: ... + def prod( + self, + dim: Union[str, ellipsis, None], + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... + def put( + self, index: Tensor, source: Tensor, accumulate: _bool = False + ) -> Tensor: ... + def put_( + self, index: Tensor, source: Tensor, accumulate: _bool = False + ) -> Tensor: ... def q_per_channel_axis(self) -> _int: ... def q_per_channel_scales(self) -> Tensor: ... def q_per_channel_zero_points(self) -> Tensor: ... def q_scale(self) -> _float: ... def q_zero_point(self) -> _int: ... - def qr(self, some: _bool=True) -> torch.return_types.qr: ... + def qr(self, some: _bool = True) -> torch.return_types.qr: ... def qscheme(self) -> _qscheme: ... @overload - def quantile(self, q: Tensor, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ... - @overload - def quantile(self, q: _float, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ... + def quantile( + self, + q: Tensor, + dim: Optional[_int] = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: ... + @overload + def quantile( + self, + q: _float, + dim: Optional[_int] = None, + keepdim: _bool = False, + *, + interpolation: str = "linear", + ) -> Tensor: ... def rad2deg(self) -> Tensor: ... def rad2deg_(self) -> Tensor: ... @overload - def random_(self, *, generator: Optional[Generator]=None) -> Tensor: ... - @overload - def random_(self, from_: _int, to: Optional[_int], *, generator: Optional[Generator]=None) -> Tensor: ... + def random_(self, *, generator: Optional[Generator] = None) -> Tensor: ... @overload - def random_(self, to: _int, *, generator: Optional[Generator]=None) -> Tensor: ... + def random_( + self, + from_: _int, + to: Optional[_int], + *, + generator: Optional[Generator] = None, + ) -> Tensor: ... + @overload + def random_( + self, to: _int, *, generator: Optional[Generator] = None + ) -> Tensor: ... def ravel(self) -> Tensor: ... def reciprocal(self) -> Tensor: ... def reciprocal_(self) -> Tensor: ... def record_stream(self, s: Stream) -> None: ... - def refine_names(self, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ... + def refine_names( + self, names: Sequence[Union[str, ellipsis, None]] + ) -> Tensor: ... def relu(self) -> Tensor: ... def relu_(self) -> Tensor: ... @overload @@ -2273,8 +3114,12 @@ class _TensorBase(metaclass=_TensorMeta): def remainder_(self, other: Tensor) -> Tensor: ... @overload def remainder_(self, other: Number) -> Tensor: ... - def rename(self, names: Optional[Sequence[Union[str, ellipsis, None]]]) -> Tensor: ... - def rename_(self, names: Optional[Sequence[Union[str, ellipsis, None]]]) -> Tensor: ... + def rename( + self, names: Optional[Sequence[Union[str, ellipsis, None]]] + ) -> Tensor: ... + def rename_( + self, names: Optional[Sequence[Union[str, ellipsis, None]]] + ) -> Tensor: ... def renorm(self, p: Number, dim: _int, maxnorm: Number) -> Tensor: ... def renorm_(self, p: Number, dim: _int, maxnorm: Number) -> Tensor: ... @overload @@ -2282,26 +3127,52 @@ class _TensorBase(metaclass=_TensorMeta): @overload def repeat(self, *repeats: _int) -> Tensor: ... @overload - def repeat_interleave(self, repeats: Tensor, dim: Optional[_int]=None, *, output_size: Optional[_int]=None) -> Tensor: ... + def repeat_interleave( + self, + repeats: Tensor, + dim: Optional[_int] = None, + *, + output_size: Optional[_int] = None, + ) -> Tensor: ... @overload - def repeat_interleave(self, repeats: Union[_int, SymInt], dim: Optional[_int]=None, *, output_size: Optional[_int]=None) -> Tensor: ... - def requires_grad_(self, mode: _bool=True) -> Tensor: ... + def repeat_interleave( + self, + repeats: Union[_int, SymInt], + dim: Optional[_int] = None, + *, + output_size: Optional[_int] = None, + ) -> Tensor: ... + def requires_grad_(self, mode: _bool = True) -> Tensor: ... @overload def reshape(self, shape: Sequence[Union[_int, SymInt]]) -> Tensor: ... @overload def reshape(self, *shape: _int) -> Tensor: ... def reshape_as(self, other: Tensor) -> Tensor: ... @overload - def resize_(self, size: Sequence[Union[_int, SymInt]], *, memory_format: Optional[memory_format]=None) -> Tensor: ... - @overload - def resize_(self, *size: _int, memory_format: Optional[memory_format]=None) -> Tensor: ... - def resize_as_(self, the_template: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ... + def resize_( + self, + size: Sequence[Union[_int, SymInt]], + *, + memory_format: Optional[memory_format] = None, + ) -> Tensor: ... + @overload + def resize_( + self, *size: _int, memory_format: Optional[memory_format] = None + ) -> Tensor: ... + def resize_as_( + self, + the_template: Tensor, + *, + memory_format: Optional[memory_format] = None, + ) -> Tensor: ... def resize_as_sparse_(self, the_template: Tensor) -> Tensor: ... def resolve_conj(self) -> Tensor: ... def resolve_neg(self) -> Tensor: ... def retain_grad(self) -> None: ... - def roll(self, shifts: Union[_int, _size], dims: Union[_int, _size]=()) -> Tensor: ... - def rot90(self, k: _int=1, dims: _size=(0,1)) -> Tensor: ... + def roll( + self, shifts: Union[_int, _size], dims: Union[_int, _size] = () + ) -> Tensor: ... + def rot90(self, k: _int = 1, dims: _size = (0, 1)) -> Tensor: ... @overload def round(self) -> Tensor: ... @overload @@ -2316,37 +3187,77 @@ class _TensorBase(metaclass=_TensorMeta): @overload def scatter(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ... @overload - def scatter(self, dim: _int, index: Tensor, src: Tensor, *, reduce: str) -> Tensor: ... + def scatter( + self, dim: _int, index: Tensor, src: Tensor, *, reduce: str + ) -> Tensor: ... @overload - def scatter(self, dim: _int, index: Tensor, value: Number, *, reduce: str) -> Tensor: ... + def scatter( + self, dim: _int, index: Tensor, value: Number, *, reduce: str + ) -> Tensor: ... @overload - def scatter(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ... + def scatter( + self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor + ) -> Tensor: ... @overload def scatter(self, dim: _int, index: Tensor, value: Number) -> Tensor: ... @overload - def scatter(self, dim: Union[str, ellipsis, None], index: Tensor, value: Number) -> Tensor: ... + def scatter( + self, dim: Union[str, ellipsis, None], index: Tensor, value: Number + ) -> Tensor: ... @overload def scatter_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ... @overload - def scatter_(self, dim: _int, index: Tensor, src: Tensor, *, reduce: str) -> Tensor: ... + def scatter_( + self, dim: _int, index: Tensor, src: Tensor, *, reduce: str + ) -> Tensor: ... @overload - def scatter_(self, dim: _int, index: Tensor, value: Number, *, reduce: str) -> Tensor: ... + def scatter_( + self, dim: _int, index: Tensor, value: Number, *, reduce: str + ) -> Tensor: ... @overload def scatter_(self, dim: _int, index: Tensor, value: Number) -> Tensor: ... @overload def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ... @overload - def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ... + def scatter_add( + self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor + ) -> Tensor: ... def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ... - def scatter_reduce(self, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ... - def scatter_reduce_(self, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ... + def scatter_reduce( + self, + dim: _int, + index: Tensor, + src: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: ... + def scatter_reduce_( + self, + dim: _int, + index: Tensor, + src: Tensor, + reduce: str, + *, + include_self: _bool = True, + ) -> Tensor: ... @overload def select(self, dim: _int, index: Union[_int, SymInt]) -> Tensor: ... @overload - def select(self, dim: Union[str, ellipsis, None], index: _int) -> Tensor: ... - def select_scatter(self, src: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: ... + def select( + self, dim: Union[str, ellipsis, None], index: _int + ) -> Tensor: ... + def select_scatter( + self, src: Tensor, dim: _int, index: Union[_int, SymInt] + ) -> Tensor: ... @overload - def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ... + def set_( + self, + storage: Union[Storage, TypedStorage], + offset: _int, + size: _size, + stride: _size, + ) -> Tensor: ... @overload def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ... def sgn(self) -> Tensor: ... @@ -2367,30 +3278,63 @@ class _TensorBase(metaclass=_TensorMeta): def size(self) -> Size: ... @overload def size(self, dim: _int) -> _int: ... - def slice_scatter(self, src: Tensor, dim: _int=0, start: Optional[Union[_int, SymInt]]=None, end: Optional[Union[_int, SymInt]]=None, step: Union[_int, SymInt]=1) -> Tensor: ... + def slice_scatter( + self, + src: Tensor, + dim: _int = 0, + start: Optional[Union[_int, SymInt]] = None, + end: Optional[Union[_int, SymInt]] = None, + step: Union[_int, SymInt] = 1, + ) -> Tensor: ... def slogdet(self) -> torch.return_types.slogdet: ... def smm(self, mat2: Tensor) -> Tensor: ... @overload - def softmax(self, dim: _int, dtype: Optional[_dtype]=None) -> Tensor: ... - @overload - def softmax(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ... + def softmax(self, dim: _int, dtype: Optional[_dtype] = None) -> Tensor: ... @overload - def sort(self, *, stable: Optional[_bool], dim: _int=-1, descending: _bool=False) -> torch.return_types.sort: ... + def softmax( + self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None + ) -> Tensor: ... @overload - def sort(self, dim: _int=-1, descending: _bool=False) -> torch.return_types.sort: ... + def sort( + self, + *, + stable: Optional[_bool], + dim: _int = -1, + descending: _bool = False, + ) -> torch.return_types.sort: ... @overload - def sort(self, *, stable: Optional[_bool], dim: Union[str, ellipsis, None], descending: _bool=False) -> torch.return_types.sort: ... + def sort( + self, dim: _int = -1, descending: _bool = False + ) -> torch.return_types.sort: ... @overload - def sort(self, dim: Union[str, ellipsis, None], descending: _bool=False) -> torch.return_types.sort: ... + def sort( + self, + *, + stable: Optional[_bool], + dim: Union[str, ellipsis, None], + descending: _bool = False, + ) -> torch.return_types.sort: ... + @overload + def sort( + self, dim: Union[str, ellipsis, None], descending: _bool = False + ) -> torch.return_types.sort: ... def sparse_dim(self) -> _int: ... def sparse_mask(self, mask: Tensor) -> Tensor: ... - def sparse_resize_(self, size: _size, sparse_dim: _int, dense_dim: _int) -> Tensor: ... - def sparse_resize_and_clear_(self, size: _size, sparse_dim: _int, dense_dim: _int) -> Tensor: ... - @overload - def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ... - @overload - def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ... - def split_with_sizes(self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int=0) -> List[Tensor]: ... + def sparse_resize_( + self, size: _size, sparse_dim: _int, dense_dim: _int + ) -> Tensor: ... + def sparse_resize_and_clear_( + self, size: _size, sparse_dim: _int, dense_dim: _int + ) -> Tensor: ... + @overload + def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ... + @overload + def split( + self, split_size: Tuple[_int, ...], dim: _int = 0 + ) -> Sequence[Tensor]: ... + def split_with_sizes( + self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0 + ) -> List[Tensor]: ... def sqrt(self) -> Tensor: ... def sqrt_(self) -> Tensor: ... def square(self) -> Tensor: ... @@ -2415,17 +3359,41 @@ class _TensorBase(metaclass=_TensorMeta): def squeeze_(self, *dim: _int) -> Tensor: ... @overload def squeeze_(self, dim: Union[str, ellipsis, None]) -> Tensor: ... - def sspaddmm(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... + def sspaddmm( + self, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1 + ) -> Tensor: ... @overload - def std(self, dim: Optional[Union[_int, _size]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ... + def std( + self, + dim: Optional[Union[_int, _size]], + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: ... @overload - def std(self, dim: Optional[Union[_int, _size]]=None, *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... + def std( + self, + dim: Optional[Union[_int, _size]] = None, + *, + correction: Optional[_int] = None, + keepdim: _bool = False, + ) -> Tensor: ... @overload - def std(self, unbiased: _bool=True) -> Tensor: ... + def std(self, unbiased: _bool = True) -> Tensor: ... @overload - def std(self, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ... + def std( + self, + dim: Sequence[Union[str, ellipsis, None]], + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: ... @overload - def std(self, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... + def std( + self, + dim: Sequence[Union[str, ellipsis, None]], + *, + correction: Optional[_int] = None, + keepdim: _bool = False, + ) -> Tensor: ... def untyped_storage(self) -> Storage: ... def storage_offset(self) -> _int: ... def storage_type(self) -> Storage: ... @@ -2433,27 +3401,52 @@ class _TensorBase(metaclass=_TensorMeta): def stride(self) -> Tuple[_int, ...]: ... @overload def stride(self, _int) -> _int: ... - def sub(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ... - def sub_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1) -> Tensor: ... + def sub( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + alpha: Optional[Number] = 1, + out: Optional[Tensor] = None, + ) -> Tensor: ... + def sub_( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + alpha: Optional[Number] = 1, + ) -> Tensor: ... @overload - def subtract(self, other: Tensor, *, alpha: Number=1) -> Tensor: ... + def subtract(self, other: Tensor, *, alpha: Number = 1) -> Tensor: ... @overload - def subtract(self, other: Number, alpha: Number=1) -> Tensor: ... + def subtract(self, other: Number, alpha: Number = 1) -> Tensor: ... @overload - def subtract_(self, other: Tensor, *, alpha: Number=1) -> Tensor: ... + def subtract_(self, other: Tensor, *, alpha: Number = 1) -> Tensor: ... @overload - def subtract_(self, other: Number, alpha: Number=1) -> Tensor: ... + def subtract_(self, other: Number, alpha: Number = 1) -> Tensor: ... @overload - def sum(self, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def sum(self, *, dtype: Optional[_dtype] = None) -> Tensor: ... @overload - def sum(self, dim: Optional[Union[_int, _size]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def sum( + self, + dim: Optional[Union[_int, _size]], + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... @overload - def sum(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ... + def sum( + self, + dim: Sequence[Union[str, ellipsis, None]], + keepdim: _bool = False, + *, + dtype: Optional[_dtype] = None, + ) -> Tensor: ... @overload def sum_to_size(self, size: _size) -> Tensor: ... @overload def sum_to_size(self, *size: _int) -> Tensor: ... - def svd(self, some: _bool=True, compute_uv: _bool=True) -> torch.return_types.svd: ... + def svd( + self, some: _bool = True, compute_uv: _bool = True + ) -> torch.return_types.svd: ... def swapaxes(self, axis0: _int, axis1: _int) -> Tensor: ... def swapaxes_(self, axis0: _int, axis1: _int) -> Tensor: ... def swapdims(self, dim0: _int, dim1: _int) -> Tensor: ... @@ -2461,86 +3454,178 @@ class _TensorBase(metaclass=_TensorMeta): def t(self) -> Tensor: ... def t_(self) -> Tensor: ... def take(self, index: Tensor) -> Tensor: ... - def take_along_dim(self, indices: Tensor, dim: Optional[_int]=None) -> Tensor: ... + def take_along_dim( + self, indices: Tensor, dim: Optional[_int] = None + ) -> Tensor: ... def tan(self) -> Tensor: ... def tan_(self) -> Tensor: ... def tanh(self) -> Tensor: ... def tanh_(self) -> Tensor: ... @overload - def tensor_split(self, indices: Sequence[Union[_int, SymInt]], dim: _int=0) -> List[Tensor]: ... + def tensor_split( + self, indices: Sequence[Union[_int, SymInt]], dim: _int = 0 + ) -> List[Tensor]: ... @overload - def tensor_split(self, tensor_indices_or_sections: Tensor, dim: _int=0) -> List[Tensor]: ... + def tensor_split( + self, tensor_indices_or_sections: Tensor, dim: _int = 0 + ) -> List[Tensor]: ... @overload - def tensor_split(self, sections: Union[_int, SymInt], dim: _int=0) -> List[Tensor]: ... + def tensor_split( + self, sections: Union[_int, SymInt], dim: _int = 0 + ) -> List[Tensor]: ... @overload def tile(self, dims: _size) -> Tensor: ... @overload def tile(self, *dims: _int) -> Tensor: ... @overload - def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ... + def to( + self, dtype: _dtype, non_blocking: _bool = False, copy: _bool = False + ) -> Tensor: ... @overload - def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ... - @overload - def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ... - def to_dense(self, dtype: Optional[_dtype]=None) -> Tensor: ... - def to_mkldnn(self, dtype: Optional[_dtype]=None) -> Tensor: ... - def to_padded_tensor(self, padding: _float, output_size: Optional[Sequence[Union[_int, SymInt]]]=None) -> Tensor: ... + def to( + self, + device: Optional[Union[_device, str]] = None, + dtype: Optional[_dtype] = None, + non_blocking: _bool = False, + copy: _bool = False, + ) -> Tensor: ... + @overload + def to( + self, other: Tensor, non_blocking: _bool = False, copy: _bool = False + ) -> Tensor: ... + def to_dense(self, dtype: Optional[_dtype] = None) -> Tensor: ... + def to_mkldnn(self, dtype: Optional[_dtype] = None) -> Tensor: ... + def to_padded_tensor( + self, + padding: _float, + output_size: Optional[Sequence[Union[_int, SymInt]]] = None, + ) -> Tensor: ... @overload - def to_sparse(self, *, layout: Optional[_layout]=None, blocksize: Optional[Union[_int, _size]]=None, dense_dim: Optional[_int]=None) -> Tensor: ... + def to_sparse( + self, + *, + layout: Optional[_layout] = None, + blocksize: Optional[Union[_int, _size]] = None, + dense_dim: Optional[_int] = None, + ) -> Tensor: ... @overload def to_sparse(self, sparse_dim: _int) -> Tensor: ... - def to_sparse_bsc(self, blocksize: Union[_int, _size], dense_dim: Optional[_int]=None) -> Tensor: ... - def to_sparse_bsr(self, blocksize: Union[_int, _size], dense_dim: Optional[_int]=None) -> Tensor: ... - def to_sparse_csc(self, dense_dim: Optional[_int]=None) -> Tensor: ... - def to_sparse_csr(self, dense_dim: Optional[_int]=None) -> Tensor: ... + def to_sparse_bsc( + self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None + ) -> Tensor: ... + def to_sparse_bsr( + self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None + ) -> Tensor: ... + def to_sparse_csc(self, dense_dim: Optional[_int] = None) -> Tensor: ... + def to_sparse_csr(self, dense_dim: Optional[_int] = None) -> Tensor: ... def tolist(self) -> List: ... - def topk(self, k: _int, dim: _int=-1, largest: _bool=True, sorted: _bool=True) -> torch.return_types.topk: ... + def topk( + self, + k: _int, + dim: _int = -1, + largest: _bool = True, + sorted: _bool = True, + ) -> torch.return_types.topk: ... def trace(self) -> Tensor: ... @overload def transpose(self, dim0: _int, dim1: _int) -> Tensor: ... @overload - def transpose(self, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None]) -> Tensor: ... + def transpose( + self, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None] + ) -> Tensor: ... def transpose_(self, dim0: _int, dim1: _int) -> Tensor: ... - def triangular_solve(self, A: Tensor, upper: _bool=True, transpose: _bool=False, unitriangular: _bool=False) -> torch.return_types.triangular_solve: ... - def tril(self, diagonal: _int=0) -> Tensor: ... - def tril_(self, diagonal: _int=0) -> Tensor: ... - def triu(self, diagonal: _int=0) -> Tensor: ... - def triu_(self, diagonal: _int=0) -> Tensor: ... - def true_divide(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor]=None) -> Tensor: ... - def true_divide_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: ... + def triangular_solve( + self, + A: Tensor, + upper: _bool = True, + transpose: _bool = False, + unitriangular: _bool = False, + ) -> torch.return_types.triangular_solve: ... + def tril(self, diagonal: _int = 0) -> Tensor: ... + def tril_(self, diagonal: _int = 0) -> Tensor: ... + def triu(self, diagonal: _int = 0) -> Tensor: ... + def triu_(self, diagonal: _int = 0) -> Tensor: ... + def true_divide( + self, + other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], + *, + out: Optional[Tensor] = None, + ) -> Tensor: ... + def true_divide_( + self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat] + ) -> Tensor: ... def trunc(self) -> Tensor: ... def trunc_(self) -> Tensor: ... @overload - def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ... + def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ... @overload - def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ... + def type( + self, dtype: Union[str, _dtype], non_blocking: _bool = False + ) -> Tensor: ... def type_as(self, other: Tensor) -> Tensor: ... @overload - def unbind(self, dim: _int=0) -> List[Tensor]: ... + def unbind(self, dim: _int = 0) -> List[Tensor]: ... @overload def unbind(self, dim: Union[str, ellipsis, None]) -> List[Tensor]: ... @overload - def unflatten(self, dim: Union[str, ellipsis, None], sizes: _size, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ... + def unflatten( + self, + dim: Union[str, ellipsis, None], + sizes: _size, + names: Sequence[Union[str, ellipsis, None]], + ) -> Tensor: ... @overload def unflatten(self, dim: _int, sizes: _size) -> Tensor: ... def unfold(self, dimension: _int, size: _int, step: _int) -> Tensor: ... - def uniform_(self, from_: _float=0, to: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ... - def unsafe_chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]: ... - def unsafe_split(self, split_size: Union[_int, SymInt], dim: _int=0) -> List[Tensor]: ... - def unsafe_split_with_sizes(self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int=0) -> List[Tensor]: ... + def uniform_( + self, + from_: _float = 0, + to: _float = 1, + *, + generator: Optional[Generator] = None, + ) -> Tensor: ... + def unsafe_chunk(self, chunks: _int, dim: _int = 0) -> List[Tensor]: ... + def unsafe_split( + self, split_size: Union[_int, SymInt], dim: _int = 0 + ) -> List[Tensor]: ... + def unsafe_split_with_sizes( + self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0 + ) -> List[Tensor]: ... def unsqueeze(self, dim: _int) -> Tensor: ... def unsqueeze_(self, dim: _int) -> Tensor: ... def values(self) -> Tensor: ... @overload - def var(self, dim: Optional[Union[_int, _size]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ... + def var( + self, + dim: Optional[Union[_int, _size]], + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: ... @overload - def var(self, dim: Optional[Union[_int, _size]]=None, *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... + def var( + self, + dim: Optional[Union[_int, _size]] = None, + *, + correction: Optional[_int] = None, + keepdim: _bool = False, + ) -> Tensor: ... @overload - def var(self, unbiased: _bool=True) -> Tensor: ... + def var(self, unbiased: _bool = True) -> Tensor: ... @overload - def var(self, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ... + def var( + self, + dim: Sequence[Union[str, ellipsis, None]], + unbiased: _bool = True, + keepdim: _bool = False, + ) -> Tensor: ... @overload - def var(self, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ... + def var( + self, + dim: Sequence[Union[str, ellipsis, None]], + *, + correction: Optional[_int] = None, + keepdim: _bool = False, + ) -> Tensor: ... def vdot(self, other: Tensor) -> Tensor: ... @overload def view(self, dtype: _dtype) -> Tensor: ... @@ -2599,10 +3684,14 @@ def _cuda_synchronize() -> None: ... def _cuda_ipc_collect() -> None: ... def _cuda_getArchFlags() -> Optional[str]: ... def _cuda_init() -> None: ... -def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... +def _cuda_setStream( + stream_id: _int, device_index: _int, device_type: _int +) -> None: ... def _cuda_getCompiledVersion() -> _int: ... def _cuda_cudaHostAllocator() -> _int: ... -def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... +def _cuda_cudaCachingAllocator_raw_alloc( + size: _int, cuda_stream: _int +) -> _int: ... def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ... def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ... @@ -2611,56 +3700,74 @@ def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ... def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ... def _cuda_resetPeakMemoryStats(device: _int) -> None: ... def _cuda_memorySnapshot() -> Dict[str, Any]: ... -def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ... +def _cuda_recordMemoryHistory( + enabled: _bool, + record_context: _bool, + record_context_cpp: _bool, + alloc_trace_max_entries: _int, + alloc_trace_record_context: _bool, +) -> None: ... def _cuda_getAllocatorBackend() -> str: ... -class _cuda_CUDAAllocator: - ... +class _cuda_CUDAAllocator: ... -def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ... +def _cuda_customAllocator( + alloc_fn: _int, free_fn: _int +) -> _cuda_CUDAAllocator: ... def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ... def _cuda_getAllocator() -> _cuda_CUDAAllocator: ... def _cuda_lock_mutex() -> None: ... def _cuda_unlock_mutex() -> None: ... def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... -def _cuda_jiterator_compile_and_launch_kernel(code_string: str, - kernel_name: str, - return_by_ref: _bool, - num_outputs: _int, - tensors: Tuple, - kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ... +def _cuda_jiterator_compile_and_launch_kernel( + code_string: str, + kernel_name: str, + return_by_ref: _bool, + num_outputs: _int, + tensors: Tuple, + kwargs: Dict[str, Union[_int, _float, _bool]], +) -> Tensor: ... def _cuda_get_cudnn_benchmark_limit() -> _int: ... def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ... def _nccl_version() -> _int: ... def _nccl_unique_id() -> bytes: ... def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ... -def _nccl_reduce(input: Sequence[Tensor], - output: Tensor, - root: _int, - op: _int, - streams: Optional[Sequence[_CudaStreamBase]], - comms: Optional[Sequence[object]]) -> None: ... -def _nccl_all_reduce(input: Sequence[Tensor], - output: Sequence[Tensor], - op: _int, - streams: Optional[Sequence[_CudaStreamBase]], - comms: Optional[Sequence[object]]) -> None: ... -def _nccl_broadcast(input: Sequence[Tensor], - root: _int, - streams: Optional[Sequence[_CudaStreamBase]], - comms: Optional[Sequence[object]]) -> None: ... -def _nccl_all_gather(input: Sequence[Tensor], - output: Sequence[Tensor], - streams: Optional[Sequence[_CudaStreamBase]], - comms: Optional[Sequence[object]]) -> None: ... -def _nccl_reduce_scatter(input: Sequence[Tensor], - output: Sequence[Tensor], - op: _int, - streams: Optional[Sequence[_CudaStreamBase]], - comms: Optional[Sequence[object]]) -> None: ... +def _nccl_reduce( + input: Sequence[Tensor], + output: Tensor, + root: _int, + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_all_reduce( + input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_broadcast( + input: Sequence[Tensor], + root: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_all_gather( + input: Sequence[Tensor], + output: Sequence[Tensor], + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... +def _nccl_reduce_scatter( + input: Sequence[Tensor], + output: Sequence[Tensor], + op: _int, + streams: Optional[Sequence[_CudaStreamBase]], + comms: Optional[Sequence[object]], +) -> None: ... def _rocm_is_backward_pass() -> _bool: ... - class _CudaDeviceProperties: name: str major: _int @@ -2672,17 +3779,31 @@ class _CudaDeviceProperties: # Defined in torch/csrc/cuda/python_comm.cpp def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ... -def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ... +def _broadcast_out( + tensor: Tensor, out_tensors: List[Tensor] +) -> List[Tensor]: ... def _broadcast_coalesced( - tensors: List[Tensor], - devices: List[_int], - buffer_size: _int + tensors: List[Tensor], devices: List[_int], buffer_size: _int ) -> List[List[Tensor]]: ... - -def _scatter(tensor: Tensor, devices: List[_int], chunk_sizes: Optional[List[_int]], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ... -def _scatter_out(tensor: Tensor, out_tensors: List[Tensor], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ... -def _gather(tensors: List[Tensor], dim: _int, destination_index: Optional[_int]) -> Tensor: ... -def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ... +def _scatter( + tensor: Tensor, + devices: List[_int], + chunk_sizes: Optional[List[_int]], + dim: _int, + streams: Optional[List[Stream]], +) -> List[Tensor]: ... +def _scatter_out( + tensor: Tensor, + out_tensors: List[Tensor], + dim: _int, + streams: Optional[List[Stream]], +) -> List[Tensor]: ... +def _gather( + tensors: List[Tensor], dim: _int, destination_index: Optional[_int] +) -> Tensor: ... +def _gather_out( + tensors: List[Tensor], out_tensor: Tensor, dim: _int +) -> Tensor: ... # Defined in torch/csrc/cuda/Stream.cpp class _CudaStreamBase: @@ -2694,7 +3815,13 @@ class _CudaStreamBase: cuda_stream: _int priority: _int - def __new__(self, priority: _int = 0, stream_id: _int = 0, device_index: _int = 0, stream_ptr: _int = 0) -> _CudaStreamBase: ... + def __new__( + self, + priority: _int = 0, + stream_id: _int = 0, + device_index: _int = 0, + stream_ptr: _int = 0, + ) -> _CudaStreamBase: ... def query(self) -> _bool: ... def synchronize(self) -> None: ... def priority_range(self) -> Tuple[_int, _int]: ... @@ -2704,9 +3831,16 @@ class _CudaEventBase: device: _device cuda_event: _int - def __new__(cls, enable_timing: _bool = False, blocking: _bool = False, interprocess: _bool = False) -> _CudaEventBase: ... + def __new__( + cls, + enable_timing: _bool = False, + blocking: _bool = False, + interprocess: _bool = False, + ) -> _CudaEventBase: ... @classmethod - def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ... + def from_ipc_handle( + cls, device: _device, ipc_handle: bytes + ) -> _CudaEventBase: ... def record(self, stream: _CudaStreamBase) -> None: ... def wait(self, stream: _CudaStreamBase) -> None: ... def query(self) -> _bool: ... @@ -2716,24 +3850,29 @@ class _CudaEventBase: # Defined in torch/csrc/cuda/Graph.cpp class _CUDAGraph: - def capture_begin(self, - pool: Optional[Tuple[_int, _int]]=...) -> None: ... + def capture_begin( + self, pool: Optional[Tuple[_int, _int]] = ... + ) -> None: ... def capture_end(self) -> None: ... def replay(self) -> None: ... def reset(self) -> None: ... def pool(self) -> Tuple[_int, _int]: ... def enable_debug_mode(self) -> None: ... - def debug_dump(self, - debug_path: str) -> None: ... + def debug_dump(self, debug_path: str) -> None: ... def _cuda_isCurrentStreamCapturing() -> _bool: ... - def _graph_pool_handle() -> Tuple[_int, _int]: ... # Defined in torch/csrc/DataLoader.cpp -def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers -def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs -def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs +def _set_worker_signal_handlers( + *arg: Any, +) -> None: ... # THPModule_setWorkerSignalHandlers +def _set_worker_pids( + key: _int, child_pids: Tuple[_int, ...] +) -> None: ... # THPModule_setWorkerPIDs +def _remove_worker_pids( + loader_id: _int, +) -> None: ... # THPModule_removeWorkerPIDs def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails # Defined in torch/csrc/jit/python/python_tracer.cpp @@ -2752,19 +3891,19 @@ def _create_graph_by_tracing( strict: Any, force_outplace: Any, self: Any = None, - argument_names: List[str] = [] + argument_names: List[str] = [], ) -> Tuple[Graph, Stack]: ... def _tracer_warn_use_python(): ... def _get_tracing_state() -> TracingState: ... # Defined in torch/csrc/jit/python/python_ir.cpp # Not actually defined in python_ir.cpp, not sure where they are. -class IValue: - ... +class IValue: ... + Stack = List[IValue] class JitType: - annotation_str : str + annotation_str: str def isSubtypeOf(self, other: JitType) -> _bool: ... def with_dtype(self, dtype: _dtype) -> JitType: ... def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ... @@ -2779,7 +3918,7 @@ class InferredType: def success(self) -> _bool: ... def reason(self) -> str: ... -R = TypeVar('R', bound=JitType) +R = TypeVar("R", bound=JitType) class AnyType(JitType): @staticmethod @@ -2824,7 +3963,6 @@ class StreamObjType(JitType): class ListType(JitType): def __init__(self, a: JitType) -> None: ... def getElementType(self) -> JitType: ... - @staticmethod def ofInts() -> ListType: ... @staticmethod @@ -2861,7 +3999,6 @@ class InterfaceType(JitType): class OptionalType(JitType, Generic[R]): def __init__(self, a: JitType) -> None: ... def getElementType(self) -> JitType: ... - @staticmethod def ofTensor() -> OptionalType: ... @@ -2881,16 +4018,17 @@ class EnumType(JitType): self, qualified_name: str, value_type: JitType, - enum_names_values: List[Any] - ) -> None: - ... + enum_names_values: List[Any], + ) -> None: ... class TensorType(JitType): @classmethod def get(cls) -> TensorType: ... @classmethod def getInferred(cls) -> TensorType: ... - def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ... + def with_sizes( + self, other: Optional[List[Optional[_int]]] + ) -> TensorType: ... def sizes(self) -> Optional[List[_int]]: ... def varyingSizes(self) -> Optional[List[Optional[_int]]]: ... def strides(self) -> Optional[List[_int]]: ... @@ -2901,24 +4039,19 @@ class TensorType(JitType): def create_from_tensor(t: Tensor) -> TensorType: ... # Defined in torch/csrc/jit/python/python_tree_views.cpp -class SourceRange: - ... - -class TreeView: - ... +class SourceRange: ... +class TreeView: ... class Ident(TreeView): @property def name(self) -> str: ... -class ClassDef(TreeView): - ... +class ClassDef(TreeView): ... class Def(TreeView): def name(self) -> Ident: ... -class Decl(TreeView): - ... +class Decl(TreeView): ... # Defined in torch/csrc/distributed/rpc/init.cpp def _rpc_init() -> _bool: ... @@ -2931,7 +4064,6 @@ def _c10d_init() -> _bool: ... # Defined in torch/csrc/distributed/rpc/testing/init.cpp def _faulty_agent_init() -> _bool: ... - def _enable_minidumps(directory: str) -> None: ... def _disable_minidumps() -> None: ... def _enable_minidumps_on_exceptions() -> None: ... @@ -2940,7 +4072,7 @@ def _activate_cuda_trace() -> None: ... # Defined in torch/csrc/Module.cpp def _current_graph_task_id() -> _int: ... -def _current_autograd_node() -> _Node: ... +def _current_autograd_node() -> _Node: ... class _OutOfMemoryError: pass diff --git a/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py b/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py index e0d7fb4fc..d72c4679f 100644 --- a/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py +++ b/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py @@ -1,11 +1,10 @@ import pytest -import torch - import pytorch_pfn_extras as ppe +import torch def test_stream(): - cupy = pytest.importorskip('cupy') + cupy = pytest.importorskip("cupy") assert 0 == cupy.cuda.get_current_stream().ptr assert 0 == torch.cuda.current_stream().cuda_stream @@ -42,7 +41,7 @@ def test_stream_none(): class TestMemoryPool: @pytest.fixture def cupy(self): - cupy = pytest.importorskip('cupy') + cupy = pytest.importorskip("cupy") mempool = cupy.get_default_memory_pool() yield cupy mempool.free_all_blocks() @@ -98,7 +97,8 @@ def test_use_torch_mempool_stream_mismatch(self, cupy): try: stream.use() with pytest.raises( - RuntimeError, match='pytorch_pfn_extras.cuda.stream'): + RuntimeError, match="pytorch_pfn_extras.cuda.stream" + ): arr = cupy.arange(10) del arr finally: diff --git a/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py b/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py index 6c7238a1e..ee868f850 100644 --- a/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py +++ b/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py @@ -1,6 +1,5 @@ -import torch - import pytorch_pfn_extras as ppe +import torch class DummyDataset(torch.utils.data.Dataset): diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py index c21f35404..dbb6df388 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py @@ -1,15 +1,19 @@ import numpy as np - import pytorch_pfn_extras as ppe class DummyDataset(ppe.dataset.TabularDataset): - def __init__( - self, size=10, keys=('a', 'b', 'c'), mode=tuple, - return_array=False, callback=None, convert=False): + self, + size=10, + keys=("a", "b", "c"), + mode=tuple, + return_array=False, + callback=None, + convert=False, + ): if mode is None: - keys = keys[0], + keys = (keys[0],) self._keys = keys self._mode = mode @@ -47,6 +51,6 @@ def get_examples(self, indices, key_indices): def convert(self, data): if self._convert: - return 'converted' + return "converted" else: return super(DummyDataset, self).convert(data) diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py index 9ec9ddb2d..91b426e89 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py @@ -1,13 +1,11 @@ import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) -@pytest.mark.parametrize( - 'mode', - [tuple, dict, None] -) +@pytest.mark.parametrize("mode", [tuple, dict, None]) def test_astuple(mode): dataset = dummy_dataset.DummyDataset(mode=mode, convert=True) view = dataset.astuple() @@ -15,15 +13,11 @@ def test_astuple(mode): assert len(view) == len(dataset) assert view.keys == dataset.keys assert view.mode == tuple - assert ( - view.get_examples(None, None) == dataset.get_examples(None, None)) - assert view.convert(view.fetch()) == 'converted' + assert view.get_examples(None, None) == dataset.get_examples(None, None) + assert view.convert(view.fetch()) == "converted" -@pytest.mark.parametrize( - 'mode', - [tuple, dict, None] -) +@pytest.mark.parametrize("mode", [tuple, dict, None]) def test_asdict(mode): dataset = dummy_dataset.DummyDataset(mode=mode, convert=True) view = dataset.asdict() @@ -31,6 +25,5 @@ def test_asdict(mode): assert len(view) == len(dataset) assert view.keys == dataset.keys assert view.mode == dict - assert ( - view.get_examples(None, None) == dataset.get_examples(None, None)) - assert view.convert(view.fetch()) == 'converted' + assert view.get_examples(None, None) == dataset.get_examples(None, None) + assert view.convert(view.fetch()) == "converted" diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py index b7209aa51..19197c264 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py @@ -3,57 +3,61 @@ import numpy as np import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) mode_a = [tuple, dict, None] mode_b = [tuple, dict, None] return_array = [True, False] parameter_set = [ - {'indices': None, - 'expected_indices_a': None, - 'expected_indices_b': None}, - {'indices': [3, 1, 4, 12, 14, 13, 7, 5], - 'expected_indices_a': [3, 1, 4, 7, 5], - 'expected_indices_b': [2, 4, 3]}, - {'indices': [3, 1, 4], - 'expected_indices_a': [3, 1, 4]}, - {'indices': slice(13, 6, -2), - 'expected_indices_a': slice(9, 6, -2), - 'expected_indices_b': slice(3, None, -2)}, - {'indices': slice(9, None, -2), - 'expected_indices_a': slice(9, None, -2)}, - {'indices': [1, 2, 1], - 'expected_indices_a': [1, 2, 1]}, - {'indices': []}, + {"indices": None, "expected_indices_a": None, "expected_indices_b": None}, + { + "indices": [3, 1, 4, 12, 14, 13, 7, 5], + "expected_indices_a": [3, 1, 4, 7, 5], + "expected_indices_b": [2, 4, 3], + }, + {"indices": [3, 1, 4], "expected_indices_a": [3, 1, 4]}, + { + "indices": slice(13, 6, -2), + "expected_indices_a": slice(9, 6, -2), + "expected_indices_b": slice(3, None, -2), + }, + {"indices": slice(9, None, -2), "expected_indices_a": slice(9, None, -2)}, + {"indices": [1, 2, 1], "expected_indices_a": [1, 2, 1]}, + {"indices": []}, ] @pytest.mark.parametrize( - 'mode_a, mode_b, return_array, parameter_set', - itertools.product(mode_a, mode_b, return_array, parameter_set) + "mode_a, mode_b, return_array, parameter_set", + itertools.product(mode_a, mode_b, return_array, parameter_set), ) def test_concat(mode_a, mode_b, return_array, parameter_set): def callback_a(indices, key_indices): - assert indices == parameter_set['expected_indices_a'] + assert indices == parameter_set["expected_indices_a"] assert key_indices is None dataset_a = dummy_dataset.DummyDataset( - keys=('a', 'b', 'c') if mode_b else ('a',), + keys=("a", "b", "c") if mode_b else ("a",), mode=mode_a, - return_array=return_array, callback=callback_a, - convert=True) + return_array=return_array, + callback=callback_a, + convert=True, + ) def callback_b(indices, key_indices): - assert indices == parameter_set['expected_indices_b'] + assert indices == parameter_set["expected_indices_b"] assert key_indices is None dataset_b = dummy_dataset.DummyDataset( size=5, - keys=('a', 'b', 'c') if mode_a else ('a',), + keys=("a", "b", "c") if mode_a else ("a",), mode=mode_b, - return_array=return_array, callback=callback_b) + return_array=return_array, + callback=callback_b, + ) view = dataset_a.concat(dataset_b) assert isinstance(view, ppe.dataset.TabularDataset) @@ -61,27 +65,28 @@ def callback_b(indices, key_indices): assert view.keys == dataset_a.keys assert view.mode == dataset_a.mode - output = view.get_examples(parameter_set['indices'], None) + output = view.get_examples(parameter_set["indices"], None) data = np.hstack((dataset_a.data, dataset_b.data)) - if parameter_set['indices'] is not None: - data = data[:, parameter_set['indices']] + if parameter_set["indices"] is not None: + data = data[:, parameter_set["indices"]] for out, d in itertools.zip_longest(output, data): np.testing.assert_equal(out, d) if return_array and operator.xor( - ('expected_indices_a' in parameter_set), - ('expected_indices_b' in parameter_set)): + ("expected_indices_a" in parameter_set), + ("expected_indices_b" in parameter_set), + ): assert isinstance(out, np.ndarray) else: assert isinstance(out, list) - assert view.convert(output) == 'converted' + assert view.convert(output) == "converted" def test_concat_key_length(): dataset_a = dummy_dataset.DummyDataset() - dataset_b = dummy_dataset.DummyDataset(keys=('a', 'b')) + dataset_b = dummy_dataset.DummyDataset(keys=("a", "b")) with pytest.raises(ValueError): dataset_a.concat(dataset_b) @@ -89,7 +94,7 @@ def test_concat_key_length(): def test_concat_key_order(): dataset_a = dummy_dataset.DummyDataset() - dataset_b = dummy_dataset.DummyDataset(keys=('b', 'a', 'c')) + dataset_b = dummy_dataset.DummyDataset(keys=("b", "a", "c")) with pytest.raises(ValueError): dataset_a.concat(dataset_b) diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py index 3236bb795..0f1e6014a 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py @@ -1,21 +1,17 @@ import pytest - import pytorch_pfn_extras as ppe from pytorch_pfn_extras.dataset import tabular -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) -@pytest.mark.parametrize( - 'mode', - [tuple, dict, None] -) +@pytest.mark.parametrize("mode", [tuple, dict, None]) def test_delegate_dataset(mode): - dataset = tabular.DelegateDataset( - dummy_dataset.DummyDataset(mode=mode)) + dataset = tabular.DelegateDataset(dummy_dataset.DummyDataset(mode=mode)) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == len(dataset.dataset) assert dataset.keys == dataset.dataset.keys assert dataset.mode == dataset.dataset.mode - assert ( - dataset.get_example(3) == dataset.dataset.get_example(3)) + assert dataset.get_example(3) == dataset.dataset.get_example(3) diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py index 6ca854ccb..783e2ac92 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py @@ -1,12 +1,10 @@ import numpy as np import pytest - import pytorch_pfn_extras as ppe from pytorch_pfn_extras.dataset import tabular class TestFromData: - def test_unary_array(self): dataset = tabular.from_data(np.arange(10)) @@ -20,11 +18,11 @@ def test_unary_array(self): assert isinstance(output, np.ndarray) def test_unary_array_with_key(self): - dataset = tabular.from_data(('key_a', np.arange(10))) + dataset = tabular.from_data(("key_a", np.arange(10))) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a',) + assert dataset.keys == ("key_a",) assert dataset.mode is None output = dataset.slice[[1, 3]].fetch() @@ -44,11 +42,11 @@ def test_unary_list(self): assert isinstance(output, list) def test_unary_list_with_key(self): - dataset = tabular.from_data(('key_a', [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])) + dataset = tabular.from_data(("key_a", [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a',) + assert dataset.keys == ("key_a",) assert dataset.mode is None output = dataset.slice[[1, 3]].fetch() @@ -56,12 +54,12 @@ def test_unary_list_with_key(self): assert isinstance(output, list) def test_unary_callable_unary(self): - dataset = tabular.from_data(('key_a', lambda i: i * i), size=10) + dataset = tabular.from_data(("key_a", lambda i: i * i), size=10) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a',) - assert(dataset.mode) is None + assert dataset.keys == ("key_a",) + assert (dataset.mode) is None output = dataset.slice[[1, 3]].fetch() np.testing.assert_equal(output, [1, 9]) @@ -69,11 +67,12 @@ def test_unary_callable_unary(self): def test_unary_callable_tuple(self): dataset = tabular.from_data( - (('key_a', 'key_b'), lambda i: (i * i, -i)), size=10) + (("key_a", "key_b"), lambda i: (i * i, -i)), size=10 + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a', 'key_b') + assert dataset.keys == ("key_a", "key_b") assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -83,16 +82,17 @@ def test_unary_callable_tuple(self): def test_unary_callable_dict(self): dataset = tabular.from_data( - (('key_a', 'key_b'), - lambda i: {'key_a': i * i, 'key_b': -i}), size=10) + (("key_a", "key_b"), lambda i: {"key_a": i * i, "key_b": -i}), + size=10, + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a', 'key_b') + assert dataset.keys == ("key_a", "key_b") assert dataset.mode == dict output = dataset.slice[[1, 3]].fetch() - np.testing.assert_equal(output, {'key_a': [1, 9], 'key_b': [-1, -3]}) + np.testing.assert_equal(output, {"key_a": [1, 9], "key_b": [-1, -3]}) for out in output.values(): assert isinstance(out, list) @@ -102,11 +102,12 @@ def test_unary_callable_without_key(self): def test_unary_callable_without_size(self): with pytest.raises(ValueError): - tabular.from_data(('key_a', lambda i: i * i)) + tabular.from_data(("key_a", lambda i: i * i)) def test_tuple_array_list(self): dataset = tabular.from_data( - (np.arange(10), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])) + (np.arange(10), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]) + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 @@ -120,12 +121,13 @@ def test_tuple_array_list(self): def test_tuple_array_with_key_list(self): dataset = tabular.from_data( - (('key_a', np.arange(10)), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])) + (("key_a", np.arange(10)), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]) + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 assert len(dataset.keys) == 2 - assert dataset.keys[0] == 'key_a' + assert dataset.keys[0] == "key_a" assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -135,12 +137,13 @@ def test_tuple_array_with_key_list(self): def test_tuple_array_list_with_key(self): dataset = tabular.from_data( - (np.arange(10), ('key_b', [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]))) + (np.arange(10), ("key_b", [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])) + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 assert len(dataset.keys) == 2 - assert dataset.keys[1] == 'key_b' + assert dataset.keys[1] == "key_b" assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -149,13 +152,12 @@ def test_tuple_array_list_with_key(self): assert isinstance(output[1], list) def test_tuple_array_callable_unary(self): - dataset = tabular.from_data( - (np.arange(10), ('key_b', lambda i: i * i))) + dataset = tabular.from_data((np.arange(10), ("key_b", lambda i: i * i))) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 assert len(dataset.keys) == 2 - assert dataset.keys[1] == 'key_b' + assert dataset.keys[1] == "key_b" assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -165,12 +167,13 @@ def test_tuple_array_callable_unary(self): def test_tuple_array_callable_tuple(self): dataset = tabular.from_data( - (np.arange(10), (('key_b', 'key_c'), lambda i: (i * i, -i)))) + (np.arange(10), (("key_b", "key_c"), lambda i: (i * i, -i))) + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 assert len(dataset.keys) == 3 - assert dataset.keys[1] == ('key_b') + assert dataset.keys[1] == ("key_b") assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -180,13 +183,16 @@ def test_tuple_array_callable_tuple(self): def test_tuple_array_callable_dict(self): dataset = tabular.from_data( - (np.arange(10), (('key_b', 'key_c'), - lambda i: {'key_b': i * i, 'key_c': -i}))) + ( + np.arange(10), + (("key_b", "key_c"), lambda i: {"key_b": i * i, "key_c": -i}), + ) + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 assert len(dataset.keys) == 3 - assert dataset.keys[1] == ('key_b') + assert dataset.keys[1] == ("key_b") assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -196,11 +202,12 @@ def test_tuple_array_callable_dict(self): def test_tuple_array_with_key_callable_unary(self): dataset = tabular.from_data( - (('key_a', np.arange(10)), ('key_b', lambda i: i * i))) + (("key_a", np.arange(10)), ("key_b", lambda i: i * i)) + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a', 'key_b') + assert dataset.keys == ("key_a", "key_b") assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -210,11 +217,12 @@ def test_tuple_array_with_key_callable_unary(self): def test_tuple_callable_unary_callable_unary(self): dataset = tabular.from_data( - (('key_a', lambda i: i * i), ('key_b', lambda i: -i)), size=10) + (("key_a", lambda i: i * i), ("key_b", lambda i: -i)), size=10 + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert dataset.keys == ('key_a', 'key_b') + assert dataset.keys == ("key_a", "key_b") assert dataset.mode == tuple output = dataset.slice[[1, 3]].fetch() @@ -225,87 +233,97 @@ def test_tuple_callable_unary_callable_unary(self): def test_tuple_callable_unary_callable_unary_without_size(self): with pytest.raises(ValueError): tabular.from_data( - (('key_a', lambda i: i * i), ('key_b', lambda i: -i))) + (("key_a", lambda i: i * i), ("key_b", lambda i: -i)) + ) def test_dict_array_list(self): dataset = tabular.from_data( - {'key_a': np.arange(10), 'key_b': [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]}) + {"key_a": np.arange(10), "key_b": [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]} + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert set(dataset.keys) == {'key_a', 'key_b'} + assert set(dataset.keys) == {"key_a", "key_b"} assert dataset.mode == dict output = dataset.slice[[1, 3]].fetch() - np.testing.assert_equal(output, {'key_a': [1, 3], 'key_b': [7, 8]}) - assert isinstance(output['key_a'], np.ndarray) - assert isinstance(output['key_b'], list) + np.testing.assert_equal(output, {"key_a": [1, 3], "key_b": [7, 8]}) + assert isinstance(output["key_a"], np.ndarray) + assert isinstance(output["key_b"], list) def test_dict_array_callable_unary(self): dataset = tabular.from_data( - {'key_a': np.arange(10), 'key_b': lambda i: i * i}) + {"key_a": np.arange(10), "key_b": lambda i: i * i} + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert set(dataset.keys) == {'key_a', 'key_b'} + assert set(dataset.keys) == {"key_a", "key_b"} assert dataset.mode == dict output = dataset.slice[[1, 3]].fetch() - np.testing.assert_equal(output, {'key_a': [1, 3], 'key_b': [1, 9]}) - assert isinstance(output['key_a'], np.ndarray) - assert isinstance(output['key_b'], list) + np.testing.assert_equal(output, {"key_a": [1, 3], "key_b": [1, 9]}) + assert isinstance(output["key_a"], np.ndarray) + assert isinstance(output["key_b"], list) def test_dict_array_callable_tuple(self): dataset = tabular.from_data( - {'key_a': np.arange(10), - ('key_b', 'key_c'): lambda i: (i * i, -i)}) + {"key_a": np.arange(10), ("key_b", "key_c"): lambda i: (i * i, -i)} + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert set(dataset.keys) == {'key_a', 'key_b', 'key_c'} + assert set(dataset.keys) == {"key_a", "key_b", "key_c"} assert dataset.mode == dict output = dataset.slice[[1, 3]].fetch() np.testing.assert_equal( - output, {'key_a': [1, 3], 'key_b': [1, 9], 'key_c': [-1, -3]}) - assert isinstance(output['key_a'], np.ndarray) - assert isinstance(output['key_b'], list) - assert isinstance(output['key_c'], list) + output, {"key_a": [1, 3], "key_b": [1, 9], "key_c": [-1, -3]} + ) + assert isinstance(output["key_a"], np.ndarray) + assert isinstance(output["key_b"], list) + assert isinstance(output["key_c"], list) def test_dict_array_callable_dict(self): dataset = tabular.from_data( - {'key_a': np.arange(10), - ('key_b', 'key_c'): lambda i: {'key_b': i * i, 'key_c': -i}}) + { + "key_a": np.arange(10), + ("key_b", "key_c"): lambda i: {"key_b": i * i, "key_c": -i}, + } + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert set(dataset.keys) == {'key_a', 'key_b', 'key_c'} + assert set(dataset.keys) == {"key_a", "key_b", "key_c"} assert dataset.mode == dict output = dataset.slice[[1, 3]].fetch() np.testing.assert_equal( - output, {'key_a': [1, 3], 'key_b': [1, 9], 'key_c': [-1, -3]}) - assert isinstance(output['key_a'], np.ndarray) - assert isinstance(output['key_b'], list) - assert isinstance(output['key_c'], list) + output, {"key_a": [1, 3], "key_b": [1, 9], "key_c": [-1, -3]} + ) + assert isinstance(output["key_a"], np.ndarray) + assert isinstance(output["key_b"], list) + assert isinstance(output["key_c"], list) def test_dict_callable_unary_callable_unary(self): dataset = tabular.from_data( - {'key_a': lambda i: i * i, 'key_b': lambda i: -i}, size=10) + {"key_a": lambda i: i * i, "key_b": lambda i: -i}, size=10 + ) assert isinstance(dataset, ppe.dataset.TabularDataset) assert len(dataset) == 10 - assert set(dataset.keys) == {'key_a', 'key_b'} + assert set(dataset.keys) == {"key_a", "key_b"} output = dataset.slice[[1, 3]].fetch() - np.testing.assert_equal(output, {'key_a': [1, 9], 'key_b': [-1, -3]}) - assert isinstance(output['key_a'], list) - assert isinstance(output['key_b'], list) + np.testing.assert_equal(output, {"key_a": [1, 9], "key_b": [-1, -3]}) + assert isinstance(output["key_a"], list) + assert isinstance(output["key_b"], list) def test_dict_callable_unary_callable_unary_without_size(self): with pytest.raises(ValueError): - tabular.from_data(( - {'key_a': lambda i: i * i, 'key_b': lambda i: -i})) + tabular.from_data( + ({"key_a": lambda i: i * i, "key_b": lambda i: -i}) + ) def test_unique(self): dataset_a = tabular.from_data(np.arange(10)) diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py index 38bd6beff..79016d2e4 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py @@ -2,9 +2,10 @@ import numpy as np import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) def _filter_params(params): @@ -13,20 +14,22 @@ def _filter_params(params): key_size += 3 if param[0] else 1 key_size += 2 if param[1] else 1 - if param[3] and \ - any(key_size <= key_index for key_index in param[3]): + if param[3] and any(key_size <= key_index for key_index in param[3]): continue yield param @pytest.mark.parametrize( - 'mode_a, mode_b, return_array, key_indices', - _filter_params(itertools.product( - [tuple, dict, None], - [tuple, dict, None], - [True, False], - [None, (0, 4, 1), (0, 2), (1, 0), ()])) + "mode_a, mode_b, return_array, key_indices", + _filter_params( + itertools.product( + [tuple, dict, None], + [tuple, dict, None], + [True, False], + [None, (0, 4, 1), (0, 2), (1, 0), ()], + ) + ), ) def test_join(mode_a, mode_b, return_array, key_indices): if key_indices is None: @@ -37,13 +40,13 @@ def test_join(mode_a, mode_b, return_array, key_indices): key_size_a = 3 if mode_a else 1 key_indices_a = tuple( - key_index - for key_index in key_indices - if key_index < key_size_a) + key_index for key_index in key_indices if key_index < key_size_a + ) key_indices_b = tuple( key_index - key_size_a for key_index in key_indices - if key_size_a <= key_index) + if key_size_a <= key_index + ) if key_indices_a: expected_key_indices_a = key_indices_a @@ -56,16 +59,21 @@ def callback_a(indices, key_indices): dataset_a = dummy_dataset.DummyDataset( mode=mode_a, - return_array=return_array, callback=callback_a, - convert=True) + return_array=return_array, + callback=callback_a, + convert=True, + ) def callback_b(indices, key_indices): assert indices is None assert key_indices == expected_key_indices_b - dataset_b = dummy_dataset. DummyDataset( - keys=('d', 'e'), mode=mode_b, - return_array=return_array, callback=callback_b) + dataset_b = dummy_dataset.DummyDataset( + keys=("d", "e"), + mode=mode_b, + return_array=return_array, + callback=callback_b, + ) view = dataset_a.join(dataset_b) assert isinstance(view, ppe.dataset.TabularDataset) @@ -86,12 +94,12 @@ def callback_b(indices, key_indices): else: assert isinstance(out, list) - assert view.convert(output) == 'converted' + assert view.convert(output) == "converted" def test_join_length(): dataset_a = dummy_dataset.DummyDataset() - dataset_b = dummy_dataset.DummyDataset(size=5, keys=('d', 'e')) + dataset_b = dummy_dataset.DummyDataset(size=5, keys=("d", "e")) with pytest.raises(ValueError): dataset_a.join(dataset_b) @@ -99,7 +107,7 @@ def test_join_length(): def test_join_conflict_key(): dataset_a = dummy_dataset.DummyDataset() - dataset_b = dummy_dataset.DummyDataset(keys=('a', 'd')) + dataset_b = dummy_dataset.DummyDataset(keys=("a", "d")) with pytest.raises(ValueError): dataset_a.join(dataset_b) diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py index 5473ecc08..d8c2773bf 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py @@ -3,9 +3,10 @@ import numpy as np import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) def _values_to_dicts(names, values): @@ -18,143 +19,170 @@ def safe_zip(ns, vs): assert isinstance(vs, (tuple, list)) and len(ns) == len(vs) return zip(ns, vs) - names = names.split(',') + names = names.split(",") params = [dict(safe_zip(names, value_list)) for value_list in values] return params def product(parameter): if isinstance(parameter, dict): - return product_dict(*[ - _values_to_dicts(names, values) - for names, values in sorted(parameter.items())]) + return product_dict( + *[ + _values_to_dicts(names, values) + for names, values in sorted(parameter.items()) + ] + ) elif isinstance(parameter, list): # list of lists of dicts if not all(isinstance(_, list) for _ in parameter): - raise TypeError('parameter must be list of lists of dicts') + raise TypeError("parameter must be list of lists of dicts") if not all(isinstance(_, dict) for l in parameter for _ in l): # NOQA - raise TypeError('parameter must be list of lists of dicts') + raise TypeError("parameter must be list of lists of dicts") return product_dict(*parameter) else: raise TypeError( - 'parameter must be either dict or list. Actual: {}'.format( - type(parameter))) + "parameter must be either dict or list. Actual: {}".format( + type(parameter) + ) + ) def product_dict(*parameters): return [ {k: v for dic in dicts for k, v in dic.items()} - for dicts in itertools.product(*parameters)] + for dicts in itertools.product(*parameters) + ] def _filter_params(params): for param in params: - if 'expected_len' in param and \ - isinstance(param['get_examples_indices'], list) and \ - any(param['expected_len'] <= index - for index in param['get_examples_indices']): + if ( + "expected_len" in param + and isinstance(param["get_examples_indices"], list) + and any( + param["expected_len"] <= index + for index in param["get_examples_indices"] + ) + ): continue - if 'expected_keys' in param and \ - isinstance(param['get_examples_key_indices'], tuple) and \ - any(len(param['expected_keys']) <= key_index - for key_index in param['get_examples_key_indices']): + if ( + "expected_keys" in param + and isinstance(param["get_examples_key_indices"], tuple) + and any( + len(param["expected_keys"]) <= key_index + for key_index in param["get_examples_key_indices"] + ) + ): continue # To reduce the number of tests, # drop combinations of indices and keys. # (check only `slice[indices]` and `slice[:, keys]`) - if (not (param['indices'] == slice(None) - and param['get_examples_indices'] is None) - and not (param['keys'] is None - and param['get_examples_key_indices'] is None)): + if not ( + param["indices"] == slice(None) + and param["get_examples_indices"] is None + ) and not ( + param["keys"] is None and param["get_examples_key_indices"] is None + ): continue yield param -params = _filter_params(product_dict( +params = _filter_params( product_dict( - [{'mode': tuple}, {'mode': dict}], + product_dict( + [{"mode": tuple}, {"mode": dict}], + [ + {"keys": None, "expected_keys": ("a", "b", "c")}, + {"keys": 1, "expected_keys": ("b",)}, + {"keys": (1,), "expected_keys": ("b",)}, + {"keys": 3, "key_exception": IndexError}, + {"keys": (3,), "key_exception": IndexError}, + {"keys": "c", "expected_keys": ("c",)}, + {"keys": ("c",), "expected_keys": ("c",)}, + {"keys": "d", "key_exception": KeyError}, + {"keys": ("d",), "key_exception": KeyError}, + {"keys": (-1, "a"), "expected_keys": ("c", "a")}, + {"keys": (), "expected_keys": ()}, + ], + ) + + product_dict( + [{"mode": None}], + [ + {"keys": None, "expected_keys": ("a",)}, + {"keys": 0, "expected_keys": ("a",)}, + {"keys": (0,), "expected_keys": ("a",)}, + {"keys": 1, "key_exception": IndexError}, + {"keys": (1,), "key_exception": IndexError}, + {"keys": "a", "expected_keys": ("a",)}, + {"keys": ("a",), "expected_keys": ("a",)}, + {"keys": "b", "key_exception": KeyError}, + {"keys": ("b",), "key_exception": KeyError}, + {"keys": (), "expected_keys": ()}, + ], + ), + product( + { + "return_array": [True, False], + "integer": [int, np.int32], + } + ), [ - {'keys': None, 'expected_keys': ('a', 'b', 'c')}, - {'keys': 1, 'expected_keys': ('b',)}, - {'keys': (1,), 'expected_keys': ('b',)}, - {'keys': 3, 'key_exception': IndexError}, - {'keys': (3,), 'key_exception': IndexError}, - {'keys': 'c', 'expected_keys': ('c',)}, - {'keys': ('c',), 'expected_keys': ('c',)}, - {'keys': 'd', 'key_exception': KeyError}, - {'keys': ('d',), 'key_exception': KeyError}, - {'keys': (-1, 'a'), 'expected_keys': ('c', 'a')}, - {'keys': (), 'expected_keys': ()}, + {"indices": slice(None), "expected_len": 10}, + {"indices": [3, -2], "expected_len": 2}, + {"indices": [11, 1], "index_exception": IndexError}, + {"indices": [i in {1, 3} for i in range(10)], "expected_len": 2}, + {"indices": [True] * 11, "index_exception": ValueError}, + {"indices": slice(3, None, -2), "expected_len": 2}, + {"indices": [False, 3, 9, 5, True], "expected_len": 5}, + {"indices": [], "expected_len": 0}, ], + product( + { + "get_examples_indices": [ + None, + [1], + [1, 0], + slice(0, 2, 1), + slice(1, None, -1), + [], + ], + "get_examples_key_indices": [None, (1,), (1, 0), ()], + } + ), ) - + product_dict( - [{'mode': None}], - [ - {'keys': None, 'expected_keys': ('a',)}, - {'keys': 0, 'expected_keys': ('a',)}, - {'keys': (0,), 'expected_keys': ('a',)}, - {'keys': 1, 'key_exception': IndexError}, - {'keys': (1,), 'key_exception': IndexError}, - {'keys': 'a', 'expected_keys': ('a',)}, - {'keys': ('a',), 'expected_keys': ('a',)}, - {'keys': 'b', 'key_exception': KeyError}, - {'keys': ('b',), 'key_exception': KeyError}, - {'keys': (), 'expected_keys': ()}, - ], - ), - product({ - 'return_array': [True, False], - 'integer': [int, np.int32], - }), - [ - {'indices': slice(None), 'expected_len': 10}, - {'indices': [3, -2], 'expected_len': 2}, - {'indices': [11, 1], 'index_exception': IndexError}, - {'indices': [i in {1, 3} for i in range(10)], 'expected_len': 2}, - {'indices': [True] * 11, 'index_exception': ValueError}, - {'indices': slice(3, None, -2), 'expected_len': 2}, - {'indices': [False, 3, 9, 5, True], 'expected_len': 5}, - {'indices': [], 'expected_len': 0}, - ], - product({ - 'get_examples_indices': [ - None, [1], [1, 0], slice(0, 2, 1), slice(1, None, -1), []], - 'get_examples_key_indices': [None, (1,), (1, 0), ()], - }), -)) - - -@pytest.mark.parametrize( - 'test_args', - params ) + + +@pytest.mark.parametrize("test_args", params) def test_slice(test_args): - exception = test_args.get('index_exception', None) \ - or test_args.get('key_exception', None) + exception = test_args.get("index_exception", None) or test_args.get( + "key_exception", None + ) - indices = test_args['indices'] - keys = test_args['keys'] - mode = test_args['mode'] - return_array = test_args['return_array'] - get_examples_indices = test_args['get_examples_indices'] - get_examples_key_indices = test_args['get_examples_key_indices'] + indices = test_args["indices"] + keys = test_args["keys"] + mode = test_args["mode"] + return_array = test_args["return_array"] + get_examples_indices = test_args["get_examples_indices"] + get_examples_key_indices = test_args["get_examples_key_indices"] if isinstance(indices, list): indices = [ - index if isinstance(index, bool) else test_args['integer'](index) - for index in indices] + index if isinstance(index, bool) else test_args["integer"](index) + for index in indices + ] def callback(indices, key_indices): - if isinstance(indices, list) \ - or isinstance(get_examples_indices, list): + if isinstance(indices, list) or isinstance(get_examples_indices, list): assert isinstance(indices, list) - elif isinstance(indices, slice) \ - or isinstance(get_examples_indices, slice): + elif isinstance(indices, slice) or isinstance( + get_examples_indices, slice + ): assert isinstance(indices, slice) else: assert indices is None @@ -165,8 +193,8 @@ def callback(indices, key_indices): assert isinstance(key_indices, tuple) dataset = dummy_dataset.DummyDataset( - mode=mode, return_array=return_array, callback=callback, - convert=True) + mode=mode, return_array=return_array, callback=callback, convert=True + ) if exception is not None: with pytest.raises(exception): @@ -184,15 +212,13 @@ def callback(indices, key_indices): if isinstance(keys, tuple): keys = keys else: - keys = keys, - key_indices = [ - {'a': 0, 'b': 1, 'c': 2}.get(key, key) for key in keys] - data = dataset.data[key_indices][ - :, _indices_for_numpy(indices)] + keys = (keys,) + key_indices = [{"a": 0, "b": 1, "c": 2}.get(key, key) for key in keys] + data = dataset.data[key_indices][:, _indices_for_numpy(indices)] assert isinstance(view, ppe.dataset.TabularDataset) - assert len(view) == test_args['expected_len'] - assert view.keys == test_args['expected_keys'] + assert len(view) == test_args["expected_len"] + assert view.keys == test_args["expected_keys"] if keys is None: assert view.mode == mode elif isinstance(keys, tuple): @@ -200,8 +226,7 @@ def callback(indices, key_indices): else: assert view.mode is None - output = view.get_examples( - get_examples_indices, get_examples_key_indices) + output = view.get_examples(get_examples_indices, get_examples_key_indices) if get_examples_indices is not None: data = data[:, _indices_for_numpy(get_examples_indices)] @@ -215,22 +240,24 @@ def callback(indices, key_indices): else: assert isinstance(out, list) - assert view.convert(output) == 'converted' + assert view.convert(output) == "converted" # Replace list of bool with ndarray of bool # since old numpy cannot handle list of bool. def _indices_for_numpy(indices): with warnings.catch_warnings(): - warnings.simplefilter('ignore', FutureWarning) + warnings.simplefilter("ignore", FutureWarning) if len(np.empty(2)[[False, True]]) == 1: # new numpy return indices # old numpy - if isinstance(indices, list) and \ - len(indices) > 0 and \ - isinstance(indices[0], bool): + if ( + isinstance(indices, list) + and len(indices) > 0 + and isinstance(indices[0], bool) + ): return np.array(indices) else: return indices diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py index 63885f243..84390cab8 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py @@ -2,10 +2,9 @@ import numpy as np import pytest - -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( # NOQA dummy_dataset, -) # NOQA +) @pytest.mark.parametrize( @@ -42,7 +41,8 @@ def callback(indices, key_indices): def test_convert(self, mode, return_array): dataset = dummy_dataset.DummyDataset( - mode=mode, return_array=return_array) + mode=mode, return_array=return_array + ) output = dataset.convert(dataset.fetch()) if mode is tuple: @@ -80,7 +80,8 @@ def callback(indices, key_indices): def test_iter(self, mode, return_array): dataset = dummy_dataset.DummyDataset( - mode=mode, return_array=return_array) + mode=mode, return_array=return_array + ) it = iter(dataset) for i in range(10): if mode is tuple: diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py index 3391ed1ae..63744a587 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py @@ -2,35 +2,41 @@ import numpy as np import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) # filter out invalid combinations of params def _filter_params(params): for param in params: - if param[1] is None and \ - isinstance(param[3], tuple) and \ - any(1 <= key_index - for key_index in param[3]): + if ( + param[1] is None + and isinstance(param[3], tuple) + and any(1 <= key_index for key_index in param[3]) + ): continue yield param @pytest.mark.parametrize( - 'in_mode, out_mode, indices, key_indices, with_batch', - _filter_params(itertools.product( - [tuple, dict, None], - [tuple, dict, None], - [None, [1, 3], slice(None, 2)], - [None, (0,), (1,), (1, 0)], - [False, True])) + "in_mode, out_mode, indices, key_indices, with_batch", + _filter_params( + itertools.product( + [tuple, dict, None], + [tuple, dict, None], + [None, [1, 3], slice(None, 2)], + [None, (0,), (1,), (1, 0)], + [False, True], + ) + ), ) def test_transform(in_mode, out_mode, indices, key_indices, with_batch): dataset = dummy_dataset.DummyDataset( - mode=in_mode, return_array=True, convert=True) + mode=in_mode, return_array=True, convert=True + ) def transform(*args, **kwargs): if in_mode is tuple: @@ -40,11 +46,11 @@ def transform(*args, **kwargs): elif in_mode is dict: assert len(args) == 0 assert len(kwargs) == 3 - a, b, c = kwargs['a'], kwargs['b'], kwargs['c'] + a, b, c = kwargs["a"], kwargs["b"], kwargs["c"] elif in_mode is None: assert len(args) == 1 assert len(kwargs) == 0 - a, = args + (a,) = args b, c = a, a if with_batch: @@ -59,7 +65,7 @@ def transform(*args, **kwargs): if out_mode is tuple: return a + b, b + c elif out_mode is dict: - return {'alpha': a + b, 'beta': b + c} + return {"alpha": a + b, "beta": b + c} elif out_mode is None: return a + b + c @@ -71,11 +77,11 @@ def transform_alpha(*args, **kwargs): elif in_mode is dict: assert len(args) == 0 assert len(kwargs) == 3 - a, b, c = kwargs['a'], kwargs['b'], kwargs['c'] + a, b, c = kwargs["a"], kwargs["b"], kwargs["c"] elif in_mode is None: assert len(args) == 1 assert len(kwargs) == 0 - a, = args + (a,) = args b, c = a, a if with_batch: @@ -88,9 +94,9 @@ def transform_alpha(*args, **kwargs): assert isinstance(c, float) if out_mode is tuple: - return a + b, + return (a + b,) elif out_mode is dict: - return {'alpha': a + b} + return {"alpha": a + b} elif out_mode is None: return a + b + c @@ -102,11 +108,11 @@ def transform_beta(*args, **kwargs): elif in_mode is dict: assert len(args) == 0 assert len(kwargs) == 3 - a, b, c = kwargs['a'], kwargs['b'], kwargs['c'] + a, b, c = kwargs["a"], kwargs["b"], kwargs["c"] elif in_mode is None: assert len(args) == 1 assert len(kwargs) == 0 - a, = args + (a,) = args b, c = a, a if with_batch: @@ -119,51 +125,49 @@ def transform_beta(*args, **kwargs): assert isinstance(c, float) if out_mode is tuple: - return b + c, + return (b + c,) elif out_mode is dict: - return {'beta': b + c} + return {"beta": b + c} elif out_mode is None: return a + b + c if in_mode is not None: a, b, c = dataset.data else: - a, = dataset.data + (a,) = dataset.data b, c = a, a if out_mode is not None: if in_mode is not None: - d_transform = [ - ((('a', 'b', 'c'), ('alpha', 'beta')), transform)] + d_transform = [((("a", "b", "c"), ("alpha", "beta")), transform)] else: d_transform = [ - ((('a',), ('alpha',)), transform_alpha), - ((('a',), ('beta',)), transform_beta)] + ((("a",), ("alpha",)), transform_alpha), + ((("a",), ("beta",)), transform_beta), + ] if with_batch: - view = dataset.transform_batch(('alpha', 'beta'), d_transform) + view = dataset.transform_batch(("alpha", "beta"), d_transform) else: - view = dataset.transform(('alpha', 'beta'), d_transform) + view = dataset.transform(("alpha", "beta"), d_transform) data = np.vstack((a + b, b + c)) else: if in_mode is not None: - d_transform = [ - ((('a', 'b', 'c'), ('alpha',)), transform_alpha)] + d_transform = [((("a", "b", "c"), ("alpha",)), transform_alpha)] else: - d_transform = [ - ((('a',), ('alpha',)), transform_alpha)] + d_transform = [((("a",), ("alpha",)), transform_alpha)] if with_batch: - view = dataset.transform_batch(('alpha',), d_transform) + view = dataset.transform_batch(("alpha",), d_transform) else: - view = dataset.transform(('alpha',), d_transform) + view = dataset.transform(("alpha",), d_transform) data = (a + b + c)[None] assert isinstance(view, ppe.dataset.TabularDataset) assert len(view) == len(dataset) if out_mode is not None: - assert view.keys == ('alpha', 'beta') + assert view.keys == ("alpha", "beta") assert view.mode == out_mode else: - assert view.keys == ('alpha',) + assert view.keys == ("alpha",) assert view.mode == out_mode output = view.get_examples(indices, key_indices) @@ -180,15 +184,11 @@ def transform_beta(*args, **kwargs): else: assert isinstance(out, list) - assert view.convert(view.fetch()) == 'converted' + assert view.convert(view.fetch()) == "converted" -@pytest.mark.parametrize( - 'mode', - [tuple, dict, None] -) +@pytest.mark.parametrize("mode", [tuple, dict, None]) class TestTransformInvalid: - def setup_method(self): self.count = 0 @@ -205,9 +205,9 @@ def _transform(self, a, b, c): mode = tuple if mode is tuple: - return a, + return (a,) elif mode is dict: - return {'a': a} + return {"a": a} elif mode is None: return a @@ -215,8 +215,8 @@ def test_transform_inconsistent_mode(self, mode): dataset = dummy_dataset.DummyDataset() self.mode = mode view = dataset.transform( - ('a',), - [((('a', 'b', 'c'), ('a',)), self._transform)]) + ("a",), [((("a", "b", "c"), ("a",)), self._transform)] + ) view.get_examples([0], None) with pytest.raises(ValueError): view.get_examples([0], None) @@ -225,8 +225,8 @@ def test_transform_batch_inconsistent_mode(self, mode): dataset = dummy_dataset.DummyDataset() self.mode = mode view = dataset.transform_batch( - ('a',), - [((('a', 'b', 'c'), ('a',)), self._transform)]) + ("a",), [((("a", "b", "c"), ("a",)), self._transform)] + ) view.get_examples(None, None) with pytest.raises(ValueError): view.get_examples(None, None) @@ -237,14 +237,14 @@ def test_transform_batch_length_changed(self, mode): def transform_batch(a, b, c): if self.mode is tuple: - return a + [0], + return (a + [0],) elif self.mode is dict: - return {'a': a + [0]} + return {"a": a + [0]} elif self.mode is None: return a + [0] view = dataset.transform_batch( - ('a',), - [((('a', 'b', 'c'), ('a',)), transform_batch)]) + ("a",), [((("a", "b", "c"), ("a",)), transform_batch)] + ) with pytest.raises(ValueError): view.get_examples(None, None) diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py index 0de654ad5..d8a135d42 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py @@ -1,14 +1,12 @@ import numpy as np import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( + dummy_dataset, # NOQA +) -@pytest.mark.parametrize( - 'mode', - [tuple, dict, None] -) +@pytest.mark.parametrize("mode", [tuple, dict, None]) def test_with_converter(mode): dataset = dummy_dataset.DummyDataset(mode=mode) @@ -19,18 +17,18 @@ def converter(*args, **kwargs): elif mode is dict: assert args == () np.testing.assert_equal( - kwargs, dict(zip(('a', 'b', 'c'), dataset.data))) + kwargs, dict(zip(("a", "b", "c"), dataset.data)) + ) elif mode is None: np.testing.assert_equal(args, tuple(dataset.data)) assert kwargs == {} - return 'converted' + return "converted" view = dataset.with_converter(converter) assert isinstance(view, ppe.dataset.TabularDataset) assert len(view) == len(dataset) assert view.keys == dataset.keys assert view.mode == dataset.mode - assert ( - view.get_examples(None, None) == dataset.get_examples(None, None)) - assert view.convert(view.fetch()) == 'converted' + assert view.get_examples(None, None) == dataset.get_examples(None, None) + assert view.convert(view.fetch()) == "converted" diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py index 84d1beb8a..ca63f4880 100644 --- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py +++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py @@ -1,23 +1,22 @@ import pytest import torch - -from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( +from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( # NOQA dummy_dataset, -) # NOQA +) @pytest.mark.parametrize( - 'batch_size,mode', + "batch_size,mode", [(1, dict), (2, dict), (8, dict), (1, tuple), (2, tuple), (8, tuple)], ) def test_with_dataloader(batch_size, mode): size = 10 - keys = ('a', 'b', 'c') + keys = ("a", "b", "c") dataset = dummy_dataset.DummyDataset(size=size, keys=keys, mode=mode) expected = torch.tensor(dataset.data).type(torch.float64) expected_per_key = [ [ - expected[i, j * batch_size:(j + 1) * batch_size] + expected[i, j * batch_size : (j + 1) * batch_size] for j in range((size + batch_size - 1) // batch_size) ] for i in range(len(keys)) @@ -27,4 +26,5 @@ def test_with_dataloader(batch_size, mode): for i, example in enumerate(dataloader): for j, key in enumerate(keys): assert torch.allclose( - expected_per_key[j][i], example[key if mode == dict else j]) + expected_per_key[j][i], example[key if mode == dict else j] + ) diff --git a/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py b/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py index 1554d4d6f..c8dbb4c0f 100644 --- a/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py +++ b/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py @@ -1,8 +1,7 @@ -import torch.distributed as dist - -import pytest from unittest import mock +import pytest +import torch.distributed as dist from pytorch_pfn_extras.distributed import DistributedValidationSampler _world_size = 4 @@ -17,10 +16,11 @@ def base_dataset(): def test_default(base_dataset): expected_lengths = [6, 5, 5, 5] sample_idxs = [] - with mock.patch.object(dist, 'get_world_size', return_value=_world_size), \ - mock.patch.object(dist, 'is_available', return_value=True): + with mock.patch.object( + dist, "get_world_size", return_value=_world_size + ), mock.patch.object(dist, "is_available", return_value=True): for rank in range(_world_size): - with mock.patch.object(dist, 'get_rank', return_value=rank): + with mock.patch.object(dist, "get_rank", return_value=rank): sampler = DistributedValidationSampler(base_dataset) assert len(sampler) == expected_lengths[rank] sample_idxs += list(sampler) @@ -39,11 +39,14 @@ def test_no_shuffle(base_dataset): [11, 12, 13, 14, 15], [16, 17, 18, 19, 20], ] - with mock.patch.object(dist, 'get_world_size', return_value=_world_size), \ - mock.patch.object(dist, 'is_available', return_value=True): + with mock.patch.object( + dist, "get_world_size", return_value=_world_size + ), mock.patch.object(dist, "is_available", return_value=True): for rank in range(_world_size): - with mock.patch.object(dist, 'get_rank', return_value=rank): - sampler = DistributedValidationSampler(base_dataset, shuffle=False) + with mock.patch.object(dist, "get_rank", return_value=rank): + sampler = DistributedValidationSampler( + base_dataset, shuffle=False + ) assert list(sampler) == expected_samples[rank] @@ -51,17 +54,27 @@ def test_manual_num_replicas_and_ranks(base_dataset): # When manually specifying num_replicas and rank, # it doesn't rely on these torch.distributed functions. expected_lengths = [6, 5, 5, 5] - with mock.patch.object(dist, 'get_world_size', side_effect=AssertionError()), \ - mock.patch.object(dist, 'is_available', side_effect=AssertionError()), \ - mock.patch.object(dist, 'get_rank', side_effect=AssertionError()): + with mock.patch.object( + dist, "get_world_size", side_effect=AssertionError() + ), mock.patch.object( + dist, "is_available", side_effect=AssertionError() + ), mock.patch.object( + dist, "get_rank", side_effect=AssertionError() + ): for rank in range(_world_size): - sampler = DistributedValidationSampler(base_dataset, num_replicas=_world_size, rank=rank) + sampler = DistributedValidationSampler( + base_dataset, num_replicas=_world_size, rank=rank + ) assert len(sampler) == expected_lengths[rank] def test_seed(base_dataset): - sampler1 = DistributedValidationSampler(base_dataset, num_replicas=_world_size, rank=0, seed=1) - sampler2 = DistributedValidationSampler(base_dataset, num_replicas=_world_size, rank=0, seed=2) + sampler1 = DistributedValidationSampler( + base_dataset, num_replicas=_world_size, rank=0, seed=1 + ) + sampler2 = DistributedValidationSampler( + base_dataset, num_replicas=_world_size, rank=0, seed=2 + ) assert list(sampler1) != list(sampler2) @@ -73,7 +86,7 @@ def test_no_distributed_available(base_dataset): def test_invalid_rank(base_dataset): - with mock.patch.object(dist, 'get_world_size', return_value=_world_size): + with mock.patch.object(dist, "get_world_size", return_value=_world_size): with pytest.raises(ValueError): DistributedValidationSampler(base_dataset, rank=-1) with pytest.raises(ValueError): diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py index 1e758cb9e..c53183b15 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py @@ -1,16 +1,15 @@ import pytest import torch - from pytorch_pfn_extras.nn import Ensure, ensure class TestEnsure: def test_wrong_initialization(self): - with pytest.raises(ValueError, match='both arguments'): + with pytest.raises(ValueError, match="both arguments"): Ensure(shape=None, dtype=None) @pytest.mark.parametrize( - 'shape', [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)] + "shape", [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)] ) def test_valid_shape(self, shape): tensor = torch.zeros(shape) @@ -20,69 +19,81 @@ def test_valid_shape(self, shape): ensure(tensor, shape) @pytest.mark.parametrize( - 'shape', [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)] + "shape", [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)] ) def test_invalid_shape(self, shape): tensor = torch.zeros((1, 2, 3)) module = Ensure(shape=shape) - with pytest.raises(ValueError, match='input shape is'): + with pytest.raises(ValueError, match="input shape is"): module(tensor) - with pytest.raises(ValueError, match='input shape is'): + with pytest.raises(ValueError, match="input shape is"): ensure(tensor, shape) - @pytest.mark.parametrize('shape_t, shape_c', [ - ((), (2,)), - ((3,), ()), - ((1,), (2,)), - ((1, 1), (2, 1)), - ((1, 1), (2,)), - ((2, 1), (2, 2)), - ((2, 4), (4,)), - ((2, 3, 4), (1, 4)), - ]) + @pytest.mark.parametrize( + "shape_t, shape_c", + [ + ((), (2,)), + ((3,), ()), + ((1,), (2,)), + ((1, 1), (2, 1)), + ((1, 1), (2,)), + ((2, 1), (2, 2)), + ((2, 4), (4,)), + ((2, 3, 4), (1, 4)), + ], + ) def test_broadcastable_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c, broadcastable=True) module(tensor) - @pytest.mark.parametrize('shape_t, shape_c', [ - ((3,), (2,)), - ((2, 3), (2, 2)), - ((2, 4), (3, 4)), - ((2, 4), (3,)), - ((2, 3, 4), (2, 2, 1)), - ]) + @pytest.mark.parametrize( + "shape_t, shape_c", + [ + ((3,), (2,)), + ((2, 3), (2, 2)), + ((2, 4), (3, 4)), + ((2, 4), (3,)), + ((2, 3, 4), (2, 2, 1)), + ], + ) def test_nonbroadcastable_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c, broadcastable=True) - with pytest.raises(ValueError, match='non broadcastable'): + with pytest.raises(ValueError, match="non broadcastable"): module(tensor) - @pytest.mark.parametrize('shape_t, shape_c', [ - ((2,), (None,)), - ((2, 2), (2, None)), - ((2, 1), (None, 2)), - ((1, 4), (None, 4)), - ((2, 3, 4), (2, None, 4)), - ]) + @pytest.mark.parametrize( + "shape_t, shape_c", + [ + ((2,), (None,)), + ((2, 2), (2, None)), + ((2, 1), (None, 2)), + ((1, 4), (None, 4)), + ((2, 3, 4), (2, None, 4)), + ], + ) def test_unknown_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c) module(tensor) - @pytest.mark.parametrize('shape_t, shape_c', [ - ((3, 2), (2, None)), - ((1, 4), (None, 2)), - ((2, 3, 4), (3, None, 4)), - ]) + @pytest.mark.parametrize( + "shape_t, shape_c", + [ + ((3, 2), (2, None)), + ((1, 4), (None, 2)), + ((2, 3, 4), (3, None, 4)), + ], + ) def test_invalid_unknown_shape(self, shape_t, shape_c): tensor = torch.zeros(shape_t) module = Ensure(shape=shape_c) - with pytest.raises(ValueError, match='non broadcastable'): + with pytest.raises(ValueError, match="non broadcastable"): module(tensor) @pytest.mark.parametrize( - 'dtype', [torch.int32, torch.float32, torch.complex64] + "dtype", [torch.int32, torch.float32, torch.complex64] ) def test_valid_dtypes(self, dtype): tensor = torch.zeros(1, dtype=dtype) @@ -90,36 +101,45 @@ def test_valid_dtypes(self, dtype): module(tensor) ensure(tensor, None, dtype) - @pytest.mark.parametrize('dtype_t, dtype_c', [ - (torch.int32, torch.int16), - (torch.int32, torch.float32), - (torch.float32, torch.float64), - (torch.float32, torch.complex64), - ]) + @pytest.mark.parametrize( + "dtype_t, dtype_c", + [ + (torch.int32, torch.int16), + (torch.int32, torch.float32), + (torch.float32, torch.float64), + (torch.float32, torch.complex64), + ], + ) def test_invalid_dtypes(self, dtype_t, dtype_c): tensor = torch.zeros(1, dtype=dtype_t) module = Ensure(shape=None, dtype=dtype_c) - with pytest.raises(ValueError, match='input dtype'): + with pytest.raises(ValueError, match="input dtype"): module(tensor) - @pytest.mark.parametrize('dtype_t, dtype_c', [ - (torch.int32, torch.float32), - (torch.int32, torch.complex128), - (torch.int8, torch.float16), - ]) + @pytest.mark.parametrize( + "dtype_t, dtype_c", + [ + (torch.int32, torch.float32), + (torch.int32, torch.complex128), + (torch.int8, torch.float16), + ], + ) def test_dtypes_with_cast(self, dtype_t, dtype_c): tensor = torch.zeros(1, dtype=dtype_t) module = Ensure(shape=None, dtype=dtype_c, can_cast=True) module(tensor) - @pytest.mark.parametrize('dtype_t, dtype_c', [ - (torch.complex64, torch.int32), - (torch.float32, torch.int32), - ]) + @pytest.mark.parametrize( + "dtype_t, dtype_c", + [ + (torch.complex64, torch.int32), + (torch.float32, torch.int32), + ], + ) def test_invalid_dtypes_with_cast(self, dtype_t, dtype_c): tensor = torch.zeros(1, dtype=dtype_t) module = Ensure(shape=None, dtype=dtype_c, can_cast=True) - with pytest.raises(ValueError, match='be casted to'): + with pytest.raises(ValueError, match="be casted to"): module(tensor) def test_valid_shape_and_dtype(self): @@ -131,7 +151,7 @@ def test_valid_shape_and_dtype(self): ensure(tensor, shape, dtype) # Too many warnings to list them all - @pytest.mark.filterwarnings('ignore') + @pytest.mark.filterwarnings("ignore") def test_jit_module(self): shape = (10, 5) dtype = torch.float32 @@ -146,7 +166,7 @@ def test_jit_module(self): jit_module(tensor) # Tracing with a different shape fails during trace process - with pytest.raises(ValueError, match='input shape is'): + with pytest.raises(ValueError, match="input shape is"): jit_module = torch.jit.trace(module, (tensor,)) def test_torchscript_module(self): @@ -160,5 +180,5 @@ def test_torchscript_module(self): shape = (5, 5) tensor = torch.zeros(shape, dtype=dtype) # torchscript changes the exception type - with pytest.raises(torch.jit.Error, match='input shape is'): + with pytest.raises(torch.jit.Error, match="input shape is"): jit_module(tensor) diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py index b9fc52d5e..0a0f09cb5 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py @@ -1,14 +1,14 @@ -import unittest -import pytest import functools +import unittest import warnings import numpy +import pytest +import pytorch_pfn_extras as ppe import torch from torch import nn -import pytorch_pfn_extras as ppe -assertions = unittest.TestCase('__init__') +assertions = unittest.TestCase("__init__") class UserDefinedLayer(nn.Module): @@ -19,27 +19,29 @@ def forward(self): pass -@pytest.mark.parametrize('container', [ - nn.Sequential, - nn.ModuleList, - nn.ModuleDict, -]) -@pytest.mark.parametrize('irregular_layer', [ - UserDefinedLayer, - # No reset_parameters - nn.ReLU, - # use reset_running_stats - functools.partial( - nn.BatchNorm1d, 1), - # use _reset_parameters - functools.partial( - nn.MultiheadAttention, 1, 1), - # ppe.nn layer - functools.partial( - ppe.nn.LazyConv1d, None, 1, 1), -]) +@pytest.mark.parametrize( + "container", + [ + nn.Sequential, + nn.ModuleList, + nn.ModuleDict, + ], +) +@pytest.mark.parametrize( + "irregular_layer", + [ + UserDefinedLayer, + # No reset_parameters + nn.ReLU, + # use reset_running_stats + functools.partial(nn.BatchNorm1d, 1), + # use _reset_parameters + functools.partial(nn.MultiheadAttention, 1, 1), + # ppe.nn layer + functools.partial(ppe.nn.LazyConv1d, None, 1, 1), + ], +) class TestExtendedSequential(object): - @pytest.fixture(autouse=True) def setUp(self, container, irregular_layer): self.l1 = ppe.nn.LazyLinear(None, 3) @@ -51,9 +53,7 @@ def setUp(self, container, irregular_layer): if container == nn.Sequential: self.s1 = container(self.l1, self.l2) elif container == nn.ModuleDict: - self.s1 = container({ - 'l1': self.l1, - 'l2': self.l2}) + self.s1 = container({"l1": self.l1, "l2": self.l2}) else: self.s1 = container([self.l1, self.l2]) self.container = container @@ -71,36 +71,46 @@ def test_repeat_with_init(self): # bias is filled with 0, so they should have the same values if self.container == nn.ModuleDict: numpy.testing.assert_array_equal( - ret[0][0]['l1'].bias.detach().numpy(), - ret[1][0]['l1'].bias.detach().numpy()) + ret[0][0]["l1"].bias.detach().numpy(), + ret[1][0]["l1"].bias.detach().numpy(), + ) else: numpy.testing.assert_array_equal( ret[0][0][0].bias.detach().numpy(), - ret[1][0][0].bias.detach().numpy()) + ret[1][0][0].bias.detach().numpy(), + ) # weight is initialized randomly, so they should be different assertions.assertFalse( - numpy.array_equal(ret[0][1].weight.detach().numpy(), - self.l3.weight.detach().numpy())) + numpy.array_equal( + ret[0][1].weight.detach().numpy(), + self.l3.weight.detach().numpy(), + ) + ) # And the object should also be different - assertions.assertIsNot(ret[0][1].weight.detach().numpy(), - self.l3.weight.detach().numpy()) + assertions.assertIsNot( + ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy() + ) # Repeated elements should be different objects assertions.assertIsNot(ret[0], ret[1]) # Also for the arrays - assertions.assertIsNot(ret[0][1].weight.detach().numpy(), - ret[1][1].weight.detach().numpy()) + assertions.assertIsNot( + ret[0][1].weight.detach().numpy(), ret[1][1].weight.detach().numpy() + ) # And values should be different assertions.assertFalse( - numpy.array_equal(ret[0][1].weight.detach().numpy(), - ret[1][1].weight.detach().numpy())) + numpy.array_equal( + ret[0][1].weight.detach().numpy(), + ret[1][1].weight.detach().numpy(), + ) + ) assertions.assertEqual(len(ret), 2) - ret = self.s2.repeat(0, mode='init') + ret = self.s2.repeat(0, mode="init") assertions.assertEqual(len(ret), 0) def test_repeat_with_copy(self): # s2 ((l1 -> l2) -> l3 -> l4) -> s2 ((l1 -> l2) -> l3 -> l4) - ret = self.s2.repeat(2, mode='copy') + ret = self.s2.repeat(2, mode="copy") assertions.assertIsNot(ret[0], self.s2) assertions.assertIs(type(ret[0]), type(self.s2)) assertions.assertIsNot(ret[1], self.s2) @@ -111,14 +121,17 @@ def test_repeat_with_copy(self): if self.container == nn.ModuleDict: numpy.testing.assert_array_equal( ret[0][0]["l1"].bias.detach().numpy(), - ret[1][0]["l1"].bias.detach().numpy()) + ret[1][0]["l1"].bias.detach().numpy(), + ) else: numpy.testing.assert_array_equal( ret[0][0][0].bias.detach().numpy(), - ret[1][0][0].bias.detach().numpy()) + ret[1][0][0].bias.detach().numpy(), + ) # W is shallowy copied, so the values should be same numpy.testing.assert_array_equal( - ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy()) + ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy() + ) # But the object should be different assertions.assertIsNot(ret[0][1].weight, self.l3.weight) # Repeated elements should be different objects @@ -127,16 +140,16 @@ def test_repeat_with_copy(self): assertions.assertIsNot(ret[0][1].weight, ret[1][1].weight) # But the values should be same numpy.testing.assert_array_equal( - ret[0][1].weight.detach().numpy(), - ret[1][1].weight.detach().numpy()) + ret[0][1].weight.detach().numpy(), ret[1][1].weight.detach().numpy() + ) assertions.assertEqual(len(ret), 2) - ret = self.s2.repeat(0, mode='copy') + ret = self.s2.repeat(0, mode="copy") assertions.assertEqual(len(ret), 0) def test_repeat_with_share(self): # s2 ((l1 -> l2) -> l3 -> l4) -> s2 ((l1 -> l2) -> l3 -> l4) - ret = self.s2.repeat(2, mode='share') + ret = self.s2.repeat(2, mode="share") assertions.assertIsNot(ret[0], self.s2) assertions.assertIs(type(ret[0]), type(self.s2)) assertions.assertIsNot(ret[1], self.s2) @@ -146,16 +159,20 @@ def test_repeat_with_share(self): if self.container == nn.ModuleDict: numpy.testing.assert_array_equal( ret[0][0]["l1"].bias.detach().numpy(), - ret[1][0]["l1"].bias.detach().numpy()) + ret[1][0]["l1"].bias.detach().numpy(), + ) else: numpy.testing.assert_array_equal( ret[0][0][0].bias.detach().numpy(), - ret[1][0][0].bias.detach().numpy()) + ret[1][0][0].bias.detach().numpy(), + ) # W is shallowy copied, so the values should be same numpy.testing.assert_array_equal( - ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy()) + ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy() + ) numpy.testing.assert_array_equal( - ret[1][1].weight.detach().numpy(), self.l3.weight.detach().numpy()) + ret[1][1].weight.detach().numpy(), self.l3.weight.detach().numpy() + ) # And the object should also be same assertions.assertIs(ret[0][1].weight, self.l3.weight) assertions.assertIs(ret[1][1].weight, self.l3.weight) @@ -163,7 +180,7 @@ def test_repeat_with_share(self): assertions.assertIsNot(ret[0], ret[1]) assertions.assertEqual(len(ret), 2) - ret = self.s2.repeat(0, mode='share') + ret = self.s2.repeat(0, mode="share") assertions.assertEqual(len(ret), 0) @@ -193,7 +210,7 @@ class UserDefinedLayerWithParameters(nn.Module): def __init__(self): super().__init__() param = nn.Parameter(torch.zeros(1, 1)) - self.register_parameter('weight', param) + self.register_parameter("weight", param) def forward(self): pass @@ -202,22 +219,25 @@ def forward(self): class UserDefinedLayerWithBuffer(nn.Module): def __init__(self): super().__init__() - self.register_buffer('weight', torch.zeros(1, 1)) + self.register_buffer("weight", torch.zeros(1, 1)) def forward(self): pass -@pytest.mark.parametrize('module', [ - # buit-in, no parameters - nn.ReLU, - # no parameters - UserDefinedLayer, - # has `_reset_parameters` - UserDefinedLayerWithUnderScoreReset, - # has `reset_parameters` - UserDefinedLayerWithReset, -]) +@pytest.mark.parametrize( + "module", + [ + # buit-in, no parameters + nn.ReLU, + # no parameters + UserDefinedLayer, + # has `_reset_parameters` + UserDefinedLayerWithUnderScoreReset, + # has `reset_parameters` + UserDefinedLayerWithReset, + ], +) def test_no_warning_when_repeat(module): model = ppe.nn.ExtendedSequential(module()) # no warnings are raised on these modules @@ -226,10 +246,13 @@ def test_no_warning_when_repeat(module): model.repeat(2) -@pytest.mark.parametrize('module', [ - UserDefinedLayerWithParameters, - UserDefinedLayerWithBuffer, -]) +@pytest.mark.parametrize( + "module", + [ + UserDefinedLayerWithParameters, + UserDefinedLayerWithBuffer, + ], +) def test_warning_when_repeat(module): model = ppe.nn.ExtendedSequential(module()) # warnings are raised on these modules diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py index 25297d181..c91cda51b 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py @@ -2,21 +2,21 @@ import pytest import torch +from pytorch_pfn_extras.nn.modules.lazy import ( + LazyInitializationMixin, + UninitializedParameter, +) from torch import nn from torch.nn import functional as F -from pytorch_pfn_extras.nn.modules.lazy import LazyInitializationMixin -from pytorch_pfn_extras.nn.modules.lazy import UninitializedParameter - class _MyFunc(torch.nn.Module): - def __init__(self, in_features, out_features): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) - self.register_buffer('const', torch.full((in_features,), 1.0)) + self.register_buffer("const", torch.full((in_features,), 1.0)) self._reset_params() def forward(self, input): @@ -27,9 +27,8 @@ def _reset_params(self): class _LazyMyFunc(LazyInitializationMixin, _MyFunc): - - lazy_parameter_names = ('weight',) - lazy_buffer_names = ('const',) + lazy_parameter_names = ("weight",) + lazy_buffer_names = ("const",) def __init__(self, in_features, out_features): super().__init__(in_features or 0, out_features) @@ -41,7 +40,8 @@ def forward(self, input): if isinstance(self.weight, UninitializedParameter): self.in_features = input.shape[-1] self.weight = torch.nn.Parameter( - self.weight.new_empty((self.out_features, self.in_features))) + self.weight.new_empty((self.out_features, self.in_features)) + ) self.const = self.const.new_full((self.in_features,), 1) self._reset_params() self.to(input.device) @@ -53,7 +53,6 @@ def _reset_params(self): class LazyTestBase: - def get_original_module(self): raise NotImplementedError @@ -122,15 +121,20 @@ def test_lazy_warning(self): m = self.get_lazy_module() with pytest.warns(UserWarning) as record: torch.optim.SGD(m.parameters(), lr=0.1) - assert ('Use of uninitialized lazy parameter in Optimizer ' - 'has been detected' in record[0].message.args[0]) - - @pytest.mark.parametrize('init_src, init_dst', [ - (True, True), - (True, False), - (False, True), - (False, False), - ]) + assert ( + "Use of uninitialized lazy parameter in Optimizer " + "has been detected" in record[0].message.args[0] + ) + + @pytest.mark.parametrize( + "init_src, init_dst", + [ + (True, True), + (True, False), + (False, True), + (False, False), + ], + ) def test_save_load(self, init_src, init_dst): torch.manual_seed(0) input = self.get_input() @@ -148,8 +152,7 @@ def test_save_load(self, init_src, init_dst): for name in model_dst.lazy_parameter_names ] module_buffers = [ - getattr(model_dst, name) - for name in model_dst.lazy_buffer_names + getattr(model_dst, name) for name in model_dst.lazy_buffer_names ] with tempfile.NamedTemporaryFile(delete=False) as f: @@ -181,7 +184,6 @@ def test_save_load(self, init_src, init_dst): class TestLazyMyFunc(LazyTestBase): - def get_original_module(self): return _MyFunc(10, 20) diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py index a27764929..6c236c436 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py @@ -1,14 +1,16 @@ import torch +from pytorch_pfn_extras.nn import ( # NOQA + LazyBatchNorm1d, + LazyBatchNorm2d, + LazyBatchNorm3d, +) +from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import ( + LazyTestBase, +) from torch import nn -from pytorch_pfn_extras.nn import LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d # NOQA - -from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import \ - LazyTestBase - class TestLazyBatchNorm1d(LazyTestBase): - def get_original_module(self): return nn.BatchNorm1d(10) @@ -20,7 +22,6 @@ def get_input(self): class TestLazyBatchNorm2d(LazyTestBase): - def get_original_module(self): return nn.BatchNorm2d(10) @@ -32,7 +33,6 @@ def get_input(self): class TestLazyBatchNorm3d(LazyTestBase): - def get_original_module(self): return nn.BatchNorm3d(10) diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py index c71d2d2ee..8488e212c 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py @@ -1,14 +1,12 @@ import torch -from torch import nn - from pytorch_pfn_extras.nn import LazyConv1d, LazyConv2d, LazyConv3d - -from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import \ - LazyTestBase +from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import ( + LazyTestBase, +) +from torch import nn class TestLazyConv1d(LazyTestBase): - def get_original_module(self): return nn.Conv1d(3, 4, 2) @@ -20,7 +18,6 @@ def get_input(self): class TestLazyConv2d(LazyTestBase): - def get_original_module(self): return nn.Conv2d(3, 4, 2) @@ -32,7 +29,6 @@ def get_input(self): class TestLazyConv3d(LazyTestBase): - def get_original_module(self): return nn.Conv3d(3, 4, 2) diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py index c57b72d31..f38702c0e 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py @@ -1,14 +1,12 @@ import torch -from torch import nn - from pytorch_pfn_extras.nn import LazyLinear - -from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import \ - LazyTestBase +from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import ( + LazyTestBase, +) +from torch import nn class TestLazyLinear(LazyTestBase): - def get_original_module(self): return nn.Linear(10, 20) diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py b/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py index ec2fd1b01..657402d8e 100644 --- a/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py +++ b/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py @@ -1,20 +1,18 @@ import os import sys -import urllib.request import tempfile +import urllib.request import numpy as np import pytest import pytorch_pfn_extras import torch -from torch import multiprocessing as mp +from pytorch_pfn_extras.nn.parallel import DistributedDataParallel from torch import distributed as dist +from torch import multiprocessing as mp from torch import nn from torch.utils.checkpoint import checkpoint -from pytorch_pfn_extras.nn.parallel import DistributedDataParallel - - context = mp.get_context("spawn") @@ -57,8 +55,8 @@ def _to_zero(values, group): class MyModule(nn.Module): def __init__(self): super().__init__() - self.param0 = nn.Parameter(torch.tensor(-1.)) - self.param1 = nn.Parameter(torch.tensor(1.)) + self.param0 = nn.Parameter(torch.tensor(-1.0)) + self.param1 = nn.Parameter(torch.tensor(1.0)) buf = torch.zeros(1) self.register_buffer("buffer", buf) @@ -87,9 +85,9 @@ def forward(self, x): def _run(init_file, input, module, rank, args, step, device_type): init_method = "file://{}".format(urllib.request.pathname2url(init_file)) - dist.init_process_group(backend="gloo", - init_method=init_method, - world_size=2, rank=rank) + dist.init_process_group( + backend="gloo", init_method=init_method, world_size=2, rank=rank + ) if device_type == "cpu": device = torch.device(device_type) elif device_type == "cuda": @@ -106,11 +104,9 @@ def _run(init_file, input, module, rank, args, step, device_type): return output.detach(), module.state_dict(), grads -def _launch(inputs, - modules=None, - args=None, - step=Steps._step, - device_type="cpu"): +def _launch( + inputs, modules=None, args=None, step=Steps._step, device_type="cpu" +): procs = [] with tempfile.TemporaryDirectory() as tmpdir, context.Pool(2) as pool: if modules is None: @@ -121,9 +117,8 @@ def _launch(inputs, file = os.path.join(tmpdir, "init") for i, (input, module) in enumerate(zip(inputs, modules)): p = pool.apply_async( - _run, - args=(file, input, module, i, args, step, - device_type)) + _run, args=(file, input, module, i, args, step, device_type) + ) procs.append(p) return [p.get() for p in procs] @@ -136,39 +131,44 @@ def _device_types(): @pytest.mark.skipif( - sys.platform == 'win32', - reason='DDP not fully supported on Windows') + sys.platform == "win32", reason="DDP not fully supported on Windows" +) class TestDistributedDataParallel: def test_save_load(self): module = MyModule() with_ddp = DistributedDataParallel(module) assert module.state_dict().keys() == with_ddp.state_dict().keys() module.load_state_dict(with_ddp.state_dict()) - assert np.array_equal(module.state_dict()["param0"], - with_ddp.state_dict()["param0"]) - assert np.array_equal(module.state_dict()["param1"], - with_ddp.state_dict()["param1"]) - assert np.array_equal(module.state_dict()["buffer"], - with_ddp.state_dict()["buffer"]) - - @pytest.mark.parametrize('device_type', _device_types()) + assert np.array_equal( + module.state_dict()["param0"], with_ddp.state_dict()["param0"] + ) + assert np.array_equal( + module.state_dict()["param1"], with_ddp.state_dict()["param1"] + ) + assert np.array_equal( + module.state_dict()["buffer"], with_ddp.state_dict()["buffer"] + ) + + @pytest.mark.parametrize("device_type", _device_types()) def test_sync_init_params(self, device_type): module0 = MyModule() - module0.param0.data = torch.tensor([1.]) + module0.param0.data = torch.tensor([1.0]) r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], modules=[module0, MyModule()], - device_type=device_type) + device_type=device_type, + ) assert r0[0].item() == 1 assert r1[0].item() == 2 assert r0[1]["param0"].item() == 1.0 assert r1[1]["param0"].item() == 1.0 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_all_reduce(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], - device_type=device_type) + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], + device_type=device_type, + ) assert r0[0].item() == -1 assert r1[0].item() == -2 assert r0[2]["module.param0"].item() == 1.5 @@ -176,52 +176,59 @@ def test_all_reduce(self, device_type): assert r0[2]["module.param1"] is None assert r1[2]["module.param1"] is None - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_specific_reduce(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], args={"reduce_function": Collectives._to_zero}, - device_type=device_type) + device_type=device_type, + ) assert r0[2]["module.param0"].item() == 0.0 assert r1[2]["module.param0"].item() == 0.0 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_nosync_buffer(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], args={"broadcast_buffers": False}, - device_type=device_type) + device_type=device_type, + ) assert r0[0].item() == -1 assert r1[0].item() == -2 assert r0[1]["buffer"].item() == 1 assert r1[1]["buffer"].item() == 2 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_sync_buffer(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], args={"broadcast_buffers": True}, - device_type=device_type) + device_type=device_type, + ) assert r0[0].item() == -1 assert r1[0].item() == -2 assert r0[1]["buffer"].item() == 1 assert r1[1]["buffer"].item() == 1 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_specific_broadcast(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], - args={"broadcast_function": Collectives._to_zero, - "broadcast_buffers": True}, - device_type=device_type) + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], + args={ + "broadcast_function": Collectives._to_zero, + "broadcast_buffers": True, + }, + device_type=device_type, + ) assert r0[1]["buffer"].item() == 0.0 assert r1[1]["buffer"].item() == 0.0 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_define_by_run(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([-1])], - device_type=device_type) + inputs=[torch.tensor([1.0]), torch.tensor([-1])], + device_type=device_type, + ) assert r0[0].item() == -1 assert r1[0].item() == 1 assert r0[2]["module.param0"].item() == 0.5 @@ -229,12 +236,13 @@ def test_define_by_run(self, device_type): assert r0[2]["module.param1"].item() == 0.5 assert r1[2]["module.param1"].item() == 0.5 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_no_sync(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], step=Steps._step_with_no_sync, - device_type=device_type) + device_type=device_type, + ) assert r0[0].item() == -1 assert r1[0].item() == -2 assert r0[2]["module.param0"].item() == 1 @@ -242,12 +250,13 @@ def test_no_sync(self, device_type): assert r0[2]["module.param1"] is None assert r1[2]["module.param1"] is None - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) def test_hook(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([1.]), torch.tensor([2.])], + inputs=[torch.tensor([1.0]), torch.tensor([2.0])], step=Steps._step_with_hook, - device_type=device_type) + device_type=device_type, + ) assert r0[0].item() == -1 assert r1[0].item() == -2 assert r0[2]["module.param0"].item() == 0 @@ -255,19 +264,22 @@ def test_hook(self, device_type): assert r0[2]["module.param1"].item() == 0 assert r1[2]["module.param1"].item() == 0 - @pytest.mark.parametrize('device_type', _device_types()) + @pytest.mark.parametrize("device_type", _device_types()) @pytest.mark.skipif( not pytorch_pfn_extras.requires("1.6.0"), reason="Variable._execution_engine.queue_callback does not work " - "with checkpointing when torch < 1.6.0") + "with checkpointing when torch < 1.6.0", + ) def test_checkpoint(self, device_type): r0, r1 = _launch( - inputs=[torch.tensor([[1.]]), torch.tensor([[2.]])], + inputs=[torch.tensor([[1.0]]), torch.tensor([[2.0]])], modules=[MyModuleWithCheckpoint(), MyModuleWithCheckpoint()], step=Steps._step_with_hook, - device_type=device_type) + device_type=device_type, + ) grad0 = r0[2] grad1 = r1[2] for key in grad0.keys(): - assert np.array_equal(grad0[key].cpu().numpy(), - grad1[key].cpu().numpy()) + assert np.array_equal( + grad0[key].cpu().numpy(), grad1[key].cpu().numpy() + ) diff --git a/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py b/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py index 2ac018b4d..00a4dbcb6 100644 --- a/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py +++ b/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py @@ -1,39 +1,34 @@ import os import pytest -import torch - import pytorch_pfn_extras as ppe +import torch - -_profiler_available = ( - os.name != 'nt' - or ppe.requires("1.9") -) +_profiler_available = os.name != "nt" or ppe.requires("1.9") @pytest.mark.skipif(not _profiler_available, reason="profiler is not available") -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_record(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = torch.nn.Linear(30, 40) model.to(device) x = torch.arange(30, dtype=torch.float32).to(device) with torch.profiler.profile() as prof: - with ppe.profiler.record('my_tag_1'): + with ppe.profiler.record("my_tag_1"): model(x) keys = [event.key for event in prof.key_averages()] - assert 'my_tag_1' in keys - assert 'aten::linear' in keys + assert "my_tag_1" in keys + assert "aten::linear" in keys @pytest.mark.skipif(not _profiler_available, reason="profiler is not available") -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_record_without_tag(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = torch.nn.Linear(30, 40) model.to(device) @@ -44,19 +39,19 @@ def test_record_without_tag(device): model(x) keys = [event.key for event in prof.key_averages()] - assert 'aten::linear' in keys - assert any(k.endswith('test_record_without_tag') for k in keys) + assert "aten::linear" in keys + assert any(k.endswith("test_record_without_tag") for k in keys) @pytest.mark.skipif(not _profiler_available, reason="profiler is not available") -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_record_function(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = torch.nn.Linear(30, 40) model.to(device) - @ppe.profiler.record_function('my_tag_2') + @ppe.profiler.record_function("my_tag_2") def my_run(x): model(x) @@ -65,14 +60,14 @@ def my_run(x): my_run(x) keys = [event.key for event in prof.key_averages()] - assert 'aten::linear' in keys - assert 'my_tag_2' in keys + assert "aten::linear" in keys + assert "my_tag_2" in keys @pytest.mark.skipif(not _profiler_available, reason="profiler is not available") -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_record_function_without_tag(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = torch.nn.Linear(30, 40) model.to(device) @@ -86,14 +81,14 @@ def my_run(x): my_run(x) keys = [event.key for event in prof.key_averages()] - assert 'aten::linear' in keys - assert 'my_run' in keys + assert "aten::linear" in keys + assert "my_run" in keys @pytest.mark.skipif(not _profiler_available, reason="profiler is not available") -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_record_iterable(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = torch.nn.Linear(30, 40) model.to(device) @@ -102,20 +97,20 @@ def test_record_iterable(device): iters = [x, x, x] with torch.profiler.profile() as prof: - for x in ppe.profiler.record_iterable('my_tag_3', iters): + for x in ppe.profiler.record_iterable("my_tag_3", iters): model(x) keys = [event.key for event in prof.key_averages()] - assert 'aten::linear' in keys - assert 'my_tag_3-0' in keys - assert 'my_tag_3-1' in keys - assert 'my_tag_3-2' in keys + assert "aten::linear" in keys + assert "my_tag_3-0" in keys + assert "my_tag_3-1" in keys + assert "my_tag_3-2" in keys @pytest.mark.skipif(not _profiler_available, reason="profiler is not available") -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_record_iterable_without_tag(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = torch.nn.Linear(30, 40) model.to(device) @@ -128,7 +123,7 @@ def test_record_iterable_without_tag(device): model(x) keys = [event.key for event in prof.key_averages()] - assert 'aten::linear' in keys - assert any(k.endswith('test_record_iterable_without_tag-0') for k in keys) - assert any(k.endswith('test_record_iterable_without_tag-1') for k in keys) - assert any(k.endswith('test_record_iterable_without_tag-2') for k in keys) + assert "aten::linear" in keys + assert any(k.endswith("test_record_iterable_without_tag-0") for k in keys) + assert any(k.endswith("test_record_iterable_without_tag-1") for k in keys) + assert any(k.endswith("test_record_iterable_without_tag-2") for k in keys) diff --git a/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py b/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py index 91388b7bc..62eafcc33 100644 --- a/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py +++ b/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py @@ -4,7 +4,6 @@ import time import pytest - from pytorch_pfn_extras.profiler import TimeSummary, get_time_summary @@ -43,8 +42,9 @@ def worker(summary): @pytest.mark.skipif( - sys.platform == 'win32', - reason='Multiprocessing not fully supported on Windows') + sys.platform == "win32", + reason="Multiprocessing not fully supported on Windows", +) def test_report_from_other_process(): summary = TimeSummary() p = mp.Process(target=worker, args=(summary,)) @@ -65,8 +65,9 @@ def worker1(): @pytest.mark.skipif( - sys.platform == 'win32', - reason='Multiprocessing not fully supported on Windows') + sys.platform == "win32", + reason="Multiprocessing not fully supported on Windows", +) def test_global_summary(): time_summary = get_time_summary() time_summary.initialize() @@ -98,10 +99,14 @@ def test_clear(): def test_multiprocessing_start_method(): # Ensure that importing PPE does not initialize multiprocessing context. # See #238 for the context. - subprocess.check_call([ - sys.executable, - '-c', - ('import multiprocessing as mp; ' - + 'import pytorch_pfn_extras; ' - + 'mp.set_start_method("spawn"); ') - ]) + subprocess.check_call( + [ + sys.executable, + "-c", + ( + "import multiprocessing as mp; " + + "import pytorch_pfn_extras; " + + 'mp.set_start_method("spawn"); ' + ), + ] + ) diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py index b1115e810..b8a97faee 100644 --- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py +++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py @@ -1,20 +1,17 @@ import types import warnings -import torch -import torch.nn.functional as F - import pytorch_pfn_extras as ppe import pytorch_pfn_extras.utils.comparer as _comp +import torch +import torch.nn.functional as F from pytorch_pfn_extras.onnx._as_output import trace class JITRuntime(ppe.runtime.PyTorchRuntime): - def move_module(self, module): - def new_forward(self, *args): - if hasattr(self, '_traced_mod'): + if hasattr(self, "_traced_mod"): out = self._traced_mod(*args) inter_size = len(self._names) if inter_size == 0: @@ -23,7 +20,10 @@ def new_forward(self, *args): out = [out] return dict( **{str(i): x for i, x in enumerate(out[:-inter_size])}, - **{name: x for name, x in zip(self._names, out[-inter_size:])}, + **{ + name: x + for name, x in zip(self._names, out[-inter_size:]) + }, ) new_forward = self.forward @@ -33,7 +33,8 @@ def new_forward(self, *args): with warnings.catch_warnings(): warnings.simplefilter("ignore") self._traced_mod = torch.jit.trace_module( - new_module, {"forward": args}) + new_module, {"forward": args} + ) self._names = [out.name for out in outputs.values] self.forward = new_forward @@ -41,7 +42,7 @@ def new_forward(self, *args): def forward_with_init(self, *args, **kwargs): # `module.forward` is called multiple times while tracing. - handler = getattr(_comp._thread_local, 'handler', None) + handler = getattr(_comp._thread_local, "handler", None) if handler is not None: handler._reset_intermediate_values() return self._orig_forward(*args, **kwargs) @@ -67,9 +68,9 @@ def __init__(self): def forward(self, x, t): y = self.model(x) - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" loss = F.l1_loss(y, t) - ppe.reporting.report({prefix + '/loss': loss}) + ppe.reporting.report({prefix + "/loss": loss}) return loss @@ -82,26 +83,60 @@ def _get_jit_cpu_model(device_type): def test_jit_runtime_trainer(): - model, optimizer = _get_jit_cpu_model('jit-cpu') - trainer = ppe.engine.create_trainer(model, optimizer, 10, device='jit-cpu') + model, optimizer = _get_jit_cpu_model("jit-cpu") + trainer = ppe.engine.create_trainer(model, optimizer, 10, device="jit-cpu") data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) trainer.run(data) def test_jit_runtime_evaluator(): - model, optimizer = _get_jit_cpu_model('jit-cpu') - evaluator = ppe.engine.create_evaluator(model, device='jit-cpu') + model, optimizer = _get_jit_cpu_model("jit-cpu") + evaluator = ppe.engine.create_evaluator(model, device="jit-cpu") data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) evaluator.run(data) def test_jit_runtime_trainer_with_evaluator(): - model, optimizer = _get_jit_cpu_model('jit-cpu') - evaluator = ppe.engine.create_evaluator(model, device='jit-cpu') + model, optimizer = _get_jit_cpu_model("jit-cpu") + evaluator = ppe.engine.create_evaluator(model, device="jit-cpu") trainer = ppe.engine.create_trainer( - model, optimizer, 10, device='jit-cpu', evaluator=evaluator) + model, optimizer, 10, device="jit-cpu", evaluator=evaluator + ) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) trainer.run(data, data) diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py index 696dd2582..2ee50f7ab 100644 --- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py +++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py @@ -1,6 +1,5 @@ -import torch - import pytorch_pfn_extras as ppe +import torch class FallbackRuntime(ppe.runtime.BaseRuntime): @@ -13,26 +12,26 @@ class MyCustomRuntime(ppe.runtime.PyTorchRuntime): def test_registry_register(): registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime) - registry.register('dummy_device', MyCustomRuntime) + registry.register("dummy_device", MyCustomRuntime) assert ( - registry.get_runtime_class_for_device_spec('dummy_device') + registry.get_runtime_class_for_device_spec("dummy_device") == MyCustomRuntime ) def test_registry_fallback(): registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime) - registry.register('dummy_device', MyCustomRuntime) + registry.register("dummy_device", MyCustomRuntime) assert ( - registry.get_runtime_class_for_device_spec('unknown_device') + registry.get_runtime_class_for_device_spec("unknown_device") == FallbackRuntime ) def test_registry_torch_device(): registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime) - registry.register('cpu', MyCustomRuntime) + registry.register("cpu", MyCustomRuntime) assert ( - registry.get_runtime_class_for_device_spec(torch.device('cpu')) + registry.get_runtime_class_for_device_spec(torch.device("cpu")) == MyCustomRuntime ) diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py index aa6b94ad4..b0a921b1a 100644 --- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py +++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py @@ -1,22 +1,19 @@ import contextlib import pytest -import torch - import pytorch_pfn_extras as ppe +import torch @pytest.mark.skipif( - not torch.cuda.is_available(), - reason='Moving across devices requires CUDA' + not torch.cuda.is_available(), reason="Moving across devices requires CUDA" ) class TestPytorchRuntime: - @pytest.mark.parametrize('device', ['cpu', 'cuda']) + @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize( - 'batch', [{'x': torch.zeros(1)}, - [torch.zeros(1)], - torch.zeros(1), - object()]) + "batch", + [{"x": torch.zeros(1)}, [torch.zeros(1)], torch.zeros(1), object()], + ) def test_convert_batch(self, device, batch): rt = ppe.runtime.PyTorchRuntime(device, {}) cbatch = rt.convert_batch(batch) @@ -31,14 +28,14 @@ def test_convert_batch(self, device, batch): else: assert cbatch is batch - @pytest.mark.parametrize('device', ['cpu', 'cuda']) + @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_move_module(self, device): rt = ppe.runtime.PyTorchRuntime(device, {}) module = torch.nn.Linear(1, 1) module = rt.move_module(module) assert module.weight.device.type == device - @pytest.mark.parametrize('device', ['cpu', 'cuda']) + @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_move_tensor(self, device): rt = ppe.runtime.PyTorchRuntime(device, {}) tensor = torch.zeros(10) @@ -66,20 +63,20 @@ def __init__(self): super().__init__() self.layer1 = torch.nn.Linear(10, 10) self.layer2 = torch.nn.Linear(10, 10) - ppe.to(self.layer2, device='dummy', runtime_class=DummyRuntime) + ppe.to(self.layer2, device="dummy", runtime_class=DummyRuntime) def test_runtime_container(): module = MyModule() # This is a top module, so it won't show child ones for _ in ppe.runtime._runtime.named_runtime_modules(module): - pytest.fail('Never reach') + pytest.fail("Never reach") def test_split_runtime_container(): module = SplitModule() for name, mod in ppe.runtime._runtime.named_runtime_modules(module): - assert name == 'layer2' + assert name == "layer2" assert mod is module.layer2 @@ -89,24 +86,29 @@ def __init__(self): super().__init__() self.layer1 = SplitModule() self.layer2 = SplitModule() + module = MultiLevelSplitModule() - expected = [('layer2', module.layer1.layer2), - ('layer2', module.layer2.layer2)] + expected = [ + ("layer2", module.layer1.layer2), + ("layer2", module.layer2.layer2), + ] for expected, (name, mod) in zip( - expected, ppe.runtime._runtime.named_runtime_modules(module)): + expected, ppe.runtime._runtime.named_runtime_modules(module) + ): assert name == expected[0] assert mod is expected[1] for _ in zip( - expected, ppe.runtime._runtime.named_runtime_modules( - module, recursive=False)): - pytest.fail('Never reach') + expected, + ppe.runtime._runtime.named_runtime_modules(module, recursive=False), + ): + pytest.fail("Never reach") def test_module_change_forward(): class Module1(torch.nn.Module): def forward(self, input): - raise RuntimeError('The module forward should never be executed') + raise RuntimeError("The module forward should never be executed") class Module2: def __init__(self): @@ -129,7 +131,7 @@ def move_module(self, module): with pytest.raises(RuntimeError): module(None) - ppe.to(module, device='dummy', runtime_class=ForwardIntercepterRuntime) + ppe.to(module, device="dummy", runtime_class=ForwardIntercepterRuntime) assert int(module(None)) == 5 @@ -164,9 +166,9 @@ def trace(cls, event_name, arg): called = 2 assert called == 0 - with ppe.runtime.BaseRuntime.trace('dummy', None): + with ppe.runtime.BaseRuntime.trace("dummy", None): assert called == 0 assert called == 0 - with TracerRuntime.trace('dummy', None): + with TracerRuntime.trace("dummy", None): assert called == 1 assert called == 2 diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py index 86a46f7b9..37eb725db 100644 --- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py +++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py @@ -1,13 +1,13 @@ import io -import pytest -import torch +import pytest import pytorch_pfn_extras as ppe +import torch @pytest.mark.gpu def test_tensor_ppe_to(): - device = 'cuda:0' + device = "cuda:0" tensor = torch.zeros(10) out = ppe.to(tensor, device) assert str(out.device) == device @@ -22,7 +22,7 @@ def __init__(self): @pytest.mark.gpu def test_module_ppe_to(): - device = 'cuda:0' + device = "cuda:0" module = MyModule() ppe.to(module, device) assert all([str(p.device) == device for p in module.parameters()]) @@ -30,7 +30,7 @@ def test_module_ppe_to(): def test_invalid_ppe_to(): - device = 'cpu' + device = "cpu" with pytest.raises(ValueError): ppe.to(object(), device) @@ -46,24 +46,23 @@ def initialize_module(self, module, loader_or_batch): def test_module_split_ppe_to(): module = MyModule() - ppe.to(module.layer2, 'dummy', runtime_class=MyRuntime, - options={'opt': 1}) + ppe.to(module.layer2, "dummy", runtime_class=MyRuntime, options={"opt": 1}) rt_layer1 = ppe.runtime._runtime._module_runtime_tag(module.layer1) rt_layer2 = ppe.runtime._runtime._module_runtime_tag(module.layer2) assert str(next(iter(module.layer1.parameters())).device) == "cpu" assert rt_layer1 is None assert isinstance(rt_layer2, MyRuntime) - assert rt_layer2.device_spec == 'dummy' - assert rt_layer2.options['opt'] == 1 + assert rt_layer2.device_spec == "dummy" + assert rt_layer2.options["opt"] == 1 def test_module_split_ppe_to_config(): # Deprecated "config" option. module = MyModule() - ppe.to(module, 'dummy', runtime_class=MyRuntime, config={'opt': 1}) + ppe.to(module, "dummy", runtime_class=MyRuntime, config={"opt": 1}) rt_layer1 = ppe.runtime._runtime._module_runtime_tag(module) assert isinstance(rt_layer1, MyRuntime) - assert rt_layer1.options['opt'] == 1 + assert rt_layer1.options["opt"] == 1 class NonPicklableRuntime(ppe.runtime.BaseRuntime): diff --git a/tests/pytorch_pfn_extras_tests/test_config.py b/tests/pytorch_pfn_extras_tests/test_config.py index 302b6b0e7..60c151d5b 100644 --- a/tests/pytorch_pfn_extras_tests/test_config.py +++ b/tests/pytorch_pfn_extras_tests/test_config.py @@ -2,28 +2,26 @@ import os import tempfile import unittest -import pytest -from pytorch_pfn_extras.config import Config -from pytorch_pfn_extras.config import customize_type +import pytest +from pytorch_pfn_extras.config import Config, customize_type def func_0(a, b, c=10): return a + b + c -@customize_type(c='/foo/v0') +@customize_type(c="/foo/v0") def func_1(a, b, c): - return {'d': a * b, 'e': c} + return {"d": a * b, "e": c} -@customize_type(config='!/') +@customize_type(config="!/") def func_2(config): return json.dumps(config) class Cls0(object): - def __init__(self, a, b, c=10): self.a = a self.b = b @@ -33,9 +31,8 @@ def __eq__(self, other): return (self.a, self.b, self.c) == (other.a, other.b, other.c) -@customize_type(c='../../foo/v0') +@customize_type(c="../../foo/v0") class Cls1(object): - def __init__(self, a, b, c): self.d = a * b self.e = c @@ -45,263 +42,280 @@ def __eq__(self, other): class TestConfig(unittest.TestCase): - types = { - 'func_0': func_0, - 'func_1': func_1, - 'func_2': func_2, - 'cls_0': Cls0, - 'cls_1': Cls1, + "func_0": func_0, + "func_1": func_1, + "func_2": func_2, + "cls_0": Cls0, + "cls_1": Cls1, } def test_config(self): - config = Config({ - 'foo': { - 'v0': {'type': 'func_0', 'a': 1, 'b': 2}, - 'v1': {'type': 'func_0', 'a': 1, 'b': 2, 'c': 3}, - 'v2': {'type': 'func_1', 'a': 1, 'b': 2}, - 'v3': {'type': 'func_1', 'a': 1, 'b': 2, 'c': 3}, - }, - 'bar': [ - {'type': 'cls_0', 'a': 1, 'b': 2}, - {'type': 'cls_0', 'a': 1, 'b': 2, 'c': 3}, - {'type': 'cls_1', 'a': 1, 'b': 2}, - {'type': 'cls_1', 'a': 1, 'b': 2, 'c': 3}, - ], - 'baz': { - 'v0': '@/foo/v2.d', - 'v1': '@../bar/1/c', - 'v2': '@/bar/3.d', - 'v3': '@../foo/v3', - } - }, self.types) - - self.assertEqual(config['/'], { - 'foo': { - 'v0': 13, - 'v1': 6, - 'v2': {'d': 2, 'e': 13}, - 'v3': {'d': 2, 'e': 3}, + config = Config( + { + "foo": { + "v0": {"type": "func_0", "a": 1, "b": 2}, + "v1": {"type": "func_0", "a": 1, "b": 2, "c": 3}, + "v2": {"type": "func_1", "a": 1, "b": 2}, + "v3": {"type": "func_1", "a": 1, "b": 2, "c": 3}, + }, + "bar": [ + {"type": "cls_0", "a": 1, "b": 2}, + {"type": "cls_0", "a": 1, "b": 2, "c": 3}, + {"type": "cls_1", "a": 1, "b": 2}, + {"type": "cls_1", "a": 1, "b": 2, "c": 3}, + ], + "baz": { + "v0": "@/foo/v2.d", + "v1": "@../bar/1/c", + "v2": "@/bar/3.d", + "v3": "@../foo/v3", + }, }, - 'bar': [ - Cls0(1, 2, 10), - Cls0(1, 2, 3), - Cls1(1, 2, 13), - Cls1(1, 2, 3), - ], - 'baz': { - 'v0': 2, - 'v1': 3, - 'v2': 2, - 'v3': {'d': 2, 'e': 3}, + self.types, + ) + + self.assertEqual( + config["/"], + { + "foo": { + "v0": 13, + "v1": 6, + "v2": {"d": 2, "e": 13}, + "v3": {"d": 2, "e": 3}, + }, + "bar": [ + Cls0(1, 2, 10), + Cls0(1, 2, 3), + Cls1(1, 2, 13), + Cls1(1, 2, 3), + ], + "baz": { + "v0": 2, + "v1": 3, + "v2": 2, + "v3": {"d": 2, "e": 3}, + }, }, - }) + ) def test_config_escape(self): pre_eval_config = { - 'foo': { - 'v0': {'type': 'func_0', 'a': 1, 'b': 2}, + "foo": { + "v0": {"type": "func_0", "a": 1, "b": 2}, }, - 'bar': {'type': 'func_2'}, + "bar": {"type": "func_2"}, } config = Config(pre_eval_config, self.types) - self.assertEqual(config['!/foo'], { - 'v0': {'type': 'func_0', 'a': 1, 'b': 2}, - }) - self.assertEqual(json.loads(config['/bar']), pre_eval_config) + self.assertEqual( + config["!/foo"], + { + "v0": {"type": "func_0", "a": 1, "b": 2}, + }, + ) + self.assertEqual(json.loads(config["/bar"]), pre_eval_config) def test_config_load_path(self): - with tempfile.TemporaryDirectory() as temp0, \ - tempfile.TemporaryDirectory() as temp1: - with open(os.path.join(temp0, 'foo.json'), mode='w') as f: - json.dump({ - 'foo': {'v0': {'type': 'func_0', 'a': 1, 'b': 2}}, - 'bar': {'import': os.path.join(temp1, 'bar.json')}, - 'baz': { - 'import': 'baz.json', - '0/b': 3, - '1/d': [1, 2], + with tempfile.TemporaryDirectory() as temp0, tempfile.TemporaryDirectory() as temp1: + with open(os.path.join(temp0, "foo.json"), mode="w") as f: + json.dump( + { + "foo": {"v0": {"type": "func_0", "a": 1, "b": 2}}, + "bar": {"import": os.path.join(temp1, "bar.json")}, + "baz": { + "import": "baz.json", + "0/b": 3, + "1/d": [1, 2], + }, }, - }, f) - with open(os.path.join(temp1, 'bar.json'), mode='w') as f: - json.dump({'type': 'func_0', 'a': 3, 'b': 4}, f) - with open(os.path.join(temp0, 'baz.json'), mode='w') as f: - json.dump([ - {'type': 'func_1', 'a': 1, 'b': 2}, - {'d': 3, 'e': 4}, - ], f) + f, + ) + with open(os.path.join(temp1, "bar.json"), mode="w") as f: + json.dump({"type": "func_0", "a": 3, "b": 4}, f) + with open(os.path.join(temp0, "baz.json"), mode="w") as f: + json.dump( + [ + {"type": "func_1", "a": 1, "b": 2}, + {"d": 3, "e": 4}, + ], + f, + ) config = Config.load_path( - os.path.join(temp0, 'foo.json'), types=self.types) - - self.assertEqual(config['!/foo'], - {'v0': {'type': 'func_0', 'a': 1, 'b': 2}}) - self.assertEqual(config['/foo'], {'v0': 13}) - self.assertEqual(config['!/bar'], {'type': 'func_0', 'a': 3, 'b': 4}) - self.assertEqual(config['/bar'], 17) - self.assertEqual(config['!/baz'], [ - {'type': 'func_1', 'a': 1, 'b': 3}, - {'d': [1, 2], 'e': 4}, - ]) - self.assertEqual(config['/baz'], [ - {'d': 3, 'e': 13}, - {'d': [1, 2], 'e': 4}, - ]) + os.path.join(temp0, "foo.json"), types=self.types + ) + + self.assertEqual( + config["!/foo"], {"v0": {"type": "func_0", "a": 1, "b": 2}} + ) + self.assertEqual(config["/foo"], {"v0": 13}) + self.assertEqual(config["!/bar"], {"type": "func_0", "a": 3, "b": 4}) + self.assertEqual(config["/bar"], 17) + self.assertEqual( + config["!/baz"], + [ + {"type": "func_1", "a": 1, "b": 3}, + {"d": [1, 2], "e": 4}, + ], + ) + self.assertEqual( + config["/baz"], + [ + {"d": 3, "e": 13}, + {"d": [1, 2], "e": 4}, + ], + ) def test_config_with_config_key_invalid_index(self): - config = Config([['a'], [['b', ['c', 'd']]]]) + config = Config([["a"], [["b", ["c", "d"]]]]) with self.assertRaises(IndexError) as cm: - config['/1/2/3'] + config["/1/2/3"] self.assertEqual( cm.exception.args[-2:], - ('2 not in !/1', - '/1/2/3 -> !/1/2/3 -> !/1/2')) + ("2 not in !/1", "/1/2/3 -> !/1/2/3 -> !/1/2"), + ) def test_config_with_config_key_invalid_key(self): - config = Config({'foo': {'bar': {'baz': None}}}) + config = Config({"foo": {"bar": {"baz": None}}}) with self.assertRaises(KeyError) as cm: - config['/foo/Bar/baz'] + config["/foo/Bar/baz"] self.assertEqual( cm.exception.args[-2:], - ('Bar not in !/foo', - '/foo/Bar/baz -> !/foo/Bar/baz -> !/foo/Bar')) + ("Bar not in !/foo", "/foo/Bar/baz -> !/foo/Bar/baz -> !/foo/Bar"), + ) def test_config_with_config_key_invalid_type(self): - config = Config({'foo': [['b', {'baz': None}]]}) + config = Config({"foo": [["b", {"baz": None}]]}) with self.assertRaises(TypeError) as cm: - config['/foo/bar/baz'] + config["/foo/bar/baz"] self.assertEqual( cm.exception.args[-2:], - ('bar not in !/foo', - '/foo/bar/baz -> !/foo/bar/baz -> !/foo/bar')) + ("bar not in !/foo", "/foo/bar/baz -> !/foo/bar/baz -> !/foo/bar"), + ) def test_config_with_attr_key_invalid_index(self): - config = Config([['a'], [['b', ['c', 'd']]]]) + config = Config([["a"], [["b", ["c", "d"]]]]) with self.assertRaises(IndexError) as cm: - config['/.1.2.3'] + config["/.1.2.3"] self.assertEqual( cm.exception.args[-2:], - ('2 not in /.1 ([[\'b\', [\'c\', \'d\']]])', - '/.1.2.3 -> /.1.2')) + ("2 not in /.1 ([['b', ['c', 'd']]])", "/.1.2.3 -> /.1.2"), + ) def test_config_with_attr_key_invalid_key(self): - config = Config({'foo': {'bar': {'baz': None}}}) + config = Config({"foo": {"bar": {"baz": None}}}) with self.assertRaises(KeyError) as cm: - config['/.foo.Bar.baz'] + config["/.foo.Bar.baz"] self.assertEqual( cm.exception.args[-2:], - ('Bar not in /.foo ({\'bar\': {\'baz\': None}})', - '/.foo.Bar.baz -> /.foo.Bar')) + ( + "Bar not in /.foo ({'bar': {'baz': None}})", + "/.foo.Bar.baz -> /.foo.Bar", + ), + ) def test_config_with_attr_key_invalid_type(self): - config = Config({'foo': [['b', {'baz': None}]]}) + config = Config({"foo": [["b", {"baz": None}]]}) with self.assertRaises(TypeError) as cm: - config['/.foo.bar.baz'] + config["/.foo.bar.baz"] self.assertEqual( cm.exception.args[-2:], - ('bar not in /.foo ([[\'b\', {\'baz\': None}]])', - '/.foo.bar.baz -> /.foo.bar')) + ( + "bar not in /.foo ([['b', {'baz': None}]])", + "/.foo.bar.baz -> /.foo.bar", + ), + ) def test_config_with_invalid_type(self): - config = Config({'foo': [{'type': 'foo', 'a': 0, 'b': 1}]}) + config = Config({"foo": [{"type": "foo", "a": 0, "b": 1}]}) with self.assertRaises(KeyError) as cm: - config['/'] + config["/"] self.assertEqual( - cm.exception.args[-2:], - ('foo not in types', - '/ -> /foo -> /foo/0')) + cm.exception.args[-2:], ("foo not in types", "/ -> /foo -> /foo/0") + ) def test_config_with_invalid_call(self): def foo(): - raise RuntimeError('foo') + raise RuntimeError("foo") config = Config( - {'foo': [{'type': 'foo', 'a': 0, 'b': 1}]}, - types={'foo': foo}) + {"foo": [{"type": "foo", "a": 0, "b": 1}]}, types={"foo": foo} + ) with self.assertRaises(TypeError) as cm: - config['/'] + config["/"] - self.assertEqual( - cm.exception.args[-1:], - ('/ -> /foo -> /foo/0',)) + self.assertEqual(cm.exception.args[-1:], ("/ -> /foo -> /foo/0",)) def test_config_with_circular_dependency(self): - config = Config({'foo': '@/bar', 'bar': '@foo.d'}) + config = Config({"foo": "@/bar", "bar": "@foo.d"}) with self.assertRaises(RuntimeError) as cm: - config['/'] + config["/"] self.assertIn( cm.exception.args, { - ('Circular dependency', '/ -> /foo -> /bar -> /foo.d -> /foo'), - ('Circular dependency', '/ -> /bar -> /foo.d -> /foo -> /bar'), - }) + ("Circular dependency", "/ -> /foo -> /bar -> /foo.d -> /foo"), + ("Circular dependency", "/ -> /bar -> /foo.d -> /foo -> /bar"), + }, + ) def test_config_with_circular_import(self): with tempfile.TemporaryDirectory() as temp: - with open(os.path.join(temp, 'foo.json'), mode='w') as f: - json.dump({'a': {'import': 'bar.json'}}, f) - with open(os.path.join(temp, 'bar.json'), mode='w') as f: - json.dump([{'import': './foo.json'}], f) + with open(os.path.join(temp, "foo.json"), mode="w") as f: + json.dump({"a": {"import": "bar.json"}}, f) + with open(os.path.join(temp, "bar.json"), mode="w") as f: + json.dump([{"import": "./foo.json"}], f) with self.assertRaises(RuntimeError) as cm: - Config.load_path(os.path.join(temp, 'foo.json')) + Config.load_path(os.path.join(temp, "foo.json")) self.assertEqual( cm.exception.args, - ('Circular import', - '!/ of {foo} -> !/a of {foo} -> !/ of {bar}' - ' -> !/0 of {bar} -> !/ of {foo}'.format( - foo=os.path.join(temp, 'foo.json'), - bar=os.path.join(temp, 'bar.json')))) + ( + "Circular import", + "!/ of {foo} -> !/a of {foo} -> !/ of {bar}" + " -> !/0 of {bar} -> !/ of {foo}".format( + foo=os.path.join(temp, "foo.json"), + bar=os.path.join(temp, "bar.json"), + ), + ), + ) def test_config_with_args_update(self): - config = Config({ - 'foo': { - 'ls': ['first'] - } - }, self.types) + config = Config({"foo": {"ls": ["first"]}}, self.types) - assert config['/foo/ls/0'] == 'first' - config.update_via_args([('/foo/ls/0', 'changed')]) - assert config['/foo/ls/0'] == 'changed' + assert config["/foo/ls/0"] == "first" + config.update_via_args([("/foo/ls/0", "changed")]) + assert config["/foo/ls/0"] == "changed" def test_config_with_args_update_type_conversion(self): - config = Config({ - 'foo': { - 'ls': [0] - } - }, self.types) + config = Config({"foo": {"ls": [0]}}, self.types) - assert config['/foo/ls/0'] == 0 - config.update_via_args([('/foo/ls/0', '16')]) - assert config['/foo/ls/0'] == 16 + assert config["/foo/ls/0"] == 0 + config.update_via_args([("/foo/ls/0", "16")]) + assert config["/foo/ls/0"] == 16 def test_config_with_args_update_type_conversion_bool(self): - config = Config({ - 'foo': { - 'ls': [True] - } - }, self.types) - - assert config['/foo/ls/0'] - config.update_via_args([('/foo/ls/0', False)]) - assert not config['/foo/ls/0'] - config.update_via_args([('/foo/ls/0', "False")]) - assert not config['/foo/ls/0'] - config.update_via_args([('/foo/ls/0', "true")]) - assert config['/foo/ls/0'] - config.update_via_args([('/foo/ls/0', "TRUE")]) - assert config['/foo/ls/0'] - config.update_via_args([('/foo/ls/0', "false")]) - assert not config['/foo/ls/0'] + config = Config({"foo": {"ls": [True]}}, self.types) + + assert config["/foo/ls/0"] + config.update_via_args([("/foo/ls/0", False)]) + assert not config["/foo/ls/0"] + config.update_via_args([("/foo/ls/0", "False")]) + assert not config["/foo/ls/0"] + config.update_via_args([("/foo/ls/0", "true")]) + assert config["/foo/ls/0"] + config.update_via_args([("/foo/ls/0", "TRUE")]) + assert config["/foo/ls/0"] + config.update_via_args([("/foo/ls/0", "false")]) + assert not config["/foo/ls/0"] with pytest.raises(ValueError): - config.update_via_args([('/foo/ls/0', "alse")]) + config.update_via_args([("/foo/ls/0", "alse")]) diff --git a/tests/pytorch_pfn_extras_tests/test_config_types.py b/tests/pytorch_pfn_extras_tests/test_config_types.py index fe637b05a..95d559abb 100644 --- a/tests/pytorch_pfn_extras_tests/test_config_types.py +++ b/tests/pytorch_pfn_extras_tests/test_config_types.py @@ -4,90 +4,99 @@ import unittest import optuna - -from pytorch_pfn_extras.config_types import optuna_types -from pytorch_pfn_extras.config_types import load_path_with_optuna_types +from pytorch_pfn_extras.config_types import ( + load_path_with_optuna_types, + optuna_types, +) class TestConfigTypes(unittest.TestCase): - study = optuna.create_study() def test_config_optuna_types(self): def objective(trial): types = optuna_types(trial) self.assertEqual( - types['optuna_suggest_categorical'], - trial.suggest_categorical) - self.assertEqual( - types['optuna_suggest_discrete_uniform'], - trial.suggest_discrete_uniform) - self.assertEqual( - types['optuna_suggest_float'], - trial.suggest_float) + types["optuna_suggest_categorical"], trial.suggest_categorical + ) self.assertEqual( - types['optuna_suggest_int'], - trial.suggest_int) + types["optuna_suggest_discrete_uniform"], + trial.suggest_discrete_uniform, + ) + self.assertEqual(types["optuna_suggest_float"], trial.suggest_float) + self.assertEqual(types["optuna_suggest_int"], trial.suggest_int) self.assertEqual( - types['optuna_suggest_loguniform'], - trial.suggest_loguniform) + types["optuna_suggest_loguniform"], trial.suggest_loguniform + ) self.assertEqual( - types['optuna_suggest_uniform'], - trial.suggest_uniform) + types["optuna_suggest_uniform"], trial.suggest_uniform + ) return 0.0 + self.study.optimize(objective, n_trials=1) def test_load_path_with_optuna_types(self): low = 0 high = 8 with tempfile.TemporaryDirectory() as temp0: - with open(os.path.join(temp0, 'foo.json'), mode='w') as f: - json.dump({ - 'foo': { - 'type': 'optuna_suggest_int', - 'name': 'a', - 'low': low, - 'high': high - } - }, f) + with open(os.path.join(temp0, "foo.json"), mode="w") as f: + json.dump( + { + "foo": { + "type": "optuna_suggest_int", + "name": "a", + "low": low, + "high": high, + } + }, + f, + ) def objective(trial): config = load_path_with_optuna_types( - os.path.join(temp0, 'foo.json'), trial) - self.assertIsInstance(config['/foo'], int) - self.assertGreaterEqual(config['/foo'], low) - self.assertLessEqual(config['/foo'], high) + os.path.join(temp0, "foo.json"), trial + ) + self.assertIsInstance(config["/foo"], int) + self.assertGreaterEqual(config["/foo"], low) + self.assertLessEqual(config["/foo"], high) return 0.0 + self.study.optimize(objective, n_trials=2 * (high - low + 1)) def test_load_path_with_optuna_types_with_types_argument(self): low = 0 high = 8 with tempfile.TemporaryDirectory() as temp0: - with open(os.path.join(temp0, 'foo.json'), mode='w') as f: - json.dump({ - 'foo': { - 'type': 'optuna_suggest_int', - 'name': 'a', - 'low': low, - 'high': high + with open(os.path.join(temp0, "foo.json"), mode="w") as f: + json.dump( + { + "foo": { + "type": "optuna_suggest_int", + "name": "a", + "low": low, + "high": high, + }, + "bar": {"type": "dict", "x": 0}, }, - 'bar': { - 'type': 'dict', - 'x': 0 - } - }, f) + f, + ) def objective(trial): config = load_path_with_optuna_types( - os.path.join(temp0, 'foo.json'), trial, - types={'optuna_suggest_int': float, 'dict': dict}) - self.assertIsInstance(config['/foo'], int) - self.assertGreaterEqual(config['/foo'], low) - self.assertLessEqual(config['/foo'], high) - self.assertIsInstance(config['/bar'], dict) - self.assertEqual(config['/bar']['x'], 0) + os.path.join(temp0, "foo.json"), + trial, + types={"optuna_suggest_int": float, "dict": dict}, + ) + self.assertIsInstance(config["/foo"], int) + self.assertGreaterEqual(config["/foo"], low) + self.assertLessEqual(config["/foo"], high) + self.assertIsInstance(config["/bar"], dict) + self.assertEqual(config["/bar"]["x"], 0) return 0.0 + self.assertWarns( - UserWarning, self.study.optimize, objective, - n_trials=2 * (high - low + 1)) + UserWarning, + self.study.optimize, + objective, + n_trials=2 * (high - low + 1), + ) diff --git a/tests/pytorch_pfn_extras_tests/test_handler.py b/tests/pytorch_pfn_extras_tests/test_handler.py index 87e7b381f..cf3615f2e 100644 --- a/tests/pytorch_pfn_extras_tests/test_handler.py +++ b/tests/pytorch_pfn_extras_tests/test_handler.py @@ -1,10 +1,9 @@ -import unittest.mock import contextlib +import unittest.mock -import torch import pytest - import pytorch_pfn_extras as ppe +import torch def torch_testing_assert_close(*args, **kwargs): @@ -91,14 +90,14 @@ def forward(self, x): class MockTrainer: def __init__(self): - self.models = {'main': MockModule()} + self.models = {"main": MockModule()} self.optimizers = {} self.epoch = 0 class MockEvaluator: def __init__(self): - self.models = {'main': MockModule()} + self.models = {"main": MockModule()} self.optimizers = {} self.epoch = 0 @@ -112,7 +111,7 @@ def train_epoch_end(self, epoch, models): def train_step(self, models, optimizers, batch_idx, batch): assert batch.converted - return models['main'](batch) + return models["main"](batch) def train_step_optimizers(self, models, optimizers, batch_idx): self._train_step_optimizers_called = True @@ -122,111 +121,110 @@ def train_validation_begin(self, models): def eval_step(self, models, batch_idx, batch): assert batch.converted - return models['main'](batch) + return models["main"](batch) class HandlerTester: def _get_handler(self, options=None): if options is None: options = {} - ppe.runtime.runtime_registry.register('test_rt', MockRuntime) + ppe.runtime.runtime_registry.register("test_rt", MockRuntime) trainer = MockTrainer() logic = MockLogic() handler = ppe.handler.Handler( - logic, MockRuntime('test_rt', {}), options + logic, MockRuntime("test_rt", {}), options ) return handler, trainer, logic def _move_modules(self, module, to_move): for name in to_move: - if name == 'self': - ppe.to(module, 'test_rt') + if name == "self": + ppe.to(module, "test_rt") else: - ppe.to(getattr(module, name), 'test_rt') + ppe.to(getattr(module, name), "test_rt") def _assert_called(self, module, to_move, function): for name, mod in module.named_modules(): - if (mod is module and 'self' in to_move) or (name in to_move): - assert getattr(mod._ppe_runtime, f'_{function}_called') + if (mod is module and "self" in to_move) or (name in to_move): + assert getattr(mod._ppe_runtime, f"_{function}_called") assert mod._ppe_runtime._called_module == mod else: - if hasattr(mod, '_ppe_runtime'): + if hasattr(mod, "_ppe_runtime"): assert mod._ppe_runtime._called_module != mod class TestHandlerTrainSync(HandlerTester): - @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_setup(self, to_move): handler, trainer, _ = self._get_handler() - module = trainer.models['main'] + module = trainer.models["main"] self._move_modules(module, to_move) handler.train_setup(trainer, []) - self._assert_called(module, to_move, 'initialize') + self._assert_called(module, to_move, "initialize") @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_cleanup(self, to_move): handler, trainer, _ = self._get_handler() - module = trainer.models['main'] + module = trainer.models["main"] self._move_modules(module, to_move) handler.train_cleanup(trainer) - self._assert_called(module, to_move, 'train_cleanup') + self._assert_called(module, to_move, "train_cleanup") @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_epoch_begin(self, to_move): handler, trainer, logic = self._get_handler() - module = trainer.models['main'] + module = trainer.models["main"] self._move_modules(module, to_move) handler.train_epoch_begin(trainer, []) - self._assert_called(module, to_move, 'train_epoch_begin') + self._assert_called(module, to_move, "train_epoch_begin") assert logic._train_epoch_begin_called @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_epoch_end(self, to_move): handler, trainer, logic = self._get_handler() - module = trainer.models['main'] + module = trainer.models["main"] self._move_modules(module, to_move) # Should check that the handler completes handler.train_epoch_end(trainer) - self._assert_called(module, to_move, 'train_epoch_end') + self._assert_called(module, to_move, "train_epoch_end") assert logic._train_epoch_end_called @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_step(self, to_move): handler, trainer, logic = self._get_handler() - module = trainer.models['main'] + module = trainer.models["main"] self._move_modules(module, to_move) callback = unittest.mock.Mock(return_value=None) handler.train_step(trainer, 0, None, callback) callback.assert_called_once_with(0, 1) - self._assert_called(module, to_move, 'train_pre_step') + self._assert_called(module, to_move, "train_pre_step") assert logic._train_step_optimizers_called @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_post_step(self, to_move): - options = {'train_report_keys': ['output']} + options = {"train_report_keys": ["output"]} handler, trainer, _ = self._get_handler(options) - module = trainer.models['main'] + module = trainer.models["main"] self._move_modules(module, to_move) reporter = ppe.reporting.Reporter() with reporter: - handler.train_post_step(trainer, 0, None, {'output': 1}) - assert reporter.observation['train/output'] == 1 - self._assert_called(module, to_move, 'train_post_step') + handler.train_post_step(trainer, 0, None, {"output": 1}) + assert reporter.observation["train/output"] == 1 + self._assert_called(module, to_move, "train_post_step") class TestHandlerValidationSync(HandlerTester): @@ -236,53 +234,53 @@ def _get_handler(self, options=None): return handler, evaluator, logic @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_eval_setup(self, to_move): handler, evaluator, _ = self._get_handler() - module = evaluator.models['main'] + module = evaluator.models["main"] self._move_modules(module, to_move) handler.eval_setup(evaluator, []) - self._assert_called(module, to_move, 'initialize') + self._assert_called(module, to_move, "initialize") @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_validation_begin(self, to_move): handler, evaluator, logic = self._get_handler() - module = evaluator.models['main'] + module = evaluator.models["main"] self._move_modules(module, to_move) handler.train_validation_begin(None, evaluator) - self._assert_called(module, to_move, 'train_validation_begin') + self._assert_called(module, to_move, "train_validation_begin") assert logic._train_validation_begin_called @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_eval_step(self, to_move): handler, evaluator, logic = self._get_handler() - module = evaluator.models['main'] + module = evaluator.models["main"] self._move_modules(module, to_move) callback = unittest.mock.Mock(return_value=None) handler.eval_step(evaluator, 0, None, callback) callback.assert_called_once_with(0, 1) - self._assert_called(module, to_move, 'eval_pre_step') + self._assert_called(module, to_move, "eval_pre_step") @pytest.mark.parametrize( - 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')] + "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")] ) def test_train_post_step(self, to_move): - options = {'eval_report_keys': ['output']} + options = {"eval_report_keys": ["output"]} handler, evaluator, _ = self._get_handler(options) - module = evaluator.models['main'] + module = evaluator.models["main"] self._move_modules(module, to_move) reporter = ppe.reporting.Reporter() with reporter: - handler.eval_post_step(evaluator, 0, None, {'output': 1}) - assert reporter.observation['val/output'] == 1 - self._assert_called(module, to_move, 'eval_post_step') + handler.eval_post_step(evaluator, 0, None, {"output": 1}) + assert reporter.observation["val/output"] == 1 + self._assert_called(module, to_move, "eval_post_step") @pytest.mark.gpu @@ -291,7 +289,7 @@ def run_autocast(self, options): trainer = MockTrainer() logic = ppe.handler.Logic(options=options) handler = ppe.handler.Handler( - logic, ppe.runtime.PyTorchRuntime('cuda', {}), {} + logic, ppe.runtime.PyTorchRuntime("cuda", {}), {} ) completed = False @@ -300,11 +298,11 @@ class _MModule(torch.nn.Module): def forward(self, x, y): return torch.mm(x, y) - trainer.models['main'] = _MModule() - trainer.optimizers['main'] = torch.optim.SGD( + trainer.models["main"] = _MModule() + trainer.optimizers["main"] = torch.optim.SGD( [torch.nn.Parameter(torch.zeros(10))], 0.01 ) - ppe.to(trainer.models['main'], 'cuda') + ppe.to(trainer.models["main"], "cuda") completed = False def callback(batch_idx, outs): @@ -322,38 +320,38 @@ def callback(batch_idx, outs): completed = True inputs = { - 'x': torch.rand((2, 2)).cuda(), - 'y': torch.rand((2, 2)).cuda(), + "x": torch.rand((2, 2)).cuda(), + "y": torch.rand((2, 2)).cuda(), } handler.train_step(trainer, 0, inputs, callback) assert completed - @pytest.mark.parametrize('autocast', [True, False]) + @pytest.mark.parametrize("autocast", [True, False]) def test_autocast(self, autocast): - self.run_autocast({'autocast': autocast}) + self.run_autocast({"autocast": autocast}) def test_autocast_not_enabled(self): old_enable = ppe.runtime._autocast._cuda_amp_available try: ppe.runtime._autocast._cuda_amp_available = False with pytest.raises(RuntimeError): - ppe.handler.Logic(options={'autocast': True}) + ppe.handler.Logic(options={"autocast": True}) finally: ppe.runtime._autocast._cuda_amp_available = old_enable - @pytest.mark.skipif(not ppe.requires("1.10.0"), reason="requires PyTorch>=1.10") + @pytest.mark.skipif( + not ppe.requires("1.10.0"), reason="requires PyTorch>=1.10" + ) @pytest.mark.parametrize( - 'device_type, dtype', - [("cpu", torch.bfloat16), ("cuda", torch.float16)] + "device_type, dtype", [("cpu", torch.bfloat16), ("cuda", torch.float16)] ) def test_autocast_options(self, device_type, dtype): self.run_autocast( - {"autocast": {'device_type': device_type, "dtype": dtype}} + {"autocast": {"device_type": device_type, "dtype": dtype}} ) class TestLogic: - def test_train_epoch_begin(self): # Check that the DataLoader has the sampler updated class _MockedDL: @@ -361,15 +359,17 @@ def __init__(self): class _Sampler: def set_epoch(self, epoch): self.epoch = epoch + self.sampler = _Sampler() + logic = ppe.handler.Logic() loader = _MockedDL() - models = {'main': torch.nn.Linear(1, 1)} + models = {"main": torch.nn.Linear(1, 1)} # The model should be set to train mode - models['main'].eval() - assert not models['main'].training + models["main"].eval() + assert not models["main"].training logic.train_epoch_begin(models, 10, loader) - assert models['main'].training + assert models["main"].training assert loader.sampler.epoch == 10 def _run_step(self, logic, device): @@ -384,26 +384,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x).sum() model = _Module().to(device) - models = {'main': model} - optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0, 0)} + models = {"main": model} + optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0, 0)} out = logic.train_step(models, optimizers, 0, input) return models, optimizers, input, out def test_train_step(self): logic = ppe.handler.Logic() - models, optimizers, input, out = self._run_step(logic, 'cpu') - model = models['main'] + models, optimizers, input, out = self._run_step(logic, "cpu") + model = models["main"] assert input.grad is not None # The gradient of a linear layer is its transposed weight torch_testing_assert_close(input.grad, model.weight.T) torch_testing_assert_close(out, model(input)) @pytest.mark.parametrize( - 'to_backprop', - [None, ('0',), ('0', '1'), ('0', '1', '2'), ('1', '2'), ('2',)] + "to_backprop", + [None, ("0",), ("0", "1"), ("0", "1", "2"), ("1", "2"), ("2",)], ) def test_train_step_backward(self, to_backprop): - logic = ppe.handler.Logic(options={'backward_outputs': to_backprop}) + logic = ppe.handler.Logic(options={"backward_outputs": to_backprop}) input = torch.rand(1, 1) input.requires_grad = True @@ -415,20 +415,22 @@ def __init__(self): self.l2 = torch.nn.Linear(1, 1) def forward(self, x): - return {'0': self.l0(x), '1': self.l1(x), '2': self.l2(x)} + return {"0": self.l0(x), "1": self.l1(x), "2": self.l2(x)} + model = _MultiOutModel() - models = {'main': model} - optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0)} + models = {"main": model} + optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0)} assert input.grad is None if to_backprop is None: - to_backprop = ('0', '1', '2') + to_backprop = ("0", "1", "2") # Copy the original parameters to check that they were not updated original_parameters = {} for val in to_backprop: - original_parameters[val] = getattr( - model, f'l{val}').weight.detach().clone() + original_parameters[val] = ( + getattr(model, f"l{val}").weight.detach().clone() + ) outs = logic.train_step(models, optimizers, 0, input) @@ -436,13 +438,14 @@ def forward(self, x): assert len(outs.keys()) == 3 grad = torch.zeros(1) for val in to_backprop: - grad = grad + getattr(model, f'l{val}').weight.T + grad = grad + getattr(model, f"l{val}").weight.T torch_testing_assert_close(input.grad, grad) # Check that logic step does not change the value of weight for val in original_parameters: torch_testing_assert_close( - original_parameters[val], getattr(model, f'l{val}').weight) + original_parameters[val], getattr(model, f"l{val}").weight + ) def test_train_step_backward_nograd(self): logic = ppe.handler.Logic() @@ -455,19 +458,19 @@ def __init__(self): self.l0 = torch.nn.Linear(1, 1) def forward(self, x): - return {'0': x} + return {"0": x} model = _DummyModel() - models = {'main': model} - optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0)} + models = {"main": model} + optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0)} assert input.grad is None outs = logic.train_step(models, optimizers, 0, input) - assert outs['0'].grad is None + assert outs["0"].grad is None def test_train_step_backward_invalid(self): - logic = ppe.handler.Logic(options={'backward_outputs': 'abcd'}) + logic = ppe.handler.Logic(options={"backward_outputs": "abcd"}) input = torch.rand(1, 1) input.requires_grad = True @@ -477,20 +480,20 @@ def __init__(self): self.l0 = torch.nn.Linear(1, 1) def forward(self, x): - return {'0': x} + return {"0": x} model = _DummyModel() - models = {'main': model} - optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0)} + models = {"main": model} + optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0)} assert input.grad is None - with pytest.warns(UserWarning, match='backward value: abcd'): + with pytest.warns(UserWarning, match="backward value: abcd"): logic.train_step(models, optimizers, 0, input) def test_train_step_optimizers(self): logic = ppe.handler.Logic() - models, optimizers, input, out = self._run_step(logic, 'cpu') - model = models['main'] + models, optimizers, input, out = self._run_step(logic, "cpu") + model = models["main"] m_weight = model.weight.clone().detach() w_grad = model.weight.grad.clone().detach() logic.train_step_optimizers(model, optimizers, 0) @@ -500,10 +503,10 @@ def test_train_step_optimizers(self): @pytest.mark.gpu def test_grad_scaler(self): scaler = torch.cuda.amp.GradScaler() - options = {'grad_scaler': scaler} + options = {"grad_scaler": scaler} logic = ppe.handler.Logic(options=options) - models, optimizers, input, out = self._run_step(logic, 'cuda') - model = models['main'] + models, optimizers, input, out = self._run_step(logic, "cuda") + model = models["main"] m_weight = model.weight.clone().detach() w_grad = model.weight.grad.clone().detach() # The gradient of a linear layer is its transposed weight @@ -513,11 +516,12 @@ def test_grad_scaler(self): # Checks that the value was correctly updated and gradients deescaled # before the update torch_testing_assert_close( - scaler.scale(m_weight) - w_grad, scaler.scale(model.weight.T)) + scaler.scale(m_weight) - w_grad, scaler.scale(model.weight.T) + ) @pytest.mark.gpu def test_invalid_grad_scaler(self): - options = {'grad_scaler': object()} + options = {"grad_scaler": object()} with pytest.raises(RuntimeError): ppe.handler.Logic(options=options) @@ -526,7 +530,7 @@ def test_disabled_grad_scaler(self): old_enable = ppe.runtime._autocast._cuda_amp_available try: ppe.runtime._autocast._cuda_amp_available = False - options = {'grad_scaler': torch.cuda.amp.GradScaler()} + options = {"grad_scaler": torch.cuda.amp.GradScaler()} with pytest.raises(RuntimeError): ppe.handler.Logic(options=options) finally: @@ -534,23 +538,23 @@ def test_disabled_grad_scaler(self): def test_train_validation_begin(self): logic = ppe.handler.Logic() - models = {'main': torch.nn.Linear(1, 1)} - models['main'].train() - assert models['main'].training + models = {"main": torch.nn.Linear(1, 1)} + models["main"].train() + assert models["main"].training logic.train_validation_begin(models) - assert not models['main'].training + assert not models["main"].training def test_eval_step(self): logic = ppe.handler.Logic() input = torch.rand(1, 1) model = torch.nn.Linear(1, 1) - models = {'main': model} - models['main'].eval() + models = {"main": model} + models["main"].eval() out = logic.eval_step(models, 0, input) torch_testing_assert_close(out, model(input)) @pytest.mark.gpu def test_use_grad_scaler_with_clousure(self): - options = {'grad_scaler': torch.cuda.amp.GradScaler()} + options = {"grad_scaler": torch.cuda.amp.GradScaler()} with pytest.raises(RuntimeError): ppe.handler.ClousureLogic(options=options) diff --git a/tests/pytorch_pfn_extras_tests/test_logging.py b/tests/pytorch_pfn_extras_tests/test_logging.py index e4709afa9..55e8ee9cb 100644 --- a/tests/pytorch_pfn_extras_tests/test_logging.py +++ b/tests/pytorch_pfn_extras_tests/test_logging.py @@ -7,17 +7,17 @@ def test_file_output(): try: with tempfile.NamedTemporaryFile() as logfile: logfile.close() # this is needed for Windows - logging._configure_logging(filename=logfile.name, level='DEBUG') + logging._configure_logging(filename=logfile.name, level="DEBUG") logger = logging._get_root_logger() - logger.info('TEST LOG MESSAGE') + logger.info("TEST LOG MESSAGE") with open(logfile.name) as f: - assert 'TEST LOG MESSAGE' in f.read() + assert "TEST LOG MESSAGE" in f.read() finally: logging._configure_logging() def test_get_logger(): - logger = logging.get_logger('app') + logger = logging.get_logger("app") logger.setLevel(logging.DEBUG) - assert logger.name == 'ppe.app' + assert logger.name == "ppe.app" assert logger.level == logging.DEBUG diff --git a/tests/pytorch_pfn_extras_tests/test_logic.py b/tests/pytorch_pfn_extras_tests/test_logic.py index 9d85f8821..6bf97fd33 100644 --- a/tests/pytorch_pfn_extras_tests/test_logic.py +++ b/tests/pytorch_pfn_extras_tests/test_logic.py @@ -1,13 +1,13 @@ from typing import Any, Mapping -import pytest +from unittest import mock +import pytest +import pytorch_pfn_extras as ppe import torch from torch import nn +from torch.nn import Module +from torch.nn import functional as F from torch.optim import Optimizer -from torch.nn import functional as F, Module -from unittest import mock - -import pytorch_pfn_extras as ppe class MyModel(torch.nn.Module): @@ -35,15 +35,15 @@ def __init__(self, model): def forward(self, x, t): y = self.model(x) - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" loss = F.l1_loss(y, t) - ppe.reporting.report({prefix + '/loss': loss}) + ppe.reporting.report({prefix + "/loss": loss}) return loss -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_trainer(device): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() iters_per_epoch = 10 epochs = 20 @@ -52,19 +52,41 @@ def test_trainer(device): model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(iters_per_epoch) + ] + ) backward_fn = mock.Mock(return_value=None) trainer = ppe.engine.create_trainer( - model_with_loss, optimizer, epochs, + model_with_loss, + optimizer, + epochs, device=device, - options={'backward_function': backward_fn} + options={"backward_function": backward_fn}, ) trainer.run(data) assert backward_fn.call_count == epochs * iters_per_epoch -@pytest.mark.parametrize("trigger", [(1, "epoch"), (0.5, "epoch"), (10, "iteration"), (5, "iteration"), (1, "iteration")]) +@pytest.mark.parametrize( + "trigger", + [ + (1, "epoch"), + (0.5, "epoch"), + (10, "iteration"), + (5, "iteration"), + (1, "iteration"), + ], +) def test_train_step_mode_with_evaluator(trigger): iters_per_epoch = 10 epochs = 20 @@ -73,11 +95,28 @@ def test_train_step_mode_with_evaluator(trigger): model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(iters_per_epoch) + ] + ) backward_fn = mock.Mock(return_value=None) class LogicWithTrainStepCheck(ppe.handler.Logic): - def train_step(self, models: Mapping[str, Module], optimizers: Mapping[str, Optimizer], batch_idx: int, batch: Any) -> Any: + def train_step( + self, + models: Mapping[str, Module], + optimizers: Mapping[str, Optimizer], + batch_idx: int, + batch: Any, + ) -> Any: model = models[self.model_name] assert model.training return super().train_step(models, optimizers, batch_idx, batch) @@ -94,7 +133,7 @@ def train_step(self, models: Mapping[str, Module], optimizers: Mapping[str, Opti ), trigger, ), - options={'backward_function': backward_fn} + options={"backward_function": backward_fn}, ) trainer.run(data, data) assert backward_fn.call_count == epochs * iters_per_epoch diff --git a/tests/pytorch_pfn_extras_tests/test_reporter.py b/tests/pytorch_pfn_extras_tests/test_reporter.py index 463e5d907..44b6e59e0 100644 --- a/tests/pytorch_pfn_extras_tests/test_reporter.py +++ b/tests/pytorch_pfn_extras_tests/test_reporter.py @@ -6,9 +6,8 @@ import numpy import pytest -import torch - import pytorch_pfn_extras as ppe +import torch def test_empty_reporter(): @@ -41,12 +40,8 @@ def thread_func(reporter, record): record2 = [] reporter1 = ppe.reporting.Reporter() reporter2 = ppe.reporting.Reporter() - thread1 = threading.Thread( - target=thread_func, - args=(reporter1, record1)) - thread2 = threading.Thread( - target=thread_func, - args=(reporter2, record2)) + thread1 = threading.Thread(target=thread_func, args=(reporter1, record1)) + thread2 = threading.Thread(target=thread_func, args=(reporter2, record2)) thread1.daemon = True thread2.daemon = True thread1.start() @@ -72,71 +67,72 @@ def test_scope(): def test_add_observer(): reporter = ppe.reporting.Reporter() observer = object() - reporter.add_observer('o', observer) + reporter.add_observer("o", observer) - reporter.report({'x': 1}, observer) + reporter.report({"x": 1}, observer) observation = reporter.observation - assert 'o/x' in observation - assert observation['o/x'] == 1 - assert 'x'not in observation + assert "o/x" in observation + assert observation["o/x"] == 1 + assert "x" not in observation def test_add_observers(): reporter = ppe.reporting.Reporter() observer1 = object() - reporter.add_observer('o1', observer1) + reporter.add_observer("o1", observer1) observer2 = object() - reporter.add_observer('o2', observer2) + reporter.add_observer("o2", observer2) - reporter.report({'x': 1}, observer1) - reporter.report({'y': 2}, observer2) + reporter.report({"x": 1}, observer1) + reporter.report({"y": 2}, observer2) observation = reporter.observation - assert 'o1/x' in observation - assert observation['o1/x'] == 1 - assert 'o2/y' in observation - assert observation['o2/y'] == 2 - assert 'x' not in observation - assert 'y' not in observation - assert 'o1/y' not in observation - assert 'o2/x' not in observation + assert "o1/x" in observation + assert observation["o1/x"] == 1 + assert "o2/y" in observation + assert observation["o2/y"] == 2 + assert "x" not in observation + assert "y" not in observation + assert "o1/y" not in observation + assert "o2/x" not in observation def test_report_without_observer(): reporter = ppe.reporting.Reporter() - reporter.report({'x': 1}) + reporter.report({"x": 1}) observation = reporter.observation - assert 'x' in observation - assert observation['x'] == 1 + assert "x" in observation + assert observation["x"] == 1 # ppe.reporting.report + def test_report_without_reporter(): observer = object() - ppe.reporting.report({'x': 1}, observer) + ppe.reporting.report({"x": 1}, observer) def test_report(): reporter = ppe.reporting.Reporter() with reporter: - ppe.reporting.report({'x': 1}) + ppe.reporting.report({"x": 1}) observation = reporter.observation - assert 'x' in observation - assert observation['x'] == 1 + assert "x" in observation + assert observation["x"] == 1 def test_report_with_observer(): reporter = ppe.reporting.Reporter() observer = object() - reporter.add_observer('o', observer) + reporter.add_observer("o", observer) with reporter: - ppe.reporting.report({'x': 1}, observer) + ppe.reporting.report({"x": 1}, observer) observation = reporter.observation - assert 'o/x' in observation - assert observation['o/x'] == 1 + assert "o/x" in observation + assert observation["o/x"] == 1 def test_report_with_unregistered_observer(): @@ -144,7 +140,7 @@ def test_report_with_unregistered_observer(): observer = object() with reporter: with pytest.raises(KeyError): - ppe.reporting.report({'x': 1}, observer) + ppe.reporting.report({"x": 1}, observer) def test_report_scope(): @@ -153,21 +149,21 @@ def test_report_scope(): with reporter: with ppe.reporting.report_scope(observation): - ppe.reporting.report({'x': 1}) + ppe.reporting.report({"x": 1}) - assert 'x' in observation - assert observation['x'] == 1 - assert 'x' not in reporter.observation + assert "x" in observation + assert observation["x"] == 1 + assert "x" not in reporter.observation def test_report_tensor_detached(): reporter = ppe.reporting.Reporter() - x = torch.tensor(numpy.array(1, 'float32'), requires_grad=True) + x = torch.tensor(numpy.array(1, "float32"), requires_grad=True) with reporter: - ppe.reporting.report({'x': x}) + ppe.reporting.report({"x": x}) observation = reporter.observation - assert 'x' in observation - assert not observation['x'].requires_grad + assert "x" in observation + assert not observation["x"].requires_grad assert x.requires_grad @@ -182,17 +178,18 @@ def test_report_callable(): # ppe.reporting.Summary + def test_summary_basic(): summary = ppe.reporting.Summary() - summary.add(torch.Tensor(numpy.array(1, 'float32'))) - summary.add(torch.Tensor(numpy.array(-2, 'float32'))) + summary.add(torch.Tensor(numpy.array(1, "float32"))) + summary.add(torch.Tensor(numpy.array(-2, "float32"))) mean = summary.compute_mean() - numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, 'f')) + numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, "f")) mean, std = summary.make_statistics() - numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, 'f')) - numpy.testing.assert_allclose(std.numpy(), numpy.array(1.5, 'f')) + numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, "f")) + numpy.testing.assert_allclose(std.numpy(), numpy.array(1.5, "f")) def test_summary_int(): @@ -206,28 +203,28 @@ def test_summary_int(): mean, std = summary.make_statistics() numpy.testing.assert_allclose(mean, 2) - numpy.testing.assert_allclose(std, numpy.sqrt(2. / 3.)) + numpy.testing.assert_allclose(std, numpy.sqrt(2.0 / 3.0)) def test_summary_float(): summary = ppe.reporting.Summary() - summary.add(1.) - summary.add(2.) - summary.add(3.) + summary.add(1.0) + summary.add(2.0) + summary.add(3.0) mean = summary.compute_mean() - numpy.testing.assert_allclose(mean, 2.) + numpy.testing.assert_allclose(mean, 2.0) mean, std = summary.make_statistics() - numpy.testing.assert_allclose(mean, 2.) - numpy.testing.assert_allclose(std, numpy.sqrt(2. / 3.)) + numpy.testing.assert_allclose(mean, 2.0) + numpy.testing.assert_allclose(std, numpy.sqrt(2.0 / 3.0)) def test_summary_weight(): summary = ppe.reporting.Summary() - summary.add(1., 0.5) - summary.add(2., numpy.array(0.4)) - summary.add(3., torch.autograd.Variable(torch.Tensor(numpy.array(0.3)))) + summary.add(1.0, 0.5) + summary.add(2.0, numpy.array(0.4)) + summary.add(3.0, torch.autograd.Variable(torch.Tensor(numpy.array(0.3)))) mean = summary.compute_mean() val = (1 * 0.5 + 2 * 0.4 + 3 * 0.3) / (0.5 + 0.4 + 0.3) @@ -244,21 +241,21 @@ def test_summary_deferred_add(): def test_summary_add_operator(): s1 = ppe.reporting.Summary() - s1.add(1.) - s1.add(2.) - s1.add(3.) + s1.add(1.0) + s1.add(2.0) + s1.add(3.0) s2 = ppe.reporting.Summary() - s2.add(4.) - s2.add(5.) - s2.add(6.) - s2.add(7.) + s2.add(4.0) + s2.add(5.0) + s2.add(6.0) + s2.add(7.0) s = s1 + s2 assert s.state_dict() == { - '_x': 28.0, - '_x2': 140.0, - '_n': 7, + "_x": 28.0, + "_x2": 140.0, + "_n": 7, } @@ -279,12 +276,14 @@ def _check_summary_serialize(value1, value2, value3): torch.save(summary.state_dict(), f.name) # Load tensors in CPU to simulate a snapshot restore summary2.load_state_dict( - torch.load(f.name, map_location=torch.device('cpu'))) + torch.load(f.name, map_location=torch.device("cpu")) + ) summary2.add(value3) - expected_mean = float((value1 + value2 + value3) / 3.) + expected_mean = float((value1 + value2 + value3) / 3.0) expected_std = math.sqrt( - (value1**2 + value2**2 + value3**2) / 3. - expected_mean**2) + (value1**2 + value2**2 + value3**2) / 3.0 - expected_mean**2 + ) mean = summary2.compute_mean() if isinstance(mean, torch.Tensor): @@ -305,21 +304,25 @@ def test_serialize_array_float(): numpy.array(1.5, numpy.float32), numpy.array(2.0, numpy.float32), # sum of the above two is non-integer - numpy.array(3.5, numpy.float32)) + numpy.array(3.5, numpy.float32), + ) def test_serialize_array_int(): _check_summary_serialize( numpy.array(1, numpy.int32), numpy.array(-2, numpy.int32), - numpy.array(2, numpy.int32)) + numpy.array(2, numpy.int32), + ) def test_serialize_scalar_float(): _check_summary_serialize( - 1.5, 2.0, + 1.5, + 2.0, # sum of the above two is non-integer - 3.5) + 3.5, + ) def test_serialize_scalar_int(): @@ -328,9 +331,8 @@ def test_serialize_scalar_int(): def test_serialize_tensor(): _check_summary_serialize( - torch.tensor(1.5), - torch.tensor(2.0), - torch.tensor(3.5)) + torch.tensor(1.5), torch.tensor(2.0), torch.tensor(3.5) + ) @pytest.mark.gpu @@ -338,18 +340,21 @@ def test_serialize_tensor_cuda(): _check_summary_serialize( torch.tensor(1.5).cuda(), torch.tensor(2.0).cuda(), - torch.tensor(3.5).cuda()) + torch.tensor(3.5).cuda(), + ) def test_serialize_tensor_with_grad(): _check_summary_serialize( torch.tensor(1.5, requires_grad=True), torch.tensor(2.0, requires_grad=True), - 3.5) + 3.5, + ) # ppe.reporting.DictSummary + def _check_dict_summary(summary, data): mean = summary.compute_mean() assert set(mean.keys()) == set(data.keys()) @@ -358,150 +363,171 @@ def _check_dict_summary(summary, data): numpy.testing.assert_allclose(mean[name], m) stats = summary.make_statistics() - assert ( - set(stats.keys()) - == set(data.keys()).union(name + '.std' for name in data.keys())) + assert set(stats.keys()) == set(data.keys()).union( + name + ".std" for name in data.keys() + ) for name in data.keys(): m = sum(data[name]) / float(len(data[name])) s = numpy.sqrt( - sum(x * x for x in data[name]) / float(len(data[name])) - - m * m) + sum(x * x for x in data[name]) / float(len(data[name])) - m * m + ) numpy.testing.assert_allclose(stats[name], m) - numpy.testing.assert_allclose(stats[name + '.std'], s) + numpy.testing.assert_allclose(stats[name + ".std"], s) def test_dict_summary(): summary = ppe.reporting.DictSummary() - summary.add({'numpy': numpy.array(3, 'f'), 'int': 1, 'float': 4.}) - summary.add({'numpy': numpy.array(1, 'f'), 'int': 5, 'float': 9.}) - summary.add({'numpy': numpy.array(2, 'f'), 'int': 6, 'float': 5.}) - summary.add({'numpy': numpy.array(3, 'f'), 'int': 5, 'float': 8.}) + summary.add({"numpy": numpy.array(3, "f"), "int": 1, "float": 4.0}) + summary.add({"numpy": numpy.array(1, "f"), "int": 5, "float": 9.0}) + summary.add({"numpy": numpy.array(2, "f"), "int": 6, "float": 5.0}) + summary.add({"numpy": numpy.array(3, "f"), "int": 5, "float": 8.0}) - _check_dict_summary(summary, { - 'numpy': (3., 1., 2., 3.), - 'int': (1, 5, 6, 5), - 'float': (4., 9., 5., 8.), - }) + _check_dict_summary( + summary, + { + "numpy": (3.0, 1.0, 2.0, 3.0), + "int": (1, 5, 6, 5), + "float": (4.0, 9.0, 5.0, 8.0), + }, + ) def test_dit_summary_sparse(): summary = ppe.reporting.DictSummary() - summary.add({'a': 3., 'b': 1.}) - summary.add({'a': 1., 'b': 5., 'c': 9.}) - summary.add({'b': 6.}) - summary.add({'a': 3., 'b': 5., 'c': 8.}) + summary.add({"a": 3.0, "b": 1.0}) + summary.add({"a": 1.0, "b": 5.0, "c": 9.0}) + summary.add({"b": 6.0}) + summary.add({"a": 3.0, "b": 5.0, "c": 8.0}) - _check_dict_summary(summary, { - 'a': (3., 1., 3.), - 'b': (1., 5., 6., 5.), - 'c': (9., 8.), - }) + _check_dict_summary( + summary, + { + "a": (3.0, 1.0, 3.0), + "b": (1.0, 5.0, 6.0, 5.0), + "c": (9.0, 8.0), + }, + ) def test_dict_summary_weight(): summary = ppe.reporting.DictSummary() - summary.add({'a': (1., 0.5)}) - summary.add({'a': (2., numpy.array(0.4))}) + summary.add({"a": (1.0, 0.5)}) + summary.add({"a": (2.0, numpy.array(0.4))}) summary.add( - {'a': (3., torch.autograd.Variable(torch.Tensor(numpy.array(0.3))))}) + {"a": (3.0, torch.autograd.Variable(torch.Tensor(numpy.array(0.3))))} + ) mean = summary.compute_mean() val = (1 * 0.5 + 2 * 0.4 + 3 * 0.3) / (0.5 + 0.4 + 0.3) - numpy.testing.assert_allclose(mean['a'].numpy(), val) + numpy.testing.assert_allclose(mean["a"].numpy(), val) arr = numpy.array([0.5]) with pytest.raises(ValueError): - summary.add({'a': (4., arr)}) + summary.add({"a": (4.0, arr)}) var = torch.autograd.Variable(torch.Tensor(numpy.array([0.5]))) with pytest.raises(ValueError): - summary.add({'a': (4., var)}) + summary.add({"a": (4.0, var)}) def test_dict_summary_serialize(): summary = ppe.reporting.DictSummary() - summary.add({'numpy': numpy.array(3, 'f'), 'int': 1, 'float': 4.}) - summary.add({'numpy': numpy.array(1, 'f'), 'int': 5, 'float': 9.}) - summary.add({'numpy': numpy.array(2, 'f'), 'int': 6, 'float': 5.}) + summary.add({"numpy": numpy.array(3, "f"), "int": 1, "float": 4.0}) + summary.add({"numpy": numpy.array(1, "f"), "int": 5, "float": 9.0}) + summary.add({"numpy": numpy.array(2, "f"), "int": 6, "float": 5.0}) summary2 = ppe.reporting.DictSummary() summary2.load_state_dict(summary.state_dict()) - summary2.add({'numpy': numpy.array(3, 'f'), 'int': 5, 'float': 8.}) + summary2.add({"numpy": numpy.array(3, "f"), "int": 5, "float": 8.0}) - _check_dict_summary(summary2, { - 'numpy': (3., 1., 2., 3.), - 'int': (1, 5, 6, 5), - 'float': (4., 9., 5., 8.), - }) + _check_dict_summary( + summary2, + { + "numpy": (3.0, 1.0, 2.0, 3.0), + "int": (1, 5, 6, 5), + "float": (4.0, 9.0, 5.0, 8.0), + }, + ) -@pytest.mark.parametrize('delimiter', ['/', '.']) +@pytest.mark.parametrize("delimiter", ["/", "."]) @pytest.mark.parametrize( # How the state of the summary is transferred. - 'transfer_protocol', + "transfer_protocol", [ - 'direct', # Use state_dict() and load_state_dict() - 'torch', # Use torch.save() and torch.load() - ]) + "direct", # Use state_dict() and load_state_dict() + "torch", # Use torch.save() and torch.load() + ], +) def test_dict_summary_serialize_names_with_delimiter( - delimiter, transfer_protocol): - key1 = 'a{d}b'.format(d=delimiter) - key2 = '{d}a{d}b'.format(d=delimiter) - key3 = 'a{d}b{d}'.format(d=delimiter) + delimiter, transfer_protocol +): + key1 = "a{d}b".format(d=delimiter) + key2 = "{d}a{d}b".format(d=delimiter) + key3 = "a{d}b{d}".format(d=delimiter) summary = ppe.reporting.DictSummary() - summary.add({key1: 3., key2: 1., key3: 4.}) - summary.add({key1: 1., key2: 5., key3: 9.}) - summary.add({key1: 2., key2: 6., key3: 5.}) + summary.add({key1: 3.0, key2: 1.0, key3: 4.0}) + summary.add({key1: 1.0, key2: 5.0, key3: 9.0}) + summary.add({key1: 2.0, key2: 6.0, key3: 5.0}) - if transfer_protocol == 'direct': + if transfer_protocol == "direct": summary2 = ppe.reporting.DictSummary() summary2.load_state_dict(summary.state_dict()) else: - assert transfer_protocol == 'torch' + assert transfer_protocol == "torch" f = io.BytesIO() torch.save(summary, f) summary2 = torch.load(io.BytesIO(f.getvalue())) - summary2.add({key1: 3., key2: 5., key3: 8.}) + summary2.add({key1: 3.0, key2: 5.0, key3: 8.0}) - _check_dict_summary(summary2, { - key1: (3., 1., 2., 3.), - key2: (1., 5., 6., 5.), - key3: (4., 9., 5., 8.), - }) + _check_dict_summary( + summary2, + { + key1: (3.0, 1.0, 2.0, 3.0), + key2: (1.0, 5.0, 6.0, 5.0), + key3: (4.0, 9.0, 5.0, 8.0), + }, + ) def test_serialize_overwrite_different_names(): summary = ppe.reporting.DictSummary() - summary.add({'a': 3., 'b': 1.}) - summary.add({'a': 1., 'b': 5.}) + summary.add({"a": 3.0, "b": 1.0}) + summary.add({"a": 1.0, "b": 5.0}) summary2 = ppe.reporting.DictSummary() - summary2.add({'c': 5.}) + summary2.add({"c": 5.0}) summary2.load_state_dict(summary.state_dict()) - _check_dict_summary(summary2, { - 'a': (3., 1.), - 'b': (1., 5.), - }) + _check_dict_summary( + summary2, + { + "a": (3.0, 1.0), + "b": (1.0, 5.0), + }, + ) def test_serialize_overwrite_rollback(): summary = ppe.reporting.DictSummary() - summary.add({'a': 3., 'b': 1.}) - summary.add({'a': 1., 'b': 5.}) + summary.add({"a": 3.0, "b": 1.0}) + summary.add({"a": 1.0, "b": 5.0}) state = summary.state_dict() - summary.add({'a': 2., 'b': 6., 'c': 5.}) - summary.add({'a': 3., 'b': 4., 'c': 6.}) + summary.add({"a": 2.0, "b": 6.0, "c": 5.0}) + summary.add({"a": 3.0, "b": 4.0, "c": 6.0}) summary.load_state_dict(state) - summary.add({'a': 3., 'b': 5., 'c': 8.}) + summary.add({"a": 3.0, "b": 5.0, "c": 8.0}) - _check_dict_summary(summary, { - 'a': (3., 1., 3.), - 'b': (1., 5., 5.), - 'c': (8.,), - }) + _check_dict_summary( + summary, + { + "a": (3.0, 1.0, 3.0), + "b": (1.0, 5.0, 5.0), + "c": (8.0,), + }, + ) def test_dict_summary_deferred_add(): @@ -509,26 +535,32 @@ def test_dict_summary_deferred_add(): summary.add({"x": lambda: 1.0, "y": lambda: 2.0}) summary.add({"x": lambda: -1.0}) - _check_dict_summary(summary, { - "x": (1.0, -1.0), - "y": (2.0,), - }) + _check_dict_summary( + summary, + { + "x": (1.0, -1.0), + "y": (2.0,), + }, + ) def test_dict_summary_add_operator(): s1 = ppe.reporting.DictSummary() - s1.add({'a': 1., 'b': 0.1, 'c': 0.01}) - s1.add({'a': 2., 'b': 0.2, 'c': 0.02}) + s1.add({"a": 1.0, "b": 0.1, "c": 0.01}) + s1.add({"a": 2.0, "b": 0.2, "c": 0.02}) s2 = ppe.reporting.DictSummary() - s2.add({'a': 3., 'b': 0.3, 'f': 0.03}) - s2.add({'a': 4., 'b': 0.4, 'f': 0.04}) - s2.add({'a': 5., 'b': 0.5, 'f': 0.05}) + s2.add({"a": 3.0, "b": 0.3, "f": 0.03}) + s2.add({"a": 4.0, "b": 0.4, "f": 0.04}) + s2.add({"a": 5.0, "b": 0.5, "f": 0.05}) s = s1 + s2 - _check_dict_summary(s, { - 'a': [1, 2, 3, 4, 5], - 'b': [0.1, 0.2, 0.3, 0.4, 0.5], - 'c': [0.01, 0.02], - 'f': [0.03, 0.04, 0.05], - }) + _check_dict_summary( + s, + { + "a": [1, 2, 3, 4, 5], + "b": [0.1, 0.2, 0.3, 0.4, 0.5], + "c": [0.01, 0.02], + "f": [0.03, 0.04, 0.05], + }, + ) diff --git a/tests/pytorch_pfn_extras_tests/test_tensor.py b/tests/pytorch_pfn_extras_tests/test_tensor.py index a05ff280a..7f1d6ad10 100644 --- a/tests/pytorch_pfn_extras_tests/test_tensor.py +++ b/tests/pytorch_pfn_extras_tests/test_tensor.py @@ -1,10 +1,9 @@ import numpy import pytest +import pytorch_pfn_extras as ppe import torch import torch.utils.dlpack -import pytorch_pfn_extras as ppe - def test_from_ndarray_numpy(): np_arr = numpy.arange(24).reshape(2, 3, 4) @@ -21,7 +20,7 @@ def test_from_ndarray_numpy_neg(): def test_from_ndarray_cupy(): - cupy = pytest.importorskip('cupy') + cupy = pytest.importorskip("cupy") cp_arr = cupy.arange(24).reshape(2, 3, 4) tensor = ppe.from_ndarray(cp_arr) assert cp_arr.data.ptr == tensor.data_ptr() @@ -29,7 +28,7 @@ def test_from_ndarray_cupy(): def test_from_ndarray_cupy_neg(): - cupy = pytest.importorskip('cupy') + cupy = pytest.importorskip("cupy") cp_arr = cupy.flip(cupy.arange(24).reshape(2, 3, 4)) tensor = ppe.from_ndarray(cp_arr) # copy assert cp_arr.data.ptr != tensor.data_ptr() @@ -50,7 +49,7 @@ def test_as_ndarray_cpu(): def test_as_ndarray_cupy(): - cupy = pytest.importorskip('cupy') + cupy = pytest.importorskip("cupy") tensor = torch.arange(24).reshape(2, 3, 4).cuda() arr = ppe.as_ndarray(tensor) assert isinstance(arr, cupy.ndarray) @@ -60,14 +59,14 @@ def test_as_ndarray_cupy(): def test_get_xp_numpy(): assert ppe.get_xp(torch.ones(4)) is numpy - assert ppe.get_xp(torch.device('cpu')) is numpy + assert ppe.get_xp(torch.device("cpu")) is numpy assert ppe.get_xp(numpy.ones(4)) is numpy def test_get_xp_cupy(): - cupy = pytest.importorskip('cupy') + cupy = pytest.importorskip("cupy") assert ppe.get_xp(torch.ones(4).cuda()) is cupy - assert ppe.get_xp(torch.device('cuda:0')) is cupy + assert ppe.get_xp(torch.device("cuda:0")) is cupy assert ppe.get_xp(cupy.ones(1)) is cupy @@ -76,12 +75,22 @@ def test_get_xp_invalid_type(): ppe.get_xp([1, 2, 3]) -@pytest.mark.parametrize('dtype', [ - 'bool', - 'uint8', 'int8', 'int16', 'int32', 'int64', - 'float16', 'float32', 'float64', - 'complex64', 'complex128', -]) +@pytest.mark.parametrize( + "dtype", + [ + "bool", + "uint8", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + ], +) def test_torch_numpy_dtype(dtype): torch_dtype = getattr(torch, dtype) numpy_dtype = numpy.dtype(dtype) @@ -93,6 +102,6 @@ def test_torch_numpy_dtype_unsupported(): with pytest.raises(TypeError): ppe.as_numpy_dtype(torch.bfloat16) with pytest.raises(TypeError): - ppe.from_numpy_dtype(numpy.dtype('object')) + ppe.from_numpy_dtype(numpy.dtype("object")) with pytest.raises(TypeError): ppe.from_numpy_dtype(None) diff --git a/tests/pytorch_pfn_extras_tests/test_torchscript.py b/tests/pytorch_pfn_extras_tests/test_torchscript.py index 840a9c9c7..d3460b573 100644 --- a/tests/pytorch_pfn_extras_tests/test_torchscript.py +++ b/tests/pytorch_pfn_extras_tests/test_torchscript.py @@ -1,10 +1,10 @@ -import torch import pytorch_pfn_extras.torchscript as ts +import torch def test_find_inplace(): def f(v: torch.Tensor) -> None: - v += torch.ones((1,2,3)) + v += torch.ones((1, 2, 3)) def g(v: torch.Tensor): f(v) @@ -17,7 +17,7 @@ def g(v: torch.Tensor): def test_find_inplace_not_found(): def f(v: torch.Tensor) -> torch.Tensor: - return torch.ones((1,2,3)) + return torch.ones((1, 2, 3)) s = torch.jit.script(f) diff --git a/tests/pytorch_pfn_extras_tests/test_writing.py b/tests/pytorch_pfn_extras_tests/test_writing.py index 47106dd6f..b8ea93274 100644 --- a/tests/pytorch_pfn_extras_tests/test_writing.py +++ b/tests/pytorch_pfn_extras_tests/test_writing.py @@ -1,20 +1,22 @@ -import tempfile import os +import tempfile import pytest - import pytorch_pfn_extras as ppe -@pytest.mark.filterwarnings("ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning") +@pytest.mark.filterwarnings( + "ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning" +) def test_tensorboard_writing(): - pytest.importorskip('tensorboard') + pytest.importorskip("tensorboard") data = {"a": 1, "iteration": 1} with tempfile.TemporaryDirectory() as tempd: writer = ppe.writing.TensorBoardWriter( - out_dir=tempd, filename_suffix='_test') + out_dir=tempd, filename_suffix="_test" + ) writer(None, None, data) # Check that the file was generated for snap in os.listdir(tempd): - assert '_test' in snap + assert "_test" in snap writer.finalize() diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py index ab5c37b68..908da5624 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py @@ -1,45 +1,55 @@ import tempfile import pytest - import pytorch_pfn_extras as ppe from pytorch_pfn_extras.training import extensions - params = [ - ([None, 4.0, 4.5, 3.0, 3.5], - [4.0, 4.0, 3.0, 3.0], - [1, 1, 3, 3], - [3, 3, 9, 9], - extensions.MinValue), - ([None, 3.0, 4.5, 4.0, 5.0], - [3.0, 4.5, 4.5, 5.0], - [1, 2, 2, 4], - [3, 6, 6, 12], - extensions.MaxValue), + ( + [None, 4.0, 4.5, 3.0, 3.5], + [4.0, 4.0, 3.0, 3.0], + [1, 1, 3, 3], + [3, 3, 9, 9], + extensions.MinValue, + ), + ( + [None, 3.0, 4.5, 4.0, 5.0], + [3.0, 4.5, 4.5, 5.0], + [1, 2, 2, 4], + [3, 6, 6, 12], + extensions.MaxValue, + ), ] -@pytest.mark.parametrize('observed_values,expected_best_values,' - 'expected_best_epochs,expected_best_iterations,BestValueT', - params) -def test_best_observation(observed_values, expected_best_values, - expected_best_epochs, expected_best_iterations, BestValueT): +@pytest.mark.parametrize( + "observed_values,expected_best_values," + "expected_best_epochs,expected_best_iterations,BestValueT", + params, +) +def test_best_observation( + observed_values, + expected_best_values, + expected_best_epochs, + expected_best_iterations, + BestValueT, +): max_epochs = 4 iters_per_epoch = 3 with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir + ) def observer_fn(manager): return observed_values[manager.epoch] - observe = extensions.observe_value('value', observer_fn) - manager.extend(observe, trigger=(1, 'epoch')) + observe = extensions.observe_value("value", observer_fn) + manager.extend(observe, trigger=(1, "epoch")) - best_value = BestValueT('value') - manager.extend(best_value, trigger=(1, 'epoch')) + best_value = BestValueT("value") + manager.extend(best_value, trigger=(1, "epoch")) for epoch in range(max_epochs): for _ in range(iters_per_epoch): @@ -51,16 +61,16 @@ def observer_fn(manager): # Save/Load state dict (snapshot support) assert best_value.state_dict() == { - '_best_trigger': { - '_best_value': expected_best_values[-1], - '_summary': {}, - 'interval_trigger': {} + "_best_trigger": { + "_best_value": expected_best_values[-1], + "_summary": {}, + "interval_trigger": {}, }, - '_best_it': expected_best_iterations[-1], - '_best_epoch': expected_best_epochs[-1], + "_best_it": expected_best_iterations[-1], + "_best_epoch": expected_best_epochs[-1], } - best_value2 = BestValueT('value') + best_value2 = BestValueT("value") best_value2.load_state_dict(best_value.state_dict()) assert best_value2.best_value == expected_best_values[-1] assert best_value2.best_iteration == expected_best_iterations[-1] @@ -73,10 +83,11 @@ def test_key_error(): with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir + ) - best_observation = extensions.BestValue('value', lambda a, b: a < b) - manager.extend(best_observation, trigger=(1, 'epoch')) + best_observation = extensions.BestValue("value", lambda a, b: a < b) + manager.extend(best_observation, trigger=(1, "epoch")) with pytest.raises(KeyError) as e: for _ in range(max_epochs): @@ -87,7 +98,7 @@ def test_key_error(): def test_error_before_first_call(): - best_observation = extensions.BestValue('value', lambda a, b: a < b) + best_observation = extensions.BestValue("value", lambda a, b: a < b) with pytest.raises(RuntimeError): best_observation.best_value with pytest.raises(RuntimeError): diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py index fb27501eb..a38b3f6aa 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py @@ -2,26 +2,24 @@ import os import tempfile -import torch import pytest - -from pytorch_pfn_extras import distributed -from pytorch_pfn_extras import training +import torch +from pytorch_pfn_extras import distributed, training from pytorch_pfn_extras.training import extensions def _create_distributed_model(gpu=True): comm_size, comm_rank, comm_local_rank, device = _init_distributed(True) - device = torch.device('cuda:{}'.format(comm_local_rank) - if gpu else 'cpu') + device = torch.device("cuda:{}".format(comm_local_rank) if gpu else "cpu") model = torch.nn.Linear(128, 1) if torch.distributed.is_initialized(): if not gpu: raise pytest.skip("Distributed tests require GPUs.") model = torch.nn.parallel.DistributedDataParallel( - model.to(device), device_ids=[comm_local_rank]) + model.to(device), device_ids=[comm_local_rank] + ) else: model = model.to(device) @@ -32,27 +30,27 @@ def get_trainer(path): epochs = 10 # FIXME model = _create_distributed_model() optimizer = torch.optim.SGD(model.parameters(), lr=1.0) - optimizers = {'main': optimizer} - models = {'main': model} + optimizers = {"main": optimizer} + models = {"main": model} return training.ExtensionsManager( - models, optimizers, epochs, iters_per_epoch=1, out_dir=path) + models, optimizers, epochs, iters_per_epoch=1, out_dir=path + ) def _init_distributed(use_cuda): - if ('OMPI_COMM_WORLD_SIZE' in os.environ): - size, rank, local_rank = ( - distributed.initialize_ompi_environment( - backend="nccl", init_method="env")) + if "OMPI_COMM_WORLD_SIZE" in os.environ: + size, rank, local_rank = distributed.initialize_ompi_environment( + backend="nccl", init_method="env" + ) else: pytest.skip("This test requires MPI to run") - device = torch.device( - "cuda:{}".format(local_rank) if use_cuda else "cpu") + device = torch.device("cuda:{}".format(local_rank) if use_cuda else "cpu") return size, rank, local_rank, device -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def path(): with tempfile.TemporaryDirectory() as t_path: yield t_path @@ -73,11 +71,11 @@ def test_distributed_snapshot(path): torch.distributed.barrier() saver_rank = 0 - fmt = 'snapshot_iter_{.iteration}' + fmt = "snapshot_iter_{.iteration}" snapshot = extensions.snapshot(filename=fmt, saver_rank=saver_rank) trainer = get_trainer(path) - trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2) + trainer.extend(snapshot, trigger=(1, "iteration"), priority=2) for _ in range(1): with trainer.run_iteration(): pass @@ -91,7 +89,8 @@ def test_distributed_snapshot(path): new_trainer = get_trainer(path) new_trainer.load_state_dict(torch.load(os.path.join(path, found[0]))) assert _model_params_equal( - trainer._models['main'], new_trainer._models['main']) + trainer._models["main"], new_trainer._models["main"] + ) if comm_size > 1: torch.distributed.barrier() diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py index 54e9be181..60e4957a1 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py @@ -2,38 +2,35 @@ import numpy import pytest -import torch -import torch.distributed as dist - import pytorch_pfn_extras as ppe import pytorch_pfn_extras.training.extensions as ext +import torch +import torch.distributed as dist class DummyModel(torch.nn.Module): - def __init__(self): super().__init__() self.args = [] def forward(self, x): self.args.append(x) - ppe.reporting.report({'loss': x.sum()}, self) + ppe.reporting.report({"loss": x.sum()}, self) return x def custom_metric(batch, out, last_iter): - ppe.reporting.report({'custom-metric': out.sum()}) + ppe.reporting.report({"custom-metric": out.sum()}) class DummyModelTwoArgs(torch.nn.Module): - def __init__(self): super().__init__() self.args = [] def forward(self, x, y): self.args.append((x, y)) - ppe.reporting.report({'loss': x.sum() + y.sum()}, self) + ppe.reporting.report({"loss": x.sum() + y.sum()}, self) def _torch_batch_to_numpy(batch): @@ -43,11 +40,11 @@ def _torch_batch_to_numpy(batch): return batch.squeeze(0).numpy() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def evaluator_dummies(): data = [ - numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f') - for _ in range(2)] + numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f") for _ in range(2) + ] data_loader = torch.utils.data.DataLoader(data) target = DummyModel() @@ -60,7 +57,7 @@ def test_evaluate(evaluator_dummies): data, data_loader, target, evaluator, expect_mean = evaluator_dummies reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: mean = evaluator.evaluate() @@ -71,10 +68,12 @@ def test_evaluate(evaluator_dummies): assert len(target.args) == len(data) for i in range(len(data)): numpy.testing.assert_array_equal( - _torch_batch_to_numpy(target.args[i]), data[i]) + _torch_batch_to_numpy(target.args[i]), data[i] + ) numpy.testing.assert_almost_equal( - mean['target/loss'], expect_mean, decimal=4) + mean["target/loss"], expect_mean, decimal=4 + ) def test_metric(evaluator_dummies): @@ -82,9 +81,8 @@ def test_metric(evaluator_dummies): evaluator.add_metric(custom_metric) mean = evaluator() # 'main' is used by default - assert 'custom-metric' in mean - numpy.testing.assert_almost_equal( - mean['main/loss'], expect_mean, decimal=4) + assert "custom-metric" in mean + numpy.testing.assert_almost_equal(mean["main/loss"], expect_mean, decimal=4) def test_call(evaluator_dummies): @@ -92,18 +90,18 @@ def test_call(evaluator_dummies): mean = evaluator() # 'main' is used by default - numpy.testing.assert_almost_equal( - mean['main/loss'], expect_mean, decimal=4) + numpy.testing.assert_almost_equal(mean["main/loss"], expect_mean, decimal=4) def test_evaluator_name(evaluator_dummies): data, data_loader, target, evaluator, expect_mean = evaluator_dummies - evaluator.name = 'eval' + evaluator.name = "eval" mean = evaluator() # name is used as a prefix numpy.testing.assert_almost_equal( - mean['eval/main/loss'], expect_mean, decimal=4) + mean["eval/main/loss"], expect_mean, decimal=4 + ) def test_current_report(evaluator_dummies): @@ -118,16 +116,19 @@ def test_current_report(evaluator_dummies): def test_evaluator_tuple_data(): data = [ - (numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f'), - numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f')) - for _ in range(2)] + ( + numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"), + numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"), + ) + for _ in range(2) + ] data_loader = torch.utils.data.DataLoader(data) target = DummyModelTwoArgs() evaluator = ppe.training.extensions.Evaluator(data_loader, target) reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: mean = evaluator.evaluate() @@ -135,72 +136,83 @@ def test_evaluator_tuple_data(): for i in range(len(data)): assert len(target.args[i]) == len(data[i]) numpy.testing.assert_array_equal( - _torch_batch_to_numpy(target.args[i][0]), data[i][0]) + _torch_batch_to_numpy(target.args[i][0]), data[i][0] + ) numpy.testing.assert_array_equal( - _torch_batch_to_numpy(target.args[i][1]), data[i][1]) + _torch_batch_to_numpy(target.args[i][1]), data[i][1] + ) expect_mean = numpy.mean([numpy.sum(x) for x in data]) numpy.testing.assert_almost_equal( - mean['target/loss'], expect_mean, decimal=4) + mean["target/loss"], expect_mean, decimal=4 + ) def test_evaluator_dict_data(): data = [ - {'x': numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f'), - 'y': numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f')} - for _ in range(2)] + { + "x": numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"), + "y": numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"), + } + for _ in range(2) + ] data_loader = torch.utils.data.DataLoader(data) target = DummyModelTwoArgs() evaluator = ppe.training.extensions.Evaluator(data_loader, target) reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: mean = evaluator.evaluate() assert len(target.args) == len(data) for i in range(len(data)): numpy.testing.assert_array_equal( - _torch_batch_to_numpy(target.args[i][0]), data[i]['x']) + _torch_batch_to_numpy(target.args[i][0]), data[i]["x"] + ) numpy.testing.assert_array_equal( - _torch_batch_to_numpy(target.args[i][1]), data[i]['y']) + _torch_batch_to_numpy(target.args[i][1]), data[i]["y"] + ) expect_mean = numpy.mean( - [numpy.sum(x['x']) + numpy.sum(x['y']) for x in data]) + [numpy.sum(x["x"]) + numpy.sum(x["y"]) for x in data] + ) numpy.testing.assert_almost_equal( - mean['target/loss'], expect_mean, decimal=4) + mean["target/loss"], expect_mean, decimal=4 + ) def test_evaluator_with_eval_func(): - data = [ - numpy.random.uniform(-1, 1, (3, 4)).astype('f') for _ in range(2)] + data = [numpy.random.uniform(-1, 1, (3, 4)).astype("f") for _ in range(2)] data_loader = torch.utils.data.DataLoader(data) target = DummyModel() evaluator = ppe.training.extensions.Evaluator( - data_loader, {}, eval_func=target) + data_loader, {}, eval_func=target + ) reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: evaluator.evaluate() assert len(target.args) == len(data) for i in range(len(data)): numpy.testing.assert_array_equal( - _torch_batch_to_numpy(target.args[i]), data[i]) + _torch_batch_to_numpy(target.args[i]), data[i] + ) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def evaluator_with_progress(): - data = [ - numpy.random.uniform(-1, 1, (3, 4)).astype('f') for _ in range(2)] + data = [numpy.random.uniform(-1, 1, (3, 4)).astype("f") for _ in range(2)] data_loader = torch.utils.data.DataLoader(data, batch_size=1) target = DummyModel() evaluator = ppe.training.extensions.Evaluator( - data_loader, {}, eval_func=target, progress_bar=True) + data_loader, {}, eval_func=target, progress_bar=True + ) return evaluator, target @@ -208,43 +220,43 @@ def test_evaluator_progress_bar(capsys, evaluator_with_progress): evaluator, target = evaluator_with_progress reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: evaluator.evaluate() stdout = capsys.readouterr().out - assert 'validation [.........' in stdout - assert ' 0 iterations' in stdout - assert ' inf iters/sec.' in stdout - assert 'validation [#########' in stdout - assert ' 1 iterations' in stdout + assert "validation [........." in stdout + assert " 0 iterations" in stdout + assert " inf iters/sec." in stdout + assert "validation [#########" in stdout + assert " 1 iterations" in stdout def test_evaluator_progress_bar_custom_label(capsys, evaluator_with_progress): evaluator, target = evaluator_with_progress - evaluator.name = 'my_own_evaluator' # Set a (long) custom name + evaluator.name = "my_own_evaluator" # Set a (long) custom name reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: evaluator.evaluate() stdout = capsys.readouterr().out - assert 'my_own_evaluator [.........' in stdout - assert ' 0 iterations' in stdout - assert ' inf iters/sec.' in stdout - assert 'my_own_evaluator [#########' in stdout - assert ' 1 iterations' in stdout + assert "my_own_evaluator [........." in stdout + assert " 0 iterations" in stdout + assert " inf iters/sec." in stdout + assert "my_own_evaluator [#########" in stdout + assert " 1 iterations" in stdout # Code excerpts to test IgniteEvaluator class IgniteDummyModel(torch.nn.Module): def __init__(self): super(IgniteDummyModel, self).__init__() - self.count = 0. + self.count = 0.0 def forward(self, *args): - ppe.reporting.report({'x': self.count}, self) - self.count += 1. - return 0. + ppe.reporting.report({"x": self.count}, self) + self.count += 1.0 + return 0.0 def create_dummy_evaluator(model): @@ -265,7 +277,7 @@ def test_ignite_evaluator_reporting_metrics(): try: from ignite.metrics import MeanSquaredError except ImportError: - pytest.skip('pytorch-ignite is not installed') + pytest.skip("pytorch-ignite is not installed") # This tests verifies that either, usuer manually reported metrics # and ignite calculated ones are correctly reflected in the reporter @@ -279,7 +291,7 @@ def test_ignite_evaluator_reporting_metrics(): evaluator = create_dummy_evaluator(model) # Attach metrics to the evaluator metric = MeanSquaredError() - metric.attach(evaluator, 'mse') + metric.attach(evaluator, "mse") evaluator_ignite_ext = ppe.training.extensions.IgniteEvaluator( evaluator, loader, model, progress_bar=False ) @@ -287,51 +299,58 @@ def test_ignite_evaluator_reporting_metrics(): with reporter: result = evaluator_ignite_ext() # Internally reported metrics - assert result['main/x'] == 1.5 + assert result["main/x"] == 1.5 # Ignite calculated metric - assert result['val/mse'] == 0.0 + assert result["val/mse"] == 0.0 def test_distributed_evaluation(): - dummy_data = [] # Note: has no effect to the evaluation + dummy_data = [] # Note: has no effect to the evaluation data_loader = torch.utils.data.DataLoader(dummy_data) target = DummyModel() - with mock.patch.object(dist, 'is_initialized', return_value=True): + with mock.patch.object(dist, "is_initialized", return_value=True): evaluator = ext.DistributedEvaluator(data_loader, target) # Make a (simulated) summary for each rank worker_evaluations = [ - [1., 2., 3.], # assuming rank=0 - [4., 5., 6.], # rank=1 - [7., 8., 9.], # ... - [10., 11., 12.], + [1.0, 2.0, 3.0], # assuming rank=0 + [4.0, 5.0, 6.0], # rank=1 + [7.0, 8.0, 9.0], # ... + [10.0, 11.0, 12.0], ] worker_summaries = [] for accs in worker_evaluations: s = ppe.reporting.DictSummary() for acc in accs: - s.add({'target/score': acc}) + s.add({"target/score": acc}) worker_summaries.append(s) reporter = ppe.reporting.Reporter() - reporter.add_observer('target', target) + reporter.add_observer("target", target) with reporter: - with mock.patch.object(ppe.training.extensions.evaluator, '_dist_gather', - return_value=worker_summaries): + with mock.patch.object( + ppe.training.extensions.evaluator, + "_dist_gather", + return_value=worker_summaries, + ): mean = evaluator.evaluate() - assert mean['target/score'] == 6.5 + assert mean["target/score"] == 6.5 def test_distributed_evaluator_progress_bar(): - with mock.patch.object(dist, 'is_initialized', return_value=True): + with mock.patch.object(dist, "is_initialized", return_value=True): data_loader = torch.utils.data.DataLoader([]) target = DummyModel() - with mock.patch.object(dist, 'get_rank', return_value=0): - evaluator = ext.DistributedEvaluator(data_loader, target, progress_bar=True) + with mock.patch.object(dist, "get_rank", return_value=0): + evaluator = ext.DistributedEvaluator( + data_loader, target, progress_bar=True + ) assert evaluator._progress_bar # rank != 0 will forcibly set progress_bar=False - with mock.patch.object(dist, 'get_rank', return_value=1): - evaluator = ext.DistributedEvaluator(data_loader, target, progress_bar=True) + with mock.patch.object(dist, "get_rank", return_value=1): + evaluator = ext.DistributedEvaluator( + data_loader, target, progress_bar=True + ) assert not evaluator._progress_bar diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py index 5284b2289..c086b63bb 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py @@ -1,9 +1,7 @@ -import pytest import numpy +import pytest import torch - from pytorch_pfn_extras import training - from pytorch_pfn_extras.training.extensions import FailOnNonNumber @@ -16,8 +14,9 @@ def __len__(self): return len(self.values) def __getitem__(self, idx): - return numpy.array( - [self.values[idx]], numpy.float32), numpy.int64(idx % 2) + return numpy.array([self.values[idx]], numpy.float32), numpy.int64( + idx % 2 + ) class Model(torch.nn.Module): @@ -40,18 +39,18 @@ def forward(ctx, i): @staticmethod def backward(ctx, grad_output): - return grad_output + float('nan') + return grad_output + float("nan") def get_manager_model_optimizer(*, check_grad=True, grad_error=False): epochs = 3 model = Model(grad_error) optimizer = torch.optim.SGD(model.parameters(), lr=1.0) - optimizers = {'main': optimizer} - models = {'main': model} + optimizers = {"main": optimizer} + models = {"main": model} manager = training.ExtensionsManager( - models, optimizers, epochs, - iters_per_epoch=4) + models, optimizers, epochs, iters_per_epoch=4 + ) manager.extend(FailOnNonNumber(check_grad=check_grad)) return manager, model, optimizer @@ -79,27 +78,29 @@ def test_valid(): def test_nan(): manager, model, optimizer = get_manager_model_optimizer() with torch.no_grad(): - model.l1.weight[1, 0] = float('NaN') - with pytest.raises(RuntimeError, match='diverge'): + model.l1.weight[1, 0] = float("NaN") + with pytest.raises(RuntimeError, match="diverge"): run_train(manager, model, optimizer) def test_inf(): manager, model, optimizer = get_manager_model_optimizer() with torch.no_grad(): - model.l1.weight[2, 0] = float('inf') - with pytest.raises(RuntimeError, match='diverge'): + model.l1.weight[2, 0] = float("inf") + with pytest.raises(RuntimeError, match="diverge"): run_train(manager, model, optimizer) def test_check_grad(): manager, model, optimizer = get_manager_model_optimizer( - check_grad=True, grad_error=True) - with pytest.raises(RuntimeError, match='diverge'): + check_grad=True, grad_error=True + ) + with pytest.raises(RuntimeError, match="diverge"): run_train(manager, model, optimizer, optimizer_step=False) def test_no_check_grad(): manager, model, optimizer = get_manager_model_optimizer( - check_grad=False, grad_error=True) + check_grad=False, grad_error=True + ) run_train(manager, model, optimizer, optimizer_step=False) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py index 277b2626b..4e947c642 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py @@ -1,28 +1,30 @@ import io import os import tempfile -import yaml import pytorch_pfn_extras as ppe +import yaml from pytorch_pfn_extras.training import extensions -from pytorch_pfn_extras.training.extensions import log_report as log_report_module +from pytorch_pfn_extras.training.extensions import ( + log_report as log_report_module, +) def test_log_buffer(): buf = log_report_module._LogBuffer() looker = buf.emit_new_looker() assert buf.size() == 0 - buf.append('mes1') - buf.append('mes2') + buf.append("mes1") + buf.append("mes2") assert buf.size() == 2 - assert looker.get() == ['mes1', 'mes2'] + assert looker.get() == ["mes1", "mes2"] assert buf.size() == 2 looker.clear() assert buf.size() == 0 assert looker.get() == [] - buf.append('mes3') + buf.append("mes3") assert buf.size() == 1 - assert looker.get() == ['mes3'] + assert looker.get() == ["mes3"] assert buf.size() == 1 looker.clear() assert buf.size() == 0 @@ -33,15 +35,15 @@ def test_log_buffer_multiple_lookers(): buf = log_report_module._LogBuffer() looker1 = buf.emit_new_looker() looker2 = buf.emit_new_looker() - buf.append('mes1') - assert looker1.get() == ['mes1'] - assert looker2.get() == ['mes1'] + buf.append("mes1") + assert looker1.get() == ["mes1"] + assert looker2.get() == ["mes1"] assert buf.size() == 1 looker2.clear() assert buf.size() == 1 - buf.append('mes2') - assert looker1.get() == ['mes1', 'mes2'] - assert looker2.get() == ['mes2'] + buf.append("mes2") + assert looker1.get() == ["mes1", "mes2"] + assert looker2.get() == ["mes2"] assert buf.size() == 2 looker2.clear() assert buf.size() == 2 @@ -55,11 +57,13 @@ def test_buffer_size_log_report(): with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir + ) log_report = extensions.LogReport( - filename='out', format='yaml', append=True) - manager.extend(log_report, (1, 'epoch')) + filename="out", format="yaml", append=True + ) + manager.extend(log_report, (1, "epoch")) for _ in range(max_epochs): for _ in range(iters_per_epoch): @@ -67,7 +71,7 @@ def test_buffer_size_log_report(): with manager.run_iteration(): pass - with open(os.path.join(tmpdir, 'out')) as f: + with open(os.path.join(tmpdir, "out")) as f: data = f.read() values = yaml.load(data, Loader=yaml.SafeLoader) assert len(values) == max_epochs @@ -79,15 +83,17 @@ def test_buffer_size_log_report_and_print_report(): with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir + ) log_report = extensions.LogReport( - filename='out', format='yaml', append=True) - manager.extend(log_report, trigger=(1, 'epoch')) + filename="out", format="yaml", append=True + ) + manager.extend(log_report, trigger=(1, "epoch")) out = io.StringIO() print_report = extensions.PrintReport(out=out) - manager.extend(print_report, trigger=(3, 'epoch')) + manager.extend(print_report, trigger=(3, "epoch")) for _ in range(max_epochs): for _ in range(iters_per_epoch): @@ -95,7 +101,7 @@ def test_buffer_size_log_report_and_print_report(): with manager.run_iteration(): pass - with open(os.path.join(tmpdir, 'out')) as f: + with open(os.path.join(tmpdir, "out")) as f: data = f.read() values = yaml.load(data, Loader=yaml.SafeLoader) assert len(values) == max_epochs diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py index 221500b8b..f241b9fdf 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py @@ -3,21 +3,20 @@ import tempfile import pytest -import yaml - import pytorch_pfn_extras as ppe +import yaml from pytorch_pfn_extras.training import extensions @pytest.mark.parametrize( - 'filename,expected_format', + "filename,expected_format", [ - ('out.json', 'json'), - ('out.xyz', 'json'), - (None, 'json'), - ('out.yaml', 'yaml'), - ('out.jsonl', 'json-lines'), - ] + ("out.json", "json"), + ("out.xyz", "json"), + (None, "json"), + ("out.yaml", "yaml"), + ("out.jsonl", "json-lines"), + ], ) def test_format_from_ext(filename, expected_format): log_report = extensions.LogReport(filename=filename, format=None) @@ -25,14 +24,14 @@ def test_format_from_ext(filename, expected_format): @pytest.mark.parametrize( - 'format,append', + "format,append", [ - ('json', False), - ('json-lines', True), - ('json-lines', False), - ('yaml', True), - ('yaml', False), - ] + ("json", False), + ("json-lines", True), + ("json-lines", False), + ("yaml", True), + ("yaml", False), + ], ) def test_output(format, append): max_epochs = 3 @@ -40,34 +39,42 @@ def test_output(format, append): with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch, - out_dir=tmpdir) + {}, + {}, + max_epochs=max_epochs, + iters_per_epoch=iters_per_epoch, + out_dir=tmpdir, + ) log_report = extensions.LogReport( - filename='out', format=format, append=append) + filename="out", format=format, append=append + ) manager.extend(log_report) for epoch_idx in range(max_epochs): for _ in range(iters_per_epoch): with manager.run_iteration(): pass - with open(os.path.join(tmpdir, 'out')) as f: + with open(os.path.join(tmpdir, "out")) as f: data = f.read() - if format == 'json': + if format == "json": values = json.loads(data) - elif format == 'json-lines': + elif format == "json-lines": values = [json.loads(x) for x in data.splitlines()] - elif format == 'yaml': + elif format == "yaml": values = yaml.load(data, Loader=yaml.SafeLoader) assert len(values) == epoch_idx + 1 this_epoch = values.pop() - assert this_epoch['epoch'] == epoch_idx + 1 - assert (this_epoch['iteration'] - == (epoch_idx + 1) * iters_per_epoch) - assert 0 < this_epoch['elapsed_time'] + assert this_epoch["epoch"] == epoch_idx + 1 + assert ( + this_epoch["iteration"] == (epoch_idx + 1) * iters_per_epoch + ) + assert 0 < this_epoch["elapsed_time"] -@pytest.mark.filterwarnings("ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning") +@pytest.mark.filterwarnings( + "ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning" +) def test_tensorboard_writer(): - pytest.importorskip('tensorboard') + pytest.importorskip("tensorboard") max_epochs = 3 iters_per_epoch = 5 @@ -75,10 +82,15 @@ def test_tensorboard_writer(): with tempfile.TemporaryDirectory() as tmpdir: writer = ppe.writing.TensorBoardWriter(out_dir=tmpdir) log_report = extensions.LogReport( - writer=writer, trigger=(1, 'iteration')) + writer=writer, trigger=(1, "iteration") + ) manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch, - out_dir=tmpdir) + {}, + {}, + max_epochs=max_epochs, + iters_per_epoch=iters_per_epoch, + out_dir=tmpdir, + ) manager.extend(log_report) for _ in range(max_epochs): for _ in range(iters_per_epoch): @@ -89,13 +101,13 @@ def test_tensorboard_writer(): files = os.listdir(tmpdir) assert len(files) == 1 tb_file = files[0] - assert tb_file.startswith('events.out.') + assert tb_file.startswith("events.out.") # Won't play with protobuf, just ensure that our keys are in. - with open(os.path.join(tmpdir, tb_file), 'rb') as f: + with open(os.path.join(tmpdir, tb_file), "rb") as f: tb_data = f.read() - for key in ['epoch', 'iteration', 'elapsed_time']: - assert key.encode('ascii') in tb_data + for key in ["epoch", "iteration", "elapsed_time"]: + assert key.encode("ascii") in tb_data def test_deferred_values(): @@ -104,15 +116,19 @@ def test_deferred_values(): with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch, - out_dir=tmpdir) + {}, + {}, + max_epochs=max_epochs, + iters_per_epoch=iters_per_epoch, + out_dir=tmpdir, + ) log_report = extensions.LogReport(filename="out") manager.extend(log_report) for epoch_idx in range(max_epochs): for _ in range(iters_per_epoch): with manager.run_iteration(): ppe.reporting.report({"x": lambda: epoch_idx}) - with open(os.path.join(tmpdir, 'out')) as f: + with open(os.path.join(tmpdir, "out")) as f: data = f.read() values = json.loads(data) assert len(values) == epoch_idx + 1 diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py index 24612c896..d0be24cd2 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py @@ -1,9 +1,8 @@ import tempfile import pytest -import torch - import pytorch_pfn_extras as ppe +import torch def _setup_manager(): @@ -12,63 +11,73 @@ def _setup_manager(): sched = torch.optim.lr_scheduler.MultiStepLR( optim, milestones=[1, 2, 3], gamma=0.1, last_epoch=-1 ) - ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, 'iteration')) + ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, "iteration")) manager = ppe.training.ExtensionsManager( - {}, {'main': optim}, 1, extensions=[ext], iters_per_epoch=40) + {}, {"main": optim}, 1, extensions=[ext], iters_per_epoch=40 + ) return optim, manager def test_lr_scheduler(): optim, manager = _setup_manager() for i in range(4): - with manager.run_iteration(step_optimizers=['main']): + with manager.run_iteration(step_optimizers=["main"]): if i < 1: - assert optim.param_groups[0]['lr'] == pytest.approx(1.0) + assert optim.param_groups[0]["lr"] == pytest.approx(1.0) elif i < 2: - assert optim.param_groups[0]['lr'] == pytest.approx(1e-1) + assert optim.param_groups[0]["lr"] == pytest.approx(1e-1) elif i < 3: - assert optim.param_groups[0]['lr'] == pytest.approx(1e-2) + assert optim.param_groups[0]["lr"] == pytest.approx(1e-2) elif i < 4: - assert optim.param_groups[0]['lr'] == pytest.approx(1e-3) + assert optim.param_groups[0]["lr"] == pytest.approx(1e-3) def test_serialize_scheduler(): optim, manager = _setup_manager() for i in range(2): - with manager.run_iteration(step_optimizers=['main']): + with manager.run_iteration(step_optimizers=["main"]): if i < 1: - assert optim.param_groups[0]['lr'] == pytest.approx(1.0) + assert optim.param_groups[0]["lr"] == pytest.approx(1.0) else: - assert optim.param_groups[0]['lr'] == pytest.approx(1e-1) + assert optim.param_groups[0]["lr"] == pytest.approx(1e-1) state = manager.state_dict() optim, manager = _setup_manager() manager.load_state_dict(state) for i in range(2): - with manager.run_iteration(step_optimizers=['main']): + with manager.run_iteration(step_optimizers=["main"]): if i < 1: - assert optim.param_groups[0]['lr'] == pytest.approx(1e-2) + assert optim.param_groups[0]["lr"] == pytest.approx(1e-2) else: - assert optim.param_groups[0]['lr'] == pytest.approx(1e-3) + assert optim.param_groups[0]["lr"] == pytest.approx(1e-3) def test_reduce_lr_on_plateau(): param = torch.nn.Parameter(torch.zeros(10)) optim = torch.optim.SGD([param], 1.0) sched = torch.optim.lr_scheduler.ReduceLROnPlateau( - optim, patience=1, verbose=True) - ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, 'iteration')) + optim, patience=1, verbose=True + ) + ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, "iteration")) with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {'main': optim}, 1, extensions=[ext], iters_per_epoch=4, - out_dir=tmpdir) - manager.extend(ppe.training.extensions.LogReport( - filename=None, trigger=(1, "iteration"))) + {}, + {"main": optim}, + 1, + extensions=[ext], + iters_per_epoch=4, + out_dir=tmpdir, + ) + manager.extend( + ppe.training.extensions.LogReport( + filename=None, trigger=(1, "iteration") + ) + ) for _ in range(4): with manager.run_iteration(): - ppe.reporting.report({'val/loss': 1.0}) + ppe.reporting.report({"val/loss": 1.0}) lr = optim.param_groups[0]["lr"] assert lr == pytest.approx(1e-1) @@ -77,11 +86,13 @@ def test_reduce_lr_on_plateau_no_report(): param = torch.nn.Parameter(torch.zeros(10)) optim = torch.optim.SGD([param], 1.0) sched = torch.optim.lr_scheduler.ReduceLROnPlateau( - optim, patience=1, verbose=True) - ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, 'iteration')) + optim, patience=1, verbose=True + ) + ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, "iteration")) manager = ppe.training.ExtensionsManager( - {}, {'main': optim}, 1, extensions=[ext], iters_per_epoch=4) + {}, {"main": optim}, 1, extensions=[ext], iters_per_epoch=4 + ) with pytest.raises(ValueError): with manager.run_iteration(): pass diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py index 339f0e1bb..bc7eff230 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py @@ -1,5 +1,4 @@ import numpy - import pytorch_pfn_extras as ppe @@ -11,22 +10,26 @@ def test_run(): # NumPy<1.17 does not support array-like inputs in `numpy.random.randint`. data_correct = numpy.random.randint(10000, size=data_shape) % data_total - manager = ppe.training.ExtensionsManager( - {}, [], 100, - iters_per_epoch=5) + manager = ppe.training.ExtensionsManager({}, [], 100, iters_per_epoch=5) extension = ppe.training.extensions.MicroAverage( - 'main/correct', 'main/total', 'main/accuracy', - (trigger_iters, 'iteration')) - manager.extend(extension, trigger=(1, 'iteration')) + "main/correct", + "main/total", + "main/accuracy", + (trigger_iters, "iteration"), + ) + manager.extend(extension, trigger=(1, "iteration")) for js in numpy.ndindex(data_shape): with manager.run_iteration(): - ppe.reporting.report({ - 'main/correct': data_correct[js], - 'main/total': data_total[js], - }) + ppe.reporting.report( + { + "main/correct": data_correct[js], + "main/total": data_total[js], + } + ) assert ( # average is computed every trigger_iters - ('main/accuracy' in manager.observation) - == (js[1] == trigger_iters - 1)) + ("main/accuracy" in manager.observation) + == (js[1] == trigger_iters - 1) + ) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py index e9bfa81b6..2d56f1a2a 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py @@ -1,7 +1,6 @@ import warnings import pytest - from pytorch_pfn_extras.training import extensions @@ -9,6 +8,7 @@ def matplotlib_or_none(): try: import matplotlib + return matplotlib except ImportError: return None @@ -17,7 +17,7 @@ def matplotlib_or_none(): @pytest.fixture(scope="module") def matplotlib(matplotlib_or_none): if matplotlib_or_none is None: - pytest.skip('matplotlib is not installed') + pytest.skip("matplotlib is not installed") return matplotlib_or_none @@ -36,9 +36,9 @@ def test_lazy_import(matplotlib): # has to be called earlier. with warnings.catch_warnings(): - warnings.simplefilter('error') - matplotlib.use('Agg') + warnings.simplefilter("error") + matplotlib.use("Agg") # Test again with a different backend, because the above does not # generate a warning if matplotlib.use('Agg') is called and then # matplotlib.pyplot is imported. - matplotlib.use('PS') + matplotlib.use("PS") diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py index 8b0f013bd..fa4e5d672 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py @@ -8,7 +8,8 @@ def test_run_print_report(): max_epochs = 5 iters_per_epoch = 5 manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch + ) out = io.StringIO() log_report = extensions.LogReport() diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py index c4d91f431..3007b6b40 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py @@ -1,7 +1,6 @@ import io import pytest - import pytorch_pfn_extras as ppe from pytorch_pfn_extras.training.extensions import _ipython_module_available from pytorch_pfn_extras.training.extensions.log_report import _pandas_available @@ -10,13 +9,14 @@ @pytest.mark.skipif( not _ipython_module_available or not _pandas_available, reason="print report notebook import failed, " - "maybe ipython is not installed" + "maybe ipython is not installed", ) def test_run_print_report_notebook(): max_epochs = 5 iters_per_epoch = 5 manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch + ) out = io.StringIO() log_report = ppe.training.extensions.LogReport() @@ -32,5 +32,5 @@ def test_run_print_report_notebook(): pass -if __name__ == '__main__': - pytest.main([__file__, '-v', '-s']) +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py index a3844cc65..789f39a5e 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py @@ -1,12 +1,11 @@ +import json +import os import tempfile import time -import os -import json import pytest -import yaml - import pytorch_pfn_extras as ppe +import yaml def _body(): @@ -15,14 +14,14 @@ def _body(): @pytest.mark.parametrize( - 'format,append', + "format,append", [ - ('json', False), - ('json-lines', True), - ('json-lines', False), - ('yaml', True), - ('yaml', False), - ] + ("json", False), + ("json-lines", True), + ("json-lines", False), + ("yaml", True), + ("yaml", False), + ], ) def test_profile_report(format, append): ext = ppe.training.extensions.ProfileReport(format=format, append=append) @@ -31,22 +30,26 @@ def test_profile_report(format, append): # ppe.profiler.time_summary.clear() with tempfile.TemporaryDirectory() as tmpdir: manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch, - out_dir=tmpdir) + {}, + {}, + max_epochs=max_epochs, + iters_per_epoch=iters_per_epoch, + out_dir=tmpdir, + ) manager.extend(ext) for _epoch_idx in range(max_epochs): for _ in range(iters_per_epoch): with manager.run_iteration(): _body() - with open(os.path.join(tmpdir, 'log')) as f: + with open(os.path.join(tmpdir, "log")) as f: data = f.read() - if format == 'json': + if format == "json": values = json.loads(data) - elif format == 'json-lines': + elif format == "json-lines": values = [json.loads(x) for x in data.splitlines()] - elif format == 'yaml': + elif format == "yaml": values = yaml.load(data, Loader=yaml.SafeLoader) assert len(values) == _epoch_idx + 1 for value in values: - assert abs(value['iter-time'] - 0.1) < 2e-2 + assert abs(value["iter-time"] - 0.1) < 2e-2 diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py index 84a0c7246..cb7874dce 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py @@ -9,7 +9,8 @@ def test_run(): max_epochs = 5 iters_per_epoch = 5 manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch + ) out = io.StringIO() extension = ppe.training.extensions.ProgressBar( @@ -27,8 +28,13 @@ def test_run(): if manager.iteration < 2: continue status = out.getvalue() - assert '{} iter, {} epoch / {} epochs'.format( - manager.iteration, epoch, max_epochs) in status + assert ( + "{} iter, {} epoch / {} epochs".format( + manager.iteration, epoch, max_epochs + ) + in status + ) iters_per_sec = float( - re.findall(r'([0-9]+\.[0-9]*) iters/sec', status)[-1]) + re.findall(r"([0-9]+\.[0-9]*) iters/sec", status)[-1] + ) assert 7 <= iters_per_sec <= 12 diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py index 57559b329..4496820bd 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py @@ -1,26 +1,26 @@ import io import pytest +import pytorch_pfn_extras as ppe import torch +from pytorch_pfn_extras import training +from pytorch_pfn_extras.training.extensions import _ipython_module_available from torch import nn from torch.nn import Linear from torch.optim.adam import Adam -import pytorch_pfn_extras as ppe -from pytorch_pfn_extras import training -from pytorch_pfn_extras.training.extensions import _ipython_module_available - @pytest.mark.skipif( not _ipython_module_available, reason="progress bar notebook import failed, " - "maybe ipython is not installed" + "maybe ipython is not installed", ) def test_run_progress_bar_notebook(): max_epochs = 5 iters_per_epoch = 5 manager = ppe.training.ExtensionsManager( - {}, {}, max_epochs, iters_per_epoch=iters_per_epoch) + {}, {}, max_epochs, iters_per_epoch=iters_per_epoch + ) out = io.StringIO() extension = ppe.training.extensions.ProgressBarNotebook( @@ -36,22 +36,22 @@ def test_run_progress_bar_notebook(): with manager.run_iteration(): if manager.iteration < 2: continue - status = '{} iter, {} epoch / {} epochs'.format( - manager.iteration, epoch, max_epochs) + status = "{} iter, {} epoch / {} epochs".format( + manager.iteration, epoch, max_epochs + ) assert status in extension._status_html.value @pytest.mark.skipif( not _ipython_module_available, reason="progress bar notebook import failed, " - "maybe ipython is not installed" + "maybe ipython is not installed", ) def test_ignite_extensions_manager_with_progressbar_notebook(): - try: from ignite.engine import create_supervised_trainer except ImportError: - pytest.skip('pytorch-ignite not found') + pytest.skip("pytorch-ignite not found") max_epochs = 5 iters_per_epoch = 4 @@ -70,21 +70,21 @@ def forward(self, *args): def _fake_loss(*args): return torch.tensor([0.0], requires_grad=True) - trainer = create_supervised_trainer( - model, optimizer, _fake_loss) + trainer = create_supervised_trainer(model, optimizer, _fake_loss) manager = training.IgniteExtensionsManager( trainer, - {'model_name': model}, - {'optimizer_name': optimizer}, + {"model_name": model}, + {"optimizer_name": optimizer}, max_epochs, ) manager.extend(ppe.training.extensions.ProgressBarNotebook()) loader = torch.utils.data.DataLoader( - [(i, i) for i in range(iters_per_epoch)]) + [(i, i) for i in range(iters_per_epoch)] + ) trainer.run(loader, max_epochs=max_epochs) -if __name__ == '__main__': - pytest.main([__file__, '-v', '-s']) +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py index 8dbce0b8e..d9f6e039f 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py @@ -3,150 +3,157 @@ from unittest import mock import pytest - import pytorch_pfn_extras as ppe @pytest.mark.skipif( not ppe.training.extensions.slack._slack_sdk_available, - reason="Slack SDK not installed" + reason="Slack SDK not installed", ) class TestSlack: def _get_manager(self): return ppe.training.ExtensionsManager({}, [], 1, iters_per_epoch=5) - @pytest.mark.parametrize('thread',[False, True]) + @pytest.mark.parametrize("thread", [False, True]) def test_post_message(self, thread): manager = self._get_manager() - message = 'It {manager.iteration} loss: {loss}' + message = "It {manager.iteration} loss: {loss}" extension = ppe.training.extensions.Slack( - '0', message, token='123', thread=thread) + "0", message, token="123", thread=thread + ) t_ts = None if thread: t_ts = 1 - manager.extend(extension, trigger=(1, 'iteration')) + manager.extend(extension, trigger=(1, "iteration")) with mock.patch( - 'slack_sdk.WebClient.chat_postMessage', - return_value={'ok': True, 'ts': t_ts}, + "slack_sdk.WebClient.chat_postMessage", + return_value={"ok": True, "ts": t_ts}, ) as patched: with manager.run_iteration(): - assert 'Training started' in patched.call_args.kwargs["text"] - ppe.reporting.report({'loss': 0.5}) + assert "Training started" in patched.call_args.kwargs["text"] + ppe.reporting.report({"loss": 0.5}) patched.assert_called_with( - channel='0', text='It 1 loss: 0.5', thread_ts=t_ts + channel="0", text="It 1 loss: 0.5", thread_ts=t_ts ) with manager.run_iteration(): - ppe.reporting.report({'loss': 0.75}) + ppe.reporting.report({"loss": 0.75}) patched.assert_called_with( - channel='0', text='It 2 loss: 0.75', thread_ts=t_ts + channel="0", text="It 2 loss: 0.75", thread_ts=t_ts ) with manager.run_iteration(): - ppe.reporting.report({'loss': 0.75}) + ppe.reporting.report({"loss": 0.75}) with manager.run_iteration(): - ppe.reporting.report({'loss': 0.75}) + ppe.reporting.report({"loss": 0.75}) with manager.run_iteration(): - ppe.reporting.report({'loss': 0.75}) - assert 'Training finish' in patched.call_args.kwargs["text"] + ppe.reporting.report({"loss": 0.75}) + assert "Training finish" in patched.call_args.kwargs["text"] def test_post_message_on_error(self): manager = self._get_manager() - message = 'It {manager.iteration} loss: {loss}' + message = "It {manager.iteration} loss: {loss}" extension = ppe.training.extensions.Slack( - '0', message, token='123', thread=False) + "0", message, token="123", thread=False + ) t_ts = None - manager.extend(extension, trigger=(1, 'iteration')) + manager.extend(extension, trigger=(1, "iteration")) with mock.patch( - 'slack_sdk.WebClient.chat_postMessage', - return_value={'ok': True, 'ts': t_ts}, + "slack_sdk.WebClient.chat_postMessage", + return_value={"ok": True, "ts": t_ts}, ) as patched: try: with manager.run_iteration(): - raise RuntimeError('error') + raise RuntimeError("error") except RuntimeError: - assert 'Error during' in patched.call_args.kwargs["text"] + assert "Error during" in patched.call_args.kwargs["text"] def test_post_message_webhook(self): manager = self._get_manager() - message = 'It {manager.iteration} loss: {loss}' + message = "It {manager.iteration} loss: {loss}" extension = ppe.training.extensions.SlackWebhook( - url="http://test", msg=message) + url="http://test", msg=message + ) - manager.extend(extension, trigger=(1, 'iteration')) - payload_1 = json.dumps({'text': "It 1 loss: 0.5"}).encode('utf-8') - payload_2 = json.dumps({'text': "It 2 loss: 0.75"}).encode('utf-8') + manager.extend(extension, trigger=(1, "iteration")) + payload_1 = json.dumps({"text": "It 1 loss: 0.5"}).encode("utf-8") + payload_2 = json.dumps({"text": "It 2 loss: 0.75"}).encode("utf-8") with mock.patch( - 'urllib.request.urlopen', + "urllib.request.urlopen", return_value=SimpleNamespace(status=200), ) as patched: with manager.run_iteration(): - ppe.reporting.report({'loss': 0.5}) + ppe.reporting.report({"loss": 0.5}) assert patched.call_args.args[0].data == payload_1 with manager.run_iteration(): - ppe.reporting.report({'loss': 0.75}) + ppe.reporting.report({"loss": 0.75}) assert patched.call_args.args[0].data == payload_2 @pytest.mark.parametrize( - 'message', + "message", [ - 'It {manager.iteration} loss: {loss} custom: {context.foo}', - lambda m, c: 'It {manager.iteration} loss: {loss} custom: {context.foo}'.format( # NOQA - manager=m, context=c, **m.observation) - ] + "It {manager.iteration} loss: {loss} custom: {context.foo}", + lambda m, c: "It {manager.iteration} loss: {loss} custom: {context.foo}".format( # NOQA + manager=m, context=c, **m.observation + ), + ], ) def test_post_message_context(self, message): class _CustomContext: def __init__(self): - self.foo = 'bar' + self.foo = "bar" manager = self._get_manager() context = _CustomContext() extension = ppe.training.extensions.Slack( - '0', message, context=context, token='123') - manager.extend(extension, trigger=(1, 'iteration')) + "0", message, context=context, token="123" + ) + manager.extend(extension, trigger=(1, "iteration")) with mock.patch( - 'slack_sdk.WebClient.chat_postMessage', - return_value={'ok': True, 'ts': 1}, + "slack_sdk.WebClient.chat_postMessage", + return_value={"ok": True, "ts": 1}, ) as patched: with manager.run_iteration(): - ppe.reporting.report({'loss': 0.5}) + ppe.reporting.report({"loss": 0.5}) patched.assert_called_with( - channel='0', text='It 1 loss: 0.5 custom: bar', - thread_ts=1 + channel="0", text="It 1 loss: 0.5 custom: bar", thread_ts=1 ) - context.foo = 'test' + context.foo = "test" with manager.run_iteration(): - ppe.reporting.report({'loss': 0.75}) + ppe.reporting.report({"loss": 0.75}) patched.assert_called_with( - channel='0', text='It 2 loss: 0.75 custom: test', - thread_ts=1 + channel="0", text="It 2 loss: 0.75 custom: test", thread_ts=1 ) def test_post_message_files(self): manager = self._get_manager() - message = 'it: {manager.iteration}' - filenames = ['file_{manager.iteration}', '{manager._out}/abc'] + message = "it: {manager.iteration}" + filenames = ["file_{manager.iteration}", "{manager._out}/abc"] extension = ppe.training.extensions.Slack( - '0', message, filenames=filenames, token='123') - manager.extend(extension, trigger=(1, 'iteration')) + "0", message, filenames=filenames, token="123" + ) + manager.extend(extension, trigger=(1, "iteration")) with mock.patch( - 'slack_sdk.WebClient.chat_postMessage', - return_value={'ok': True, 'ts': 1}, - ), mock.patch('slack_sdk.WebClient.files_upload') as upload: + "slack_sdk.WebClient.chat_postMessage", + return_value={"ok": True, "ts": 1}, + ), mock.patch("slack_sdk.WebClient.files_upload") as upload: with manager.run_iteration(): pass - upload.assert_has_calls([ - mock.call(file='file_1'), - mock.call(file='result/abc'), - ], any_order=True) + upload.assert_has_calls( + [ + mock.call(file="file_1"), + mock.call(file="result/abc"), + ], + any_order=True, + ) def test_invalid(self): - message = 'it: {manager.iteration}' - filenames = ['file_{manager.iteration}', '{manager._out}/abc'] - with pytest.raises(RuntimeError, match='needed for communicating'): + message = "it: {manager.iteration}" + filenames = ["file_{manager.iteration}", "{manager._out}/abc"] + with pytest.raises(RuntimeError, match="needed for communicating"): ppe.training.extensions.Slack( - '0', message, start_msg=None, end_msg=None, filenames=filenames) + "0", message, start_msg=None, end_msg=None, filenames=filenames + ) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py index 75675bb0f..dd1608bff 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py @@ -5,29 +5,29 @@ import time from unittest import mock -import torch import pytest - import pytorch_pfn_extras as ppe -from pytorch_pfn_extras import training +import torch +from pytorch_pfn_extras import training, writing from pytorch_pfn_extras.training import extensions from pytorch_pfn_extras.training.extensions._snapshot import ( - _find_latest_snapshot, _find_snapshot_files, _find_stale_snapshots) -from pytorch_pfn_extras import writing + _find_latest_snapshot, + _find_snapshot_files, + _find_stale_snapshots, +) def get_trainer(*, out_dir, state_to_load=None, epochs=10): model_state_dict = {} optimizer_state_dict = {} - models = {'main': _StateDictModel(state_dict=model_state_dict)} - optimizers = {'main': _StateDictObj(state_dict=optimizer_state_dict)} + models = {"main": _StateDictModel(state_dict=model_state_dict)} + optimizers = {"main": _StateDictObj(state_dict=optimizer_state_dict)} return training.ExtensionsManager( - models, optimizers, epochs, - iters_per_epoch=10, - out_dir=out_dir) + models, optimizers, epochs, iters_per_epoch=10, out_dir=out_dir + ) -class _StateDictObj(): +class _StateDictObj: def __init__(self, *, state_dict=None): super().__init__() self.called_load_state_dict = 0 @@ -62,7 +62,8 @@ def test_call(): def test_savefun_and_writer_exclusive(): # savefun and writer arguments cannot be specified together. def savefun(*args, **kwargs): - pytest.fail('never reach') + pytest.fail("never reach") + writer = writing.SimpleWriter() with pytest.raises(TypeError): extensions.snapshot(savefun=savefun, writer=writer) @@ -72,104 +73,98 @@ def savefun(*args, **kwargs): extensions.snapshot_object(trainer, savefun=savefun, writer=writer) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def remover(): yield - if os.path.exists('myfile.dat'): - os.remove('myfile.dat') + if os.path.exists("myfile.dat"): + os.remove("myfile.dat") def test_save_file(remover): - trainer = get_trainer(out_dir='.') + trainer = get_trainer(out_dir=".") trainer._done = True w = writing.SimpleWriter() - snapshot = extensions.snapshot_object(trainer, 'myfile.dat', - writer=w) + snapshot = extensions.snapshot_object(trainer, "myfile.dat", writer=w) snapshot(trainer) - assert os.path.exists('myfile.dat') + assert os.path.exists("myfile.dat") def test_multi_target(remover): - trainer = get_trainer(out_dir='.') + trainer = get_trainer(out_dir=".") trainer._done = True - other_state_dict = {'test': True} + other_state_dict = {"test": True} other = _StateDictObj(state_dict=other_state_dict) w = ppe.writing.SimpleWriter() - target = {'trainer': trainer, 'other': other} - snapshot = extensions.snapshot_object(target, 'myfile.dat', - writer=w) + target = {"trainer": trainer, "other": other} + snapshot = extensions.snapshot_object(target, "myfile.dat", writer=w) snapshot(trainer) - assert os.path.exists('myfile.dat') + assert os.path.exists("myfile.dat") # Load the snapshot and verify it - state = torch.load('myfile.dat') - new_trainer = get_trainer(out_dir='.') + state = torch.load("myfile.dat") + new_trainer = get_trainer(out_dir=".") new_other = _StateDictObj(state_dict={}) - new_trainer.load_state_dict(state['trainer']) - new_other.load_state_dict(state['other']) + new_trainer.load_state_dict(state["trainer"]) + new_other.load_state_dict(state["other"]) assert new_trainer.state_dict() == trainer.state_dict() assert new_other.state_dict() == other_state_dict def test_multi_target_autoload(remover): - trainer = get_trainer(out_dir='.') + trainer = get_trainer(out_dir=".") trainer._done = True - other_state_dict = {'test': True} + other_state_dict = {"test": True} other = _StateDictObj(state_dict=other_state_dict) w = ppe.writing.SimpleWriter() - target = {'trainer': trainer, 'other': other} - snapshot = extensions.snapshot_object(target, 'myfile.dat', - writer=w) + target = {"trainer": trainer, "other": other} + snapshot = extensions.snapshot_object(target, "myfile.dat", writer=w) snapshot(trainer) - assert os.path.exists('myfile.dat') - new_trainer = get_trainer(out_dir='.') + assert os.path.exists("myfile.dat") + new_trainer = get_trainer(out_dir=".") new_other = _StateDictObj(state_dict={}) - target = {'trainer': new_trainer, 'other': new_other} - snapshot2 = extensions.snapshot_object(target, 'myfile.dat', - autoload=True) + target = {"trainer": new_trainer, "other": new_other} + snapshot2 = extensions.snapshot_object(target, "myfile.dat", autoload=True) # Load the snapshot and verify it - assert snapshot2.initialize(new_trainer) == 'myfile.dat' + assert snapshot2.initialize(new_trainer) == "myfile.dat" assert new_trainer.state_dict() == trainer.state_dict() assert new_other.state_dict() == other_state_dict def test_multi_target_autoload_not_found(remover): - trainer = get_trainer(out_dir='.') - other = _StateDictObj(state_dict={'original': 'state'}) + trainer = get_trainer(out_dir=".") + other = _StateDictObj(state_dict={"original": "state"}) - target = {'trainer': trainer, 'other': other} - snapshot = extensions.snapshot_object(target, 'myfile.dat', - autoload=True) + target = {"trainer": trainer, "other": other} + snapshot = extensions.snapshot_object(target, "myfile.dat", autoload=True) assert snapshot.initialize(trainer) is None - assert other.state_dict() == {'original': 'state'} + assert other.state_dict() == {"original": "state"} def test_clean_up_tempdir(remover): - trainer = get_trainer(out_dir='.') + trainer = get_trainer(out_dir=".") trainer._done = True - snapshot = extensions.snapshot_object(trainer, 'myfile.dat') + snapshot = extensions.snapshot_object(trainer, "myfile.dat") snapshot(trainer) - left_tmps = [fn for fn in os.listdir('.') - if fn.startswith('tmpmyfile.dat')] + left_tmps = [fn for fn in os.listdir(".") if fn.startswith("tmpmyfile.dat")] assert len(left_tmps) == 0 def test_on_error(): # Will fail when accesing the dummy optimizer - optimizers = {'main': object()} + optimizers = {"main": object()} trainer = training.ExtensionsManager( - {}, optimizers, 1, - iters_per_epoch=1, - out_dir='.') - filename = 'myfile-deadbeef.dat' + {}, optimizers, 1, iters_per_epoch=1, out_dir="." + ) + filename = "myfile-deadbeef.dat" - snapshot = extensions.snapshot_object(trainer, filename, - snapshot_on_error=True) + snapshot = extensions.snapshot_object( + trainer, filename, snapshot_on_error=True + ) trainer.extend(snapshot) assert not os.path.exists(filename) with pytest.raises(AttributeError): @@ -178,7 +173,7 @@ def test_on_error(): assert not os.path.exists(filename) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def path(): with tempfile.TemporaryDirectory() as t_path: yield t_path @@ -190,19 +185,22 @@ def snapshot_path(): yield t_path -@pytest.mark.parametrize('fmt', [ - 'snapshot_iter_{}', - 'snapshot_iter_{}.npz', - '{}_snapshot_man_suffix.npz', -]) +@pytest.mark.parametrize( + "fmt", + [ + "snapshot_iter_{}", + "snapshot_iter_{}.npz", + "{}_snapshot_man_suffix.npz", + ], +) def test_find_snapshot_files(fmt, path): files = (fmt.format(i) for i in range(1, 100)) - noise = ('dummy-foobar-iter{}'.format(i) for i in range(10, 304)) - noise2 = ('tmpsnapshot_iter_{}'.format(i) for i in range(10, 304)) + noise = ("dummy-foobar-iter{}".format(i) for i in range(10, 304)) + noise2 = ("tmpsnapshot_iter_{}".format(i) for i in range(10, 304)) for file in itertools.chain(noise, files, noise2): file = os.path.join(path, file) - open(file, 'w').close() + open(file, "w").close() writer = ppe.writing.SimpleWriter() snapshot_files = _find_snapshot_files(fmt, path, writer.fs) @@ -213,18 +211,21 @@ def test_find_snapshot_files(fmt, path): assert expected == sorted(list(snapshot_files)) -@pytest.mark.parametrize('fmt', [ - 'snapshot_iter_{}', - 'snapshot_iter_{}.npz', - '{}_snapshot_man_suffix.npz', -]) +@pytest.mark.parametrize( + "fmt", + [ + "snapshot_iter_{}", + "snapshot_iter_{}.npz", + "{}_snapshot_man_suffix.npz", + ], +) def test_find_latest_snapshot(fmt, path): files = [fmt.format(i) for i in range(1, 100)] base_timestamp = time.time() for i, file in enumerate(files): file = os.path.join(path, file) - open(file, 'w').close() + open(file, "w").close() # mtime resolution of some filesystems e.g. ext3 or HFS+ # is a second and thus snapshot files such as @@ -240,27 +241,36 @@ def test_find_latest_snapshot(fmt, path): assert fmt.format(99) == _find_latest_snapshot(fmt, path, writer.fs) -@pytest.mark.parametrize('fmt', [ - 'snapshot_iter_{}_{}', - 'snapshot_iter_{}_{}.npz', - '{}_snapshot_man_{}-suffix.npz', - 'snapshot_iter_{}.{}', -]) +@pytest.mark.parametrize( + "fmt", + [ + "snapshot_iter_{}_{}", + "snapshot_iter_{}_{}.npz", + "{}_snapshot_man_{}-suffix.npz", + "snapshot_iter_{}.{}", + ], +) def test_find_snapshot_files2(fmt, path): - files = (fmt.format(i * 10, j * 10) for i, j - in itertools.product(range(0, 10), range(0, 10))) - noise = ('tmpsnapshot_iter_{}.{}'.format(i, j) - for i, j in zip(range(10, 304), range(10, 200))) + files = ( + fmt.format(i * 10, j * 10) + for i, j in itertools.product(range(0, 10), range(0, 10)) + ) + noise = ( + "tmpsnapshot_iter_{}.{}".format(i, j) + for i, j in zip(range(10, 304), range(10, 200)) + ) for file in itertools.chain(noise, files): file = os.path.join(path, file) - open(file, 'w').close() + open(file, "w").close() writer = ppe.writing.SimpleWriter() snapshot_files = _find_snapshot_files(fmt, path, writer.fs) - expected = [fmt.format(i * 10, j * 10) - for i, j in itertools.product(range(0, 10), range(0, 10))] + expected = [ + fmt.format(i * 10, j * 10) + for i, j in itertools.product(range(0, 10), range(0, 10)) + ] timestamps, snapshot_files = zip(*snapshot_files) expected.sort() @@ -268,19 +278,27 @@ def test_find_snapshot_files2(fmt, path): assert expected == snapshot_files -@pytest.mark.parametrize('length_retain', [ - (100, 30), (10, 30), (1, 1000), - (1000, 1), (1, 1), (1, 3), (2, 3), -]) +@pytest.mark.parametrize( + "length_retain", + [ + (100, 30), + (10, 30), + (1, 1000), + (1000, 1), + (1, 1), + (1, 3), + (2, 3), + ], +) def test_find_stale_snapshot(length_retain, path): length, retain = length_retain - fmt = 'snapshot_iter_{}' + fmt = "snapshot_iter_{}" files = [fmt.format(i) for i in range(0, length)] base_timestamp = time.time() - length * 2 for i, file in enumerate(files): file = os.path.join(path, file) - open(file, 'w').close() + open(file, "w").close() # Same comment applies here. See comment in ``TestFindSnapshot`` t = base_timestamp + i @@ -294,17 +312,18 @@ def test_find_stale_snapshot(length_retain, path): def test_remove_stale_snapshots(path): - fmt = 'snapshot_iter_{.iteration}' + fmt = "snapshot_iter_{.iteration}" retain = 3 - snapshot = extensions.snapshot(filename=fmt, n_retains=retain, - autoload=False) + snapshot = extensions.snapshot( + filename=fmt, n_retains=retain, autoload=False + ) trainer = get_trainer(out_dir=path) - trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2) + trainer.extend(snapshot, trigger=(1, "iteration"), priority=2) class TimeStampUpdater(training.Extension): t = time.time() - 100 - name = 'ts_updater' + name = "ts_updater" priority = 1 # This must be called after snapshot taken def __call__(self, _trainer): @@ -313,7 +332,7 @@ def __call__(self, _trainer): # For filesystems that does low timestamp precision os.utime(filename, (self.t, self.t)) - trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration')) + trainer.extend(TimeStampUpdater(), trigger=(1, "iteration")) for _ in range(10): with trainer.run_iteration(): pass @@ -324,30 +343,32 @@ def __call__(self, _trainer): assert retain == len(found) found.sort() # snapshot_iter_(8, 9, 10) expected - expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)] + expected = ["snapshot_iter_{}".format(i) for i in range(8, 11)] expected.sort() assert expected == found - trainer2 = get_trainer( - out_dir=path, state_to_load=trainer.state_dict()) + trainer2 = get_trainer(out_dir=path, state_to_load=trainer.state_dict()) snapshot2 = extensions.snapshot(filename=fmt, autoload=True) # Just making sure no error occurs snapshot2.initialize(trainer2) def test_remove_stale_snapshots_with_writer(path, snapshot_path): - fmt = 'snapshot_iter_{.iteration}' + fmt = "snapshot_iter_{.iteration}" retain = 3 - snapshot = extensions.snapshot(filename=fmt, n_retains=retain, - writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), - autoload=False) + snapshot = extensions.snapshot( + filename=fmt, + n_retains=retain, + writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), + autoload=False, + ) trainer = get_trainer(out_dir=path) - trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2) + trainer.extend(snapshot, trigger=(1, "iteration"), priority=2) class TimeStampUpdater(training.Extension): t = time.time() - 100 - name = 'ts_updater' + name = "ts_updater" priority = 1 # This must be called after snapshot taken def __call__(self, _trainer): @@ -356,7 +377,7 @@ def __call__(self, _trainer): # For filesystems that does low timestamp precision os.utime(filename, (self.t, self.t)) - trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration')) + trainer.extend(TimeStampUpdater(), trigger=(1, "iteration")) for _ in range(10): with trainer.run_iteration(): pass @@ -371,13 +392,16 @@ def __call__(self, _trainer): assert retain == len(found) found.sort() # snapshot_iter_(8, 9, 10) expected - expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)] + expected = ["snapshot_iter_{}".format(i) for i in range(8, 11)] expected.sort() assert expected == found - trainer2 = get_trainer( - out_dir=path, state_to_load=trainer.state_dict()) - snapshot2 = extensions.snapshot(filename=fmt, autoload=True, writer=ppe.writing.SimpleWriter(out_dir=snapshot_path)) + trainer2 = get_trainer(out_dir=path, state_to_load=trainer.state_dict()) + snapshot2 = extensions.snapshot( + filename=fmt, + autoload=True, + writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), + ) # Just making sure no error occurs snapshot2.initialize(trainer2) @@ -408,7 +432,7 @@ def test_model_transformations(path): transform_model=lambda n, x: x.wrapper_module(), ) - snapshot = extensions.snapshot(filename='test') + snapshot = extensions.snapshot(filename="test") snapshot(manager) assert model.accessed @@ -417,7 +441,7 @@ def test_model_transformations(path): def test_snapshot_autoload_twice(path): max_epochs = 10 iters_per_epoch = 10 - fmt = 'snapshot_iter_{.iteration}' + fmt = "snapshot_iter_{.iteration}" def get_epoch_indices(): manager = get_trainer(out_dir=path, epochs=max_epochs) @@ -461,13 +485,20 @@ def test_snapshot_autoload_with_writer(path, snapshot_path): trainer = get_trainer(out_dir=path, epochs=10) trainer.models["main"]._state_dict = {"value": 0} - snapshot = extensions.snapshot(filename=snapshot_filename, writer=ppe.writing.SimpleWriter(out_dir=snapshot_path)) + snapshot = extensions.snapshot( + filename=snapshot_filename, + writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), + ) snapshot(trainer) assert os.path.isfile(os.path.join(snapshot_path, snapshot_filename)) assert not os.path.isfile(os.path.join(path, snapshot_filename)) trainer2 = get_trainer(out_dir=path, epochs=0) - snapshot2 = extensions.snapshot(filename=snapshot_filename, writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), autoload=True) + snapshot2 = extensions.snapshot( + filename=snapshot_filename, + writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), + autoload=True, + ) assert trainer2.state_dict() != trainer.state_dict() assert snapshot2.initialize(trainer2) == snapshot_filename diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py index 41b9ecfd9..c6578b111 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py @@ -1,14 +1,12 @@ import multiprocessing -import threading import tempfile +import threading from unittest import mock import pytest - from pytorch_pfn_extras import writing - -spshot_writers_path = 'pytorch_pfn_extras.writing' +spshot_writers_path = "pytorch_pfn_extras.writing" def test_simple_writer(): @@ -16,10 +14,10 @@ def test_simple_writer(): w = writing.SimpleWriter(foo=True) savefun = mock.MagicMock() with tempfile.TemporaryDirectory() as tempd: - w('myfile.dat', tempd, target, savefun=savefun) + w("myfile.dat", tempd, target, savefun=savefun) assert savefun.call_count == 1 assert savefun.call_args[0][0] == target - assert savefun.call_args[1]['foo'] is True + assert savefun.call_args[1]["foo"] is True def test_standard_writer(): @@ -27,11 +25,11 @@ def test_standard_writer(): w = writing.StandardWriter() worker = mock.MagicMock() worker.exitcode = 0 - name = spshot_writers_path + '.StandardWriter.create_worker' + name = spshot_writers_path + ".StandardWriter.create_worker" with mock.patch(name, return_value=worker): with tempfile.TemporaryDirectory() as tempd: - w('myfile.dat', tempd, target) - w('myfile.dat', tempd, target) + w("myfile.dat", tempd, target) + w("myfile.dat", tempd, target) w.finalize() assert worker.start.call_count == 2 @@ -42,16 +40,16 @@ def test_thread_writer_create_worker(): target = mock.MagicMock() w = writing.ThreadWriter() with tempfile.TemporaryDirectory() as tempd: - worker = w.create_worker('myfile.dat', tempd, target, append=False) + worker = w.create_worker("myfile.dat", tempd, target, append=False) assert isinstance(worker, threading.Thread) - w('myfile2.dat', tempd, 'test') + w("myfile2.dat", tempd, "test") w.finalize() def test_thread_writer_fail(): w = writing.ThreadWriter(savefun=None) with tempfile.TemporaryDirectory() as tempd: - w('myfile2.dat', tempd, 'test') + w("myfile2.dat", tempd, "test") with pytest.raises(RuntimeError): w.finalize() @@ -60,16 +58,16 @@ def test_process_writer_create_worker(): target = mock.MagicMock() w = writing.ProcessWriter() with tempfile.TemporaryDirectory() as tempd: - worker = w.create_worker('myfile.dat', tempd, target, append=False) + worker = w.create_worker("myfile.dat", tempd, target, append=False) assert isinstance(worker, multiprocessing.Process) - w('myfile2.dat', tempd, 'test') + w("myfile2.dat", tempd, "test") w.finalize() def test_process_writer_fail(): w = writing.ProcessWriter(savefun=None) with tempfile.TemporaryDirectory() as tempd: - w('myfile2.dat', tempd, 'test') + w("myfile2.dat", tempd, "test") with pytest.raises(RuntimeError): w.finalize() @@ -78,15 +76,17 @@ def test_queue_writer(): target = mock.MagicMock() q = mock.MagicMock() consumer = mock.MagicMock() - names = [spshot_writers_path + '.QueueWriter.create_queue', - spshot_writers_path + '.QueueWriter.create_consumer'] + names = [ + spshot_writers_path + ".QueueWriter.create_queue", + spshot_writers_path + ".QueueWriter.create_consumer", + ] with mock.patch(names[0], return_value=q): with mock.patch(names[1], return_value=consumer): w = writing.QueueWriter() with tempfile.TemporaryDirectory() as tempd: - w('myfile.dat', tempd, target) - w('myfile.dat', tempd, target) + w("myfile.dat", tempd, target) + w("myfile.dat", tempd, target) w.finalize() assert consumer.start.call_count == 1 @@ -96,8 +96,10 @@ def test_queue_writer(): def test_queue_writer_consume(): - names = [spshot_writers_path + '.QueueWriter.create_queue', - spshot_writers_path + '.QueueWriter.create_consumer'] + names = [ + spshot_writers_path + ".QueueWriter.create_queue", + spshot_writers_path + ".QueueWriter.create_consumer", + ] with mock.patch(names[0]): with mock.patch(names[1]): task = mock.MagicMock() diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py index 3f1bb0296..cba6dd4be 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py @@ -1,30 +1,25 @@ -import torch - import pytorch_pfn_extras as ppe +import torch def test_observe_value(): lr = 0.1 - manager = ppe.training.ExtensionsManager( - {}, [], 1, - iters_per_epoch=1) - extension = ppe.training.extensions.observe_value('lr', lambda x: lr) + manager = ppe.training.ExtensionsManager({}, [], 1, iters_per_epoch=1) + extension = ppe.training.extensions.observe_value("lr", lambda x: lr) manager.extend(extension) with manager.run_iteration(): pass - assert manager.observation['lr'] == lr + assert manager.observation["lr"] == lr def test_observe_lr(): lr = 0.01 - manager = ppe.training.ExtensionsManager( - {}, [], 1, - iters_per_epoch=1) + manager = ppe.training.ExtensionsManager({}, [], 1, iters_per_epoch=1) optimizer = torch.optim.Adam({torch.nn.Parameter()}, lr=lr) extension = ppe.training.extensions.observe_lr(optimizer) manager.extend(extension) with manager.run_iteration(): pass - assert manager.observation['lr'] == lr + assert manager.observation["lr"] == lr diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py index e6dad4b9b..3b433815a 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py @@ -2,32 +2,31 @@ import numpy import pytest -import torch - import pytorch_pfn_extras as ppe +import torch @pytest.fixture(scope="module") def matplotlib(): try: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") return matplotlib except ImportError: - pytest.skip('matplotlib is not installed') + pytest.skip("matplotlib is not installed") def test_run_and_save_plot(matplotlib): - filename = 'variable_statistics_plot_test.png' + filename = "variable_statistics_plot_test.png" iterations = 2 - extension_trigger = (1, 'iteration') - manager = ppe.training.ExtensionsManager( - {}, [], 2, - iters_per_epoch=1) + extension_trigger = (1, "iteration") + manager = ppe.training.ExtensionsManager({}, [], 2, iters_per_epoch=1) x = torch.rand(1, 2, 3) extension = ppe.training.extensions.VariableStatisticsPlot( - x, trigger=extension_trigger, filename=filename) + x, trigger=extension_trigger, filename=filename + ) manager.extend(extension, trigger=extension_trigger) # In the following we explicitly use plot_report._available instead of @@ -47,12 +46,11 @@ def test_reservoir_size(): shape = (2, 7, 3) n = 5 reservoir_size = 3 - xs = [ - 2 * torch.rand(shape) - 1 for i in range(n)] + xs = [2 * torch.rand(shape) - 1 for i in range(n)] - reservoir = ( - ppe.training.extensions.variable_statistics_plot.Reservoir( - size=reservoir_size, data_shape=shape)) + reservoir = ppe.training.extensions.variable_statistics_plot.Reservoir( + size=reservoir_size, data_shape=shape + ) for x in xs: reservoir.add(x) idxs, data = reservoir.get_data() @@ -68,20 +66,23 @@ def test_statistician_percentile(): shape = (2, 7, 3) x = 2 * torch.rand(shape) - 1 - percentile_sigmas = (0., 100.) # min, max + percentile_sigmas = (0.0, 100.0) # min, max statistician = ( ppe.training.extensions.variable_statistics_plot.Statistician( - collect_mean=True, collect_std=True, - percentile_sigmas=percentile_sigmas)) + collect_mean=True, + collect_std=True, + percentile_sigmas=percentile_sigmas, + ) + ) stat = statistician(x, axis=None, dtype=x.dtype) for s in stat.values(): assert s.dtype == x.dtype - assert torch.allclose(stat['mean'], torch.mean(x)) - assert torch.allclose(stat['std'], torch.std(x)) + assert torch.allclose(stat["mean"], torch.mean(x)) + assert torch.allclose(stat["std"], torch.std(x)) - percentile = stat['percentile'] + percentile = stat["percentile"] assert len(percentile) == 2 assert torch.allclose(percentile[0], torch.min(x)) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py b/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py index 0ddaaf776..f0f1b49be 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py @@ -5,7 +5,8 @@ class TestEngine: def test_engine_extension(self): engine = ppe.training._trainer.Trainer( - None, evaluator=None, models={}, optimizers={}, max_epochs=10) + None, evaluator=None, models={}, optimizers={}, max_epochs=10 + ) extension = ppe.training.extensions.LogReport() engine.extend(extension) # Create the actual manager object @@ -15,26 +16,32 @@ def test_engine_extension(self): def test_engine_state_dict(self): manager = ppe.training.ExtensionsManager( - {}, {}, 10, iters_per_epoch=100) + {}, {}, 10, iters_per_epoch=100 + ) engine = ppe.training._trainer.Trainer( - None, evaluator=None, models={}, optimizers={}, max_epochs=10) + None, evaluator=None, models={}, optimizers={}, max_epochs=10 + ) engine._setup_manager(100) assert engine.state_dict() == manager.state_dict() def test_engine_load_state_dict(self): manager = ppe.training.ExtensionsManager( - {}, {}, 10, iters_per_epoch=100) + {}, {}, 10, iters_per_epoch=100 + ) engine = ppe.training._trainer.Trainer( - None, evaluator=None, models={}, optimizers={}, max_epochs=1) + None, evaluator=None, models={}, optimizers={}, max_epochs=1 + ) engine.load_state_dict(manager.state_dict()) engine._setup_manager(20) assert engine.state_dict() == manager.state_dict() def test_engine_load_state_dict_2(self): manager = ppe.training.ExtensionsManager( - {}, {}, 10, iters_per_epoch=100) + {}, {}, 10, iters_per_epoch=100 + ) engine = ppe.training._trainer.Trainer( - None, evaluator=None, models={}, optimizers={}, max_epochs=1) + None, evaluator=None, models={}, optimizers={}, max_epochs=1 + ) engine._setup_manager(20) engine.load_state_dict(manager.state_dict()) assert engine.state_dict() == manager.state_dict() @@ -42,22 +49,29 @@ def test_engine_load_state_dict_2(self): class TestEngineInvalid: def test_engine_wrong_models(self): - with pytest.raises(ValueError, match='model must be an instance'): + with pytest.raises(ValueError, match="model must be an instance"): ppe.training._trainer.Trainer( - None, evaluator=None, models=object(), optimizers={}, max_epochs=10) + None, + evaluator=None, + models=object(), + optimizers={}, + max_epochs=10, + ) def test_engine_not_started(self): engine = ppe.training._trainer.Trainer( - None, evaluator=None, models={}, optimizers={}, max_epochs=10) - with pytest.raises(RuntimeError, match='is not started'): + None, evaluator=None, models={}, optimizers={}, max_epochs=10 + ) + with pytest.raises(RuntimeError, match="is not started"): engine.state_dict() - with pytest.raises(RuntimeError, match='is not started'): + with pytest.raises(RuntimeError, match="is not started"): engine.manager def test_extend_after_init(self): engine = ppe.training._trainer.Trainer( - None, evaluator=None, models={}, optimizers={}, max_epochs=10) + None, evaluator=None, models={}, optimizers={}, max_epochs=10 + ) engine._setup_manager(10) extension = ppe.training.extensions.LogReport() - with pytest.raises(RuntimeError, match='cannot extend after'): + with pytest.raises(RuntimeError, match="cannot extend after"): engine.extend(extension) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py b/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py index 111b3dc59..43180affe 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py @@ -1,8 +1,6 @@ import pytest - -import torch - import pytorch_pfn_extras as ppe +import torch from pytorch_pfn_extras import engine @@ -15,25 +13,28 @@ def forward(self, x, t): g = t.clone() to_alter = int(10 * (1 - self.correct_ratio)) g[0:to_alter][:] -= 1 - return {'y': g} + return {"y": g} -@pytest.mark.parametrize('device', ['cpu']) -@pytest.mark.parametrize('accuracy', [0, 0.5, 1.0]) +@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("accuracy", [0, 0.5, 1.0]) def test_evaluator_with_metric(device, accuracy): model = MyModel(accuracy) data = torch.utils.data.DataLoader( - [{'x': torch.rand(20), 't': torch.rand(1)} for i in range(10)], - batch_size=10) + [{"x": torch.rand(20), "t": torch.rand(1)} for i in range(10)], + batch_size=10, + ) ppe.to(model, device) evaluator = engine.create_evaluator( - model, device=device, - metrics=[ppe.training.metrics.AccuracyMetric('t', 'y')], - options={'eval_report_keys': ['accuracy']}) + model, + device=device, + metrics=[ppe.training.metrics.AccuracyMetric("t", "y")], + options={"eval_report_keys": ["accuracy"]}, + ) evaluator.handler.eval_setup(evaluator, data) reporter = ppe.reporting.Reporter() observation = {} with reporter.scope(observation): evaluator.run(data) - assert pytest.approx(observation['val/accuracy']) == accuracy + assert pytest.approx(observation["val/accuracy"]) == accuracy diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py b/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py index e81beb433..f6f73e2a3 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py @@ -1,15 +1,14 @@ -import pytest -import torch - from unittest import mock +import pytest import pytorch_pfn_extras as ppe +import torch def _get_dummy_manager(): model = torch.nn.Module() return ppe.training.ExtensionsManager( - {'main': model}, + {"main": model}, [], # optimizers 10, # max_epochs iters_per_epoch=1, @@ -31,7 +30,7 @@ class MyExtension(ppe.training.Extension): pass ext = MyExtension() - assert ext.default_name == 'MyExtension' + assert ext.default_name == "MyExtension" def test_deleted_invoke_before_training(): @@ -46,13 +45,17 @@ class MyExtension(ppe.training.Extension): def test_make_extension(): initialize = mock.Mock() - @ppe.training.make_extension(trigger=(2, 'epoch'), default_name='my_ext', - priority=50, initializer=initialize) + @ppe.training.make_extension( + trigger=(2, "epoch"), + default_name="my_ext", + priority=50, + initializer=initialize, + ) def my_extension(trainer): pass - assert my_extension.trigger == (2, 'epoch') - assert my_extension.default_name == 'my_ext' + assert my_extension.trigger == (2, "epoch") + assert my_extension.default_name == "my_ext" assert my_extension.priority == 50 trainer = object() @@ -66,8 +69,8 @@ def test_make_extension_default_values(): def my_extension(trainer): pass - assert my_extension.trigger == (1, 'iteration') - assert my_extension.default_name == 'my_extension' + assert my_extension.trigger == (1, "iteration") + assert my_extension.default_name == "my_extension" assert my_extension.priority == ppe.training.PRIORITY_READER manager = object() my_extension.initialize(manager) @@ -75,6 +78,7 @@ def my_extension(trainer): def test_make_extension_unexpected_kwargs(): with pytest.raises(TypeError): + @ppe.training.make_extension(foo=1) def my_extension(_): pass @@ -92,9 +96,10 @@ def on_error(self, manager, exc, tb): assert isinstance(exc, RuntimeError) self.call_cnt += 1 - optimizers = {'main': object()} + optimizers = {"main": object()} manager = ppe.training.ExtensionsManager( - {}, optimizers, 1, iters_per_epoch=2) + {}, optimizers, 1, iters_per_epoch=2 + ) ext = DummyExt() manager.extend(ext) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py b/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py index f4801e7e2..87cafe79b 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py @@ -1,12 +1,11 @@ -import torch - import pytorch_pfn_extras as ppe +import torch def _get_dummy_manager(): model = torch.nn.Module() return ppe.training.ExtensionsManager( - {'main': model}, + {"main": model}, [], # optimizers 10, # max_epochs iters_per_epoch=1, @@ -16,25 +15,25 @@ def _get_dummy_manager(): def test_default_name(): class MyExtension(ppe.training.Extension): name = None - default_name = 'defalut_name' + default_name = "defalut_name" ext = MyExtension() entry = ppe.training.ExtensionEntry(ext) assert entry.name == MyExtension.default_name - entry = ppe.training.ExtensionEntry(ext, name='updated') - assert entry.name == 'updated' + entry = ppe.training.ExtensionEntry(ext, name="updated") + assert entry.name == "updated" def test_name(): class MyExtension(ppe.training.Extension): - name = 'name' - default_name = 'defalut_name' + name = "name" + default_name = "defalut_name" ext = MyExtension() entry = ppe.training.ExtensionEntry(ext) assert entry.name == MyExtension.name - entry = ppe.training.ExtensionEntry(ext, name='updated') - assert entry.name == 'updated' + entry = ppe.training.ExtensionEntry(ext, name="updated") + assert entry.name == "updated" def test_priority(): @@ -50,14 +49,14 @@ class MyExtension(ppe.training.Extension): def test_trigger(): class MyExtension(ppe.training.Extension): - trigger = (1, 'iteration') + trigger = (1, "iteration") ext = MyExtension() entry = ppe.training.ExtensionEntry(ext) assert isinstance(entry.trigger, ppe.training.triggers.IntervalTrigger) assert entry.trigger.period == 1 - assert entry.trigger.unit == 'iteration' - entry = ppe.training.ExtensionEntry(ext, trigger=(3, 'epoch')) + assert entry.trigger.unit == "iteration" + entry = ppe.training.ExtensionEntry(ext, trigger=(3, "epoch")) assert isinstance(entry.trigger, ppe.training.triggers.IntervalTrigger) assert entry.trigger.period == 3 - assert entry.trigger.unit == 'epoch' + assert entry.trigger.unit == "epoch" diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py b/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py index 8cafab62b..f0aa1b9a7 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py @@ -1,18 +1,13 @@ import pytest - -import torch -from torch import nn - import pytorch_pfn_extras as ppe +import torch from pytorch_pfn_extras import training +from torch import nn def test_manager_status_info(): manager = training.ExtensionsManager( - nn.Module(), - object(), - 10, - iters_per_epoch=4 + nn.Module(), object(), 10, iters_per_epoch=4 ) manager.iteration = 9 assert manager.iteration == 9 @@ -25,9 +20,7 @@ def test_manager_status_info(): class _DummyExtension(training.Extension): - - def __init__( - self, extension_id, call_record, init_record, use_model=False): + def __init__(self, extension_id, call_record, init_record, use_model=False): self.extension_id = extension_id self.call_record = call_record self.init_record = init_record @@ -37,12 +30,11 @@ def __call__(self, manager): self.call_record.append(self.extension_id) if self.use_model: # Check if models are accessible. - assert manager.models['main'] is not None - assert manager.raw_models['main'] is not None + assert manager.models["main"] is not None + assert manager.raw_models["main"] is not None class _DummyExtensionInitialize(_DummyExtension): - def initialize(self, manager): self.init_record.append(self.extension_id) @@ -53,8 +45,8 @@ def test_extensions_manager_extensions(): max_epochs = 5 iters_per_epoch = 4 manager = training.ExtensionsManager( - {'model_name': model}, - {'optimizer_name': optimizer}, + {"model_name": model}, + {"optimizer_name": optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) @@ -72,33 +64,37 @@ def test_extensions_manager_extensions(): lambda manager: dummy5(manager), training.ExtensionEntry( _DummyExtension(6, call_record, init_record), - name='ext6', priority=-3, call_before_training=True, + name="ext6", + priority=-3, + call_before_training=True, ), training.ExtensionEntry( _DummyExtension(7, call_record, init_record), - name='ext7_', priority=-2, call_before_training=True, + name="ext7_", + priority=-2, + call_before_training=True, ), ] - manager.extend(exts[0], 'ext0', priority=2, call_before_training=True) - manager.extend(exts[1], 'ext1', priority=1, call_before_training=False) - manager.extend(exts[2], 'ext2', priority=3, call_before_training=False) - manager.extend(exts[3], 'ext3', priority=0, call_before_training=True) - manager.extend(exts[4], 'ext4', priority=4, call_before_training=True) - manager.extend(exts[5], 'ext5', priority=-1, call_before_training=True) + manager.extend(exts[0], "ext0", priority=2, call_before_training=True) + manager.extend(exts[1], "ext1", priority=1, call_before_training=False) + manager.extend(exts[2], "ext2", priority=3, call_before_training=False) + manager.extend(exts[3], "ext3", priority=0, call_before_training=True) + manager.extend(exts[4], "ext4", priority=4, call_before_training=True) + manager.extend(exts[5], "ext5", priority=-1, call_before_training=True) manager.extend(exts[6]) - manager.extend(exts[7], 'ext7', priority=-4, call_before_training=False) + manager.extend(exts[7], "ext7", priority=-4, call_before_training=False) - assert manager.get_extension('ext0') is exts[0] - assert manager.get_extension('ext1') is exts[1] - assert manager.get_extension('ext2') is exts[2] - assert manager.get_extension('ext3') is exts[3] - assert manager.get_extension('ext4') is exts[4] - assert manager.get_extension('ext6') is exts[6].extension - assert manager.get_extension('ext7') is exts[7].extension + assert manager.get_extension("ext0") is exts[0] + assert manager.get_extension("ext1") is exts[1] + assert manager.get_extension("ext2") is exts[2] + assert manager.get_extension("ext3") is exts[3] + assert manager.get_extension("ext4") is exts[4] + assert manager.get_extension("ext6") is exts[6].extension + assert manager.get_extension("ext7") is exts[7].extension with pytest.raises(ValueError): - manager.get_extension('ext10') + manager.get_extension("ext10") for it in range(max_epochs * iters_per_epoch): call_record.clear() @@ -121,7 +117,7 @@ def test_extensions_manager_extensions(): assert init_record == [] -class _StateDictObj(): +class _StateDictObj: def __init__(self, *, state_dict=None, state_dict_to_be_loaded=None): super().__init__() self.called_load_state_dict = 0 @@ -137,13 +133,11 @@ def load_state_dict(self, state_dict): class _StateDictModel(_StateDictObj, nn.Module): - def forward(self, *args): pass class _StateDictOptimizer(_StateDictObj): - def zero_grad(self): pass @@ -152,7 +146,6 @@ def step(self): class _StateDictExtension(_StateDictObj, training.Extension): - def __call__(self, manager): pass @@ -170,15 +163,16 @@ def test_extensions_manager_state_dict(): passed_iteration = 11 manager = training.ExtensionsManager( - {'model_name': _StateDictModel(state_dict=model_state_dict)}, - {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)}, + {"model_name": _StateDictModel(state_dict=model_state_dict)}, + {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)}, max_epochs, iters_per_epoch=iters_per_epoch, ) manager.extend( - _StateDictExtension( - state_dict=extension_state_dict), name='extension_name') + _StateDictExtension(state_dict=extension_state_dict), + name="extension_name", + ) for _ in range(passed_iteration): with manager.run_iteration(): @@ -187,29 +181,31 @@ def test_extensions_manager_state_dict(): state_dict = manager.state_dict() assert state_dict == { - 'ppe_version': ppe.__version__, - '_start_execution': passed_iteration, - '_start_iteration': passed_iteration, - 'models': {'model_name': model_state_dict}, - 'optimizers': {'optimizer_name': optimizer_state_dict}, - 'extensions': {'extension_name': { - 'extension': extension_state_dict, - 'trigger': { - }, - }}, + "ppe_version": ppe.__version__, + "_start_execution": passed_iteration, + "_start_iteration": passed_iteration, + "models": {"model_name": model_state_dict}, + "optimizers": {"optimizer_name": optimizer_state_dict}, + "extensions": { + "extension_name": { + "extension": extension_state_dict, + "trigger": {}, + } + }, } new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) new_extension = _StateDictExtension( - state_dict_to_be_loaded=extension_state_dict) + state_dict_to_be_loaded=extension_state_dict + ) new_manager = training.ExtensionsManager( - {'model_name': new_model}, - {'optimizer_name': new_optimizer}, + {"model_name": new_model}, + {"optimizer_name": new_optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) - new_manager.extend(new_extension, name='extension_name') + new_manager.extend(new_extension, name="extension_name") new_manager.load_state_dict(state_dict) assert new_model.called_load_state_dict == 1 assert new_optimizer.called_load_state_dict == 1 @@ -224,8 +220,8 @@ def test_extensions_manager_state_dict_old_ppe_no_version(): passed_iteration = 11 manager = training.ExtensionsManager( - {'model_name': _StateDictModel(state_dict=model_state_dict)}, - {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)}, + {"model_name": _StateDictModel(state_dict=model_state_dict)}, + {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)}, max_epochs, iters_per_epoch=iters_per_epoch, ) @@ -237,15 +233,15 @@ def test_extensions_manager_state_dict_old_ppe_no_version(): new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) manager_2 = training.ExtensionsManager( - {'model_name': new_model}, - {'optimizer_name': new_optimizer}, + {"model_name": new_model}, + {"optimizer_name": new_optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) state_dict = manager.state_dict() - del state_dict['ppe_version'] - with pytest.warns(UserWarning, match='version'): + del state_dict["ppe_version"] + with pytest.warns(UserWarning, match="version"): manager_2.load_state_dict(state_dict) @@ -257,8 +253,8 @@ def test_extensions_manager_state_dict_old_ppe_version(): passed_iteration = 11 manager = training.ExtensionsManager( - {'model_name': _StateDictModel(state_dict=model_state_dict)}, - {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)}, + {"model_name": _StateDictModel(state_dict=model_state_dict)}, + {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)}, max_epochs, iters_per_epoch=iters_per_epoch, ) @@ -270,15 +266,15 @@ def test_extensions_manager_state_dict_old_ppe_version(): new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) manager_2 = training.ExtensionsManager( - {'model_name': new_model}, - {'optimizer_name': new_optimizer}, + {"model_name": new_model}, + {"optimizer_name": new_optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) state_dict = manager.state_dict() - state_dict['ppe_version'] = '0.4.0' - with pytest.warns(UserWarning, match='version'): + state_dict["ppe_version"] = "0.4.0" + with pytest.warns(UserWarning, match="version"): manager_2.load_state_dict(state_dict) @@ -290,8 +286,8 @@ def test_extensions_manager_state_dict_future_ppe_version(): passed_iteration = 11 manager = training.ExtensionsManager( - {'model_name': _StateDictModel(state_dict=model_state_dict)}, - {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)}, + {"model_name": _StateDictModel(state_dict=model_state_dict)}, + {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)}, max_epochs, iters_per_epoch=iters_per_epoch, ) @@ -303,24 +299,23 @@ def test_extensions_manager_state_dict_future_ppe_version(): new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict) manager_2 = training.ExtensionsManager( - {'model_name': new_model}, - {'optimizer_name': new_optimizer}, + {"model_name": new_model}, + {"optimizer_name": new_optimizer}, max_epochs, iters_per_epoch=iters_per_epoch, ) state_dict = manager.state_dict() - state_dict['ppe_version'] = '23.0.0' - with pytest.warns(UserWarning, match='version'): + state_dict["ppe_version"] = "23.0.0" + with pytest.warns(UserWarning, match="version"): manager_2.load_state_dict(state_dict) def test_ignite_extensions_manager_state_dict(): - try: from ignite.engine import create_supervised_trainer except ImportError: - pytest.skip('pytorch-ignite not found') + pytest.skip("pytorch-ignite not found") model_state_dict = object() optimizer_state_dict = object() @@ -332,54 +327,57 @@ def test_ignite_extensions_manager_state_dict(): model = _StateDictModel(state_dict=model_state_dict) optimizer = _StateDictOptimizer(state_dict=optimizer_state_dict) - trainer = create_supervised_trainer( - model, optimizer, _fake_loss) + trainer = create_supervised_trainer(model, optimizer, _fake_loss) manager = training.IgniteExtensionsManager( trainer, - {'model_name': model}, - {'optimizer_name': optimizer}, + {"model_name": model}, + {"optimizer_name": optimizer}, max_epochs, ) manager.extend( - _StateDictExtension( - state_dict=extension_state_dict), name='extension_name') + _StateDictExtension(state_dict=extension_state_dict), + name="extension_name", + ) loader = torch.utils.data.DataLoader( - [(i, i) for i in range(iters_per_epoch)]) + [(i, i) for i in range(iters_per_epoch)] + ) trainer.run(loader, max_epochs=max_epochs) state_dict = manager.state_dict() assert state_dict == { - 'ppe_version': ppe.__version__, - '_start_execution': passed_iteration, - '_start_iteration': passed_iteration, - '_epoch_length': iters_per_epoch, - 'models': {'model_name': model_state_dict}, - 'optimizers': {'optimizer_name': optimizer_state_dict}, - 'extensions': {'extension_name': { - 'extension': extension_state_dict, - 'trigger': { - }, - }}, + "ppe_version": ppe.__version__, + "_start_execution": passed_iteration, + "_start_iteration": passed_iteration, + "_epoch_length": iters_per_epoch, + "models": {"model_name": model_state_dict}, + "optimizers": {"optimizer_name": optimizer_state_dict}, + "extensions": { + "extension_name": { + "extension": extension_state_dict, + "trigger": {}, + } + }, } new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict) new_optimizer = _StateDictOptimizer( - state_dict_to_be_loaded=optimizer_state_dict) + state_dict_to_be_loaded=optimizer_state_dict + ) new_extension = _StateDictExtension( - state_dict_to_be_loaded=extension_state_dict) + state_dict_to_be_loaded=extension_state_dict + ) - new_trainer = create_supervised_trainer( - model, optimizer, _fake_loss) + new_trainer = create_supervised_trainer(model, optimizer, _fake_loss) new_manager = training.IgniteExtensionsManager( new_trainer, - {'model_name': new_model}, - {'optimizer_name': new_optimizer}, + {"model_name": new_model}, + {"optimizer_name": new_optimizer}, max_epochs, ) - new_manager.extend(new_extension, name='extension_name') + new_manager.extend(new_extension, name="extension_name") new_manager.load_state_dict(state_dict) assert new_model.called_load_state_dict == 1 assert new_optimizer.called_load_state_dict == 1 @@ -401,12 +399,12 @@ def test_extensions_manager_with_plain_model_and_optimizer(): state_dict = manager.state_dict() assert state_dict == { - 'ppe_version': ppe.__version__, - '_start_execution': 0, - '_start_iteration': 0, - 'models': {'main': model_state_dict}, - 'optimizers': {'main': optimizer_state_dict}, - 'extensions': {} + "ppe_version": ppe.__version__, + "_start_execution": 0, + "_start_iteration": 0, + "models": {"main": model_state_dict}, + "optimizers": {"main": optimizer_state_dict}, + "extensions": {}, } @@ -435,7 +433,7 @@ def test_model_transformations(): transform_model=lambda n, x: x.wrapper_module(), ) - assert not isinstance(manager.models['main'], Wrapper) + assert not isinstance(manager.models["main"], Wrapper) assert model.accessed @@ -466,7 +464,7 @@ def test_model_transformations_in_state_dict(): transform_model=lambda n, x: x.wrapper_module(), ) new_manager.load_state_dict(state_dict) - assert isinstance(new_manager.models['main'], _StateDictModel) + assert isinstance(new_manager.models["main"], _StateDictModel) def test_call_optimizers(): @@ -479,9 +477,9 @@ def test_call_optimizers(): 1, iters_per_epoch=1, ) - with manager.run_iteration(step_optimizers=['main']): + with manager.run_iteration(step_optimizers=["main"]): a.grad = torch.tensor([2.0]) - assert torch.equal(a.detach(), torch.tensor([-1.])) + assert torch.equal(a.detach(), torch.tensor([-1.0])) def test_needs_state_this_iteration(): @@ -489,15 +487,11 @@ def test_needs_state_this_iteration(): a = torch.ones(1, requires_grad=True) optimizer = torch.optim.SGD(lr=1.0, params=[a]) extension = _DummyExtension(0, [], [], True) - extension.name = 'Dummy' + extension.name = "Dummy" extension.needs_model_state = True - extension.trigger = (50, 'iteration') + extension.trigger = (50, "iteration") manager = training.ExtensionsManager( - m, - optimizer, - 1, - iters_per_epoch=100, - extensions=[extension] + m, optimizer, 1, iters_per_epoch=100, extensions=[extension] ) while not manager.stop_trigger: with manager.run_iteration(): @@ -509,27 +503,26 @@ def test_needs_state_this_iteration(): assert not manager.needs_state_this_iteration() -@pytest.mark.parametrize('priority', [ - None, - training.extension.PRIORITY_SNAPSHOT, - training.PRIORITY_WRITER, -]) +@pytest.mark.parametrize( + "priority", + [ + None, + training.extension.PRIORITY_SNAPSHOT, + training.PRIORITY_WRITER, + ], +) def test_extensions_accessing_models_without_flag(priority): m = torch.nn.Linear(5, 5) a = torch.ones(1, requires_grad=True) optimizer = torch.optim.SGD(lr=1.0, params=[a]) extension = _DummyExtension(0, [], [], True) - extension.name = 'Dummy' + extension.name = "Dummy" extension.needs_model_state = False - extension.trigger = (1, 'iteration') + extension.trigger = (1, "iteration") if priority is not None: extension.priority = priority manager = training.ExtensionsManager( - m, - optimizer, - 1, - iters_per_epoch=5, - extensions=[extension] + m, optimizer, 1, iters_per_epoch=5, extensions=[extension] ) while not manager.stop_trigger: with pytest.raises(RuntimeError): @@ -549,9 +542,10 @@ def __call__(self, manager): def finalize(self, manager): self.finalized = True - optimizers = {'main': object()} + optimizers = {"main": object()} manager = ppe.training.ExtensionsManager( - {}, optimizers, 1, iters_per_epoch=1) + {}, optimizers, 1, iters_per_epoch=1 + ) ext = DummyExt() manager.extend(ext) @@ -565,5 +559,5 @@ def finalize(self, manager): pass -if __name__ == '__main__': - pytest.main([__file__, '-v', '-s']) +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py index 564c7cec5..4cfd57f87 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py @@ -1,20 +1,17 @@ import os import tempfile import typing +from unittest import mock import pytest - +import pytorch_pfn_extras as ppe import torch +from pytorch_pfn_extras import engine, training from torch import nn from torch.nn import functional as F -from unittest import mock - -import pytorch_pfn_extras as ppe -from pytorch_pfn_extras import engine -from pytorch_pfn_extras import training -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def path(): with tempfile.TemporaryDirectory() as t_path: yield t_path @@ -45,45 +42,59 @@ def __init__(self, model): def forward(self, x, t): y = self.model(x) - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" loss = F.l1_loss(y, t) - ppe.reporting.report({prefix + '/loss': loss}) + ppe.reporting.report({prefix + "/loss": loss}) return loss def _make_extensions(): return [ - training.extensions.LogReport(trigger=(10, 'iteration')), + training.extensions.LogReport(trigger=(10, "iteration")), training.extensions.ProgressBar(update_interval=2), training.extensions.PrintReport( [ - 'epoch', - 'iteration', - 'train/loss', - 'val/loss', - 'val/accuracy', - 'elapsed_time', - 'time', + "epoch", + "iteration", + "train/loss", + "val/loss", + "val/accuracy", + "elapsed_time", + "time", ] ), ] -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_trainer(device, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) extensions = _make_extensions() trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, extensions=extensions, + model_with_loss, + optimizer, + 20, + device=device, + extensions=extensions, out_dir=path, ) trainer.run(data) @@ -94,12 +105,26 @@ def test_trainer_no_to(path): model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) extensions = _make_extensions() trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device='cpu', extensions=extensions, + model_with_loss, + optimizer, + 20, + device="cpu", + extensions=extensions, out_dir=path, ) with pytest.raises(RuntimeError, match="ppe.to"): @@ -107,44 +132,63 @@ def test_trainer_no_to(path): def test_trainer_invalid_options(path): - device = 'cpu' + device = "cpu" model = MyModel() ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) extensions = _make_extensions() - options = {'UNKNOWN_OPTIONS': True} + options = {"UNKNOWN_OPTIONS": True} with pytest.raises(ValueError, match="UNKNOWN_OPTIONS"): engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, extensions=extensions, + model_with_loss, + optimizer, + 20, + device=device, + extensions=extensions, out_dir=path, options=options, ) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('progress_bar', [True, False]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("progress_bar", [True, False]) def test_train_with_evaluator(device, progress_bar, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar) + model_with_loss, device=device, progress_bar=progress_bar + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=evaluator, extensions=extensions, - out_dir=path + model_with_loss, + optimizer, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, + out_dir=path, ) - mpath = 'pytorch_pfn_extras.training._evaluator.Evaluator.run' + mpath = "pytorch_pfn_extras.training._evaluator.Evaluator.run" with mock.patch(mpath) as patched: trainer.run(data, data) assert patched.call_count == 20 @@ -155,68 +199,120 @@ def test_train_with_evaluator(device, progress_bar, path): [(20, (1, "epoch")), (40, (5, "iteration"))], ) def test_evaluator_trigger(evaluator_trigger, path): - device = 'cpu' + device = "cpu" progress_bar = False model = MyModel() ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar) + model_with_loss, device=device, progress_bar=progress_bar + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=(evaluator, evaluator_trigger[1]), - extensions=extensions, out_dir=path + model_with_loss, + optimizer, + 20, + device=device, + evaluator=(evaluator, evaluator_trigger[1]), + extensions=extensions, + out_dir=path, ) - path = 'pytorch_pfn_extras.training._evaluator.Evaluator.run' + path = "pytorch_pfn_extras.training._evaluator.Evaluator.run" with mock.patch(path) as patched: trainer.run(data, data) assert patched.call_count == evaluator_trigger[0] def test_evaluator_dict(path): - device = 'cpu' + device = "cpu" progress_bar = False model = MyModel() ppe.to(model, device) model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) extensions = _make_extensions() evaluator1 = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar) + model_with_loss, device=device, progress_bar=progress_bar + ) evaluator2 = engine.create_evaluator( - model, device=device, progress_bar=progress_bar) + model, device=device, progress_bar=progress_bar + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator={ - '1': evaluator1, # called 20 times. - '2': (evaluator2, (5, 'iteration')), # called 40 times. + model_with_loss, + optimizer, + 20, + device=device, + evaluator={ + "1": evaluator1, # called 20 times. + "2": (evaluator2, (5, "iteration")), # called 40 times. }, - extensions=extensions, out_dir=path + extensions=extensions, + out_dir=path, ) - path = 'pytorch_pfn_extras.training._evaluator.Evaluator.run' + path = "pytorch_pfn_extras.training._evaluator.Evaluator.run" with mock.patch(path) as patched: trainer.run(data, data) assert patched.call_count == 20 + 40 -@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_train_result_equal(device, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() train_data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) data = torch.utils.data.DataLoader( - [(torch.rand(20,),) for i in range(30)]) + [ + ( + torch.rand( + 20, + ), + ) + for i in range(30) + ] + ) def get_result_from_trainer(): model = MyModel() @@ -226,9 +322,12 @@ def get_result_from_trainer(): extensions = _make_extensions() trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, extensions=extensions, - out_dir=path + model_with_loss, + optimizer, + 20, + device=device, + extensions=extensions, + out_dir=path, ) trainer.run(train_data) @@ -293,14 +392,17 @@ def _compare_states(s1, s2): class TestTrainerState: def _get_trainer(self, epochs, out_dir): model = MyModel() - ppe.to(model, 'cpu') + ppe.to(model, "cpu") model_with_loss = MyModelWithLossFn(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) extensions = _make_extensions() trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device='cpu', extensions=extensions, - out_dir=out_dir + model_with_loss, + optimizer, + 20, + device="cpu", + extensions=extensions, + out_dir=out_dir, ) return trainer @@ -308,7 +410,18 @@ def test_trainer_state(self, path): torch.manual_seed(0) trainer = self._get_trainer(20, path) data = torch.utils.data.DataLoader( - [(torch.ones(20,), torch.ones(10,)) for i in range(10)]) + [ + ( + torch.ones( + 20, + ), + torch.ones( + 10, + ), + ) + for i in range(10) + ] + ) trainer.run(data) # State to be compared to state = trainer.state_dict() @@ -324,7 +437,18 @@ def test_trainer_state(self, path): def test_trainer_autoload(self, path): trainer = self._get_trainer(20, path) data = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(10)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(10) + ] + ) trainer.extend(ppe.training.extensions.snapshot()) trainer.run(data) @@ -333,8 +457,7 @@ def test_trainer_autoload(self, path): # This forces engine initialization new_trainer._setup_manager(len(data)) assert new_trainer.epoch == 20 - assert _compare_states( - trainer.state_dict(), new_trainer.state_dict()) + assert _compare_states(trainer.state_dict(), new_trainer.state_dict()) class MyModelWithLossDictOutput(torch.nn.Module): @@ -344,32 +467,48 @@ def __init__(self, model): def forward(self, x, t): y = self.model(x) - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" loss = F.l1_loss(y, t) - ppe.reporting.report({prefix + '/loss': loss}) - return {'y': y, 'loss': loss} + ppe.reporting.report({prefix + "/loss": loss}) + return {"y": y, "loss": loss} -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('progress_bar', [True, False]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("progress_bar", [True, False]) def test_trainer_dict_input(device, progress_bar, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() ppe.to(model, device) model_with_loss = MyModelWithLossDictOutput(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)]) + [ + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar) + model_with_loss, device=device, progress_bar=progress_bar + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=evaluator, extensions=extensions, - out_dir=path + model_with_loss, + optimizer, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, + out_dir=path, ) trainer.run(data, data) @@ -393,65 +532,103 @@ def __init__(self, model): def forward(self, input): y = self.model(input.x) - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" loss = F.l1_loss(y, input.t) - ppe.reporting.report({prefix + '/loss': loss}) + ppe.reporting.report({prefix + "/loss": loss}) return Output(y, loss, input.v) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('progress_bar', [True, False]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("progress_bar", [True, False]) def test_trainer_namedtuple_input(device, progress_bar, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() ppe.to(model, device) model_with_loss = ModelNamedTupleIO(model) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [Input(torch.rand(20,), torch.rand(10,), str(i)) for i in range(10)]) + [ + Input( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + str(i), + ) + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar) + model_with_loss, device=device, progress_bar=progress_bar + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=evaluator, extensions=extensions, - out_dir=path + model_with_loss, + optimizer, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, + out_dir=path, ) trainer.run(data, data) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('progress_bar', [True, False]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("progress_bar", [True, False]) def test_trainer_with_code_block(device, progress_bar, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() model_with_loss = MyModelWithLossDictOutput(model) ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)]) + [ + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar, - logic=ppe.handler.CodeBlockLogic()) + model_with_loss, + device=device, + progress_bar=progress_bar, + logic=ppe.handler.CodeBlockLogic(), + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=evaluator, extensions=extensions, - out_dir=path, logic=ppe.handler.CodeBlockLogic() + model_with_loss, + optimizer, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, + out_dir=path, + logic=ppe.handler.CodeBlockLogic(), ) trainer.run(data, data) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('progress_bar', [True, False]) -def test_trainer_with_code_block_with_multiple_optimizers(device, progress_bar, path): - if not torch.cuda.is_available() and device == 'cuda': +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("progress_bar", [True, False]) +def test_trainer_with_code_block_with_multiple_optimizers( + device, progress_bar, path +): + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() model_with_loss = MyModelWithLossDictOutput(model) @@ -459,37 +636,66 @@ def test_trainer_with_code_block_with_multiple_optimizers(device, progress_bar, optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1) optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)]) + [ + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar, - logic=ppe.handler.CodeBlockLogic()) + model_with_loss, + device=device, + progress_bar=progress_bar, + logic=ppe.handler.CodeBlockLogic(), + ) trainer = engine.create_trainer( - model_with_loss, {"0": optimizer0, "1": optimizer1}, 20, - device=device, evaluator=evaluator, extensions=extensions, - out_dir=path, logic=ppe.handler.CodeBlockLogic() + model_with_loss, + {"0": optimizer0, "1": optimizer1}, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, + out_dir=path, + logic=ppe.handler.CodeBlockLogic(), ) trainer.run(data, data) @pytest.mark.skipif( - os.name == 'nt' and not ppe.requires("1.9"), - reason='torch.profiler.profile is not supported.', + os.name == "nt" and not ppe.requires("1.9"), + reason="torch.profiler.profile is not supported.", ) def test_trainer_profile(): - device = 'cpu' + device = "cpu" model = MyModel() model_with_loss = MyModelWithLossDictOutput(model) ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)]) + [ + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ] + ) extensions = _make_extensions() - evaluator = engine.create_evaluator( - model_with_loss, device=device) + evaluator = engine.create_evaluator(model_with_loss, device=device) trace_handler = mock.Mock() warmup = 1 @@ -500,35 +706,58 @@ def test_trainer_profile(): schedule=torch.profiler.schedule(wait=0, warmup=warmup, active=active), ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=evaluator, extensions=extensions, + model_with_loss, + optimizer, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, profile=profile, ) trainer.run(data, data) assert trace_handler.call_count == 20 # n_epochs -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('progress_bar', [True, False]) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("progress_bar", [True, False]) def test_trainer_with_clousure_logic(device, progress_bar, path): - if not torch.cuda.is_available() and device == 'cuda': + if not torch.cuda.is_available() and device == "cuda": pytest.skip() model = MyModel() model_with_loss = MyModelWithLossFn(model) ppe.to(model_with_loss, device) optimizer = torch.optim.SGD(model.parameters(), lr=0.1) data = torch.utils.data.DataLoader( - [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)]) + [ + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ] + ) extensions = _make_extensions() evaluator = engine.create_evaluator( - model_with_loss, device=device, progress_bar=progress_bar, - logic=ppe.handler.ClousureLogic()) + model_with_loss, + device=device, + progress_bar=progress_bar, + logic=ppe.handler.ClousureLogic(), + ) trainer = engine.create_trainer( - model_with_loss, optimizer, 20, - device=device, evaluator=evaluator, extensions=extensions, - out_dir=path, logic=ppe.handler.ClousureLogic(options={"backward_outputs": ["loss"]}) + model_with_loss, + optimizer, + 20, + device=device, + evaluator=evaluator, + extensions=extensions, + out_dir=path, + logic=ppe.handler.ClousureLogic(options={"backward_outputs": ["loss"]}), ) trainer.run(data, data) @@ -544,11 +773,16 @@ def test_trainer_with_autocast(path): extensions = [] autocast_options = {"autocast": True} evaluator = engine.create_evaluator( - model_with_loss, device="cuda", - options=autocast_options) + model_with_loss, device="cuda", options=autocast_options + ) engine.create_trainer( - model_with_loss, optimizer, 20, - device="cuda", evaluator=evaluator, extensions=extensions, - out_dir=path, options=autocast_options + model_with_loss, + optimizer, + 20, + device="cuda", + evaluator=evaluator, + extensions=extensions, + out_dir=path, + options=autocast_options, ) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py index 700c0e478..1ca566a8e 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py @@ -1,34 +1,38 @@ import pytest - from pytorch_pfn_extras import training -from pytorch_pfn_extras.training import _trigger_util -from pytorch_pfn_extras.training import triggers +from pytorch_pfn_extras.training import _trigger_util, triggers @pytest.mark.parametrize( - 'iters_per_epoch,trigger_args,expected', + "iters_per_epoch,trigger_args,expected", [ # Never fire trigger (2, None, [False, False, False, False, False, False, False]), - # Interval trigger - (2, (2, 'iteration'), - [False, True, False, True, False, True, False]), - (2, (2, 'epoch'), - [False, False, False, True, False, False, False]), - + (2, (2, "iteration"), [False, True, False, True, False, True, False]), + (2, (2, "epoch"), [False, False, False, True, False, False, False]), # Callable object - (2, _trigger_util.get_trigger(None), - [False, False, False, False, False, False, False]), - (2, triggers.IntervalTrigger(2, 'iteration'), - [False, True, False, True, False, True, False]), - (2, (lambda trainer: trainer.iteration == 3), - [False, False, True, False, False, False, False]), - ] + ( + 2, + _trigger_util.get_trigger(None), + [False, False, False, False, False, False, False], + ), + ( + 2, + triggers.IntervalTrigger(2, "iteration"), + [False, True, False, True, False, True, False], + ), + ( + 2, + (lambda trainer: trainer.iteration == 3), + [False, False, True, False, False, False, False], + ), + ], ) def test_get_trigger(iters_per_epoch, trigger_args, expected): trainer = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = _trigger_util.get_trigger(trigger_args) # before the first iteration, trigger should be False diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py index 9d8e1e16d..10c539aa6 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py @@ -1,12 +1,10 @@ import numpy -import torch - import pytorch_pfn_extras as ppe +import torch def _test_trigger(trigger, key, accuracies, expected): - manager = ppe.training.ExtensionsManager( - {}, [], 100, iters_per_epoch=1) + manager = ppe.training.ExtensionsManager({}, [], 100, iters_per_epoch=1) for a, e in zip(accuracies, expected): with manager.run_iteration(): pass @@ -15,56 +13,59 @@ def _test_trigger(trigger, key, accuracies, expected): def test_early_stopping_trigger_with_accuracy(): - key = 'main/accuracy' + key = "main/accuracy" trigger = ppe.training.triggers.EarlyStoppingTrigger( - monitor=key, - patience=3, - check_trigger=(1, 'epoch'), - verbose=False) + monitor=key, patience=3, check_trigger=(1, "epoch"), verbose=False + ) accuracies = [ torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) - for acc in [0.5, 0.5, 0.6, 0.7, 0.6, 0.4, 0.3, 0.2]] + for acc in [0.5, 0.5, 0.6, 0.7, 0.6, 0.4, 0.3, 0.2] + ] expected = [False, False, False, False, False, False, True, True] _test_trigger(trigger, key, accuracies, expected) def test_early_stopping_trigger_with_loss(): - key = 'main/loss' + key = "main/loss" trigger = ppe.training.triggers.EarlyStoppingTrigger( - monitor=key, - patience=3, - check_trigger=(1, 'epoch')) + monitor=key, patience=3, check_trigger=(1, "epoch") + ) accuracies = [ torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) - for acc in [100, 80, 30, 10, 20, 24, 30, 35]] + for acc in [100, 80, 30, 10, 20, 24, 30, 35] + ] expected = [False, False, False, False, False, False, True, True] _test_trigger(trigger, key, accuracies, expected) def test_early_stopping_trigger_with_max_epoch(): - key = 'main/loss' + key = "main/loss" trigger = ppe.training.triggers.EarlyStoppingTrigger( monitor=key, patience=3, - check_trigger=(1, 'epoch'), - max_trigger=(3, 'epoch')) + check_trigger=(1, "epoch"), + max_trigger=(3, "epoch"), + ) accuracies = [ torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) - for acc in [100, 80, 30]] + for acc in [100, 80, 30] + ] expected = [False, False, True] _test_trigger(trigger, key, accuracies, expected) def test_early_stopping_trigger_with_max_iteration(): - key = 'main/loss' + key = "main/loss" trigger = ppe.training.triggers.EarlyStoppingTrigger( monitor=key, patience=3, - check_trigger=(1, 'epoch'), - max_trigger=(3, 'iteration')) + check_trigger=(1, "epoch"), + max_trigger=(3, "iteration"), + ) accuracies = [ torch.Tensor(numpy.asarray(acc, dtype=numpy.float32)) - for acc in [100, 80, 30]] + for acc in [100, 80, 30] + ] expected = [False, False, True] _test_trigger(trigger, key, accuracies, expected) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py index 9f4630434..bc203b765 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py @@ -1,24 +1,22 @@ import pytest - from pytorch_pfn_extras import training from pytorch_pfn_extras.training import triggers - _argvalues = [ # iteration - (5, (2, 'iteration'), [False, True, False, True, False, True, False], 4), + (5, (2, "iteration"), [False, True, False, True, False, True, False], 4), # basic epoch - (1, (3, 'epoch'), [False, False, True, False, False, True, False], 4), + (1, (3, "epoch"), [False, False, True, False, False, True, False], 4), # fractional epoch - (2, (1.5, 'epoch'), [False, False, True, False, False, True, False], 4), + (2, (1.5, "epoch"), [False, False, True, False, False, True, False], 4), ] -@pytest.mark.parametrize( - 'iters_per_epoch,interval,expected,resume', _argvalues) +@pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues) def test_trigger(iters_per_epoch, interval, expected, resume): trainer = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = triggers.IntervalTrigger(*interval) for e in expected: @@ -28,11 +26,11 @@ def test_trigger(iters_per_epoch, interval, expected, resume): assert trigger(trainer) == e -@pytest.mark.parametrize( - 'iters_per_epoch,interval,expected,resume', _argvalues) +@pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues) def test_resumed_trigger(iters_per_epoch, interval, expected, resume): trainer = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = triggers.IntervalTrigger(*interval) for e in expected[:resume]: @@ -52,12 +50,11 @@ def test_resumed_trigger(iters_per_epoch, interval, expected, resume): assert new_trigger(trainer) == e -@pytest.mark.parametrize( - 'iters_per_epoch,interval,expected,resume', _argvalues) +@pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues) def test_str(iters_per_epoch, interval, expected, resume): trigger = triggers.IntervalTrigger(*interval) - expected = 'IntervalTrigger({}, \'{}\')'.format(*interval) + expected = "IntervalTrigger({}, '{}')".format(*interval) actual = str(trigger) assert expected == actual, 'Expected "{}" == "{}"'.format(expected, actual) @@ -65,4 +62,4 @@ def test_str(iters_per_epoch, interval, expected, resume): def test_invalid_unit(): with pytest.raises(ValueError): - triggers.IntervalTrigger(1, 'day') + triggers.IntervalTrigger(1, "day") diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py index b72ca885c..ec307f43d 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py @@ -1,7 +1,6 @@ import pytest - -from pytorch_pfn_extras.training import triggers from pytorch_pfn_extras import training +from pytorch_pfn_extras.training import triggers def _test_trigger(manager, trigger, key, accuracies, expected): @@ -17,80 +16,136 @@ def _compare(best_value, new_value): _trigger_test_params = [ # interval = 1 iterations - (triggers.MaxValueTrigger, ((1, 'iteration'),), 1, - [0.5, 0.5, 0.4, 0.6], [True, False, False, True], 1), - (triggers.MinValueTrigger, ((1, 'iteration'),), 1, - [0.5, 0.5, 0.4, 0.6], [True, False, True, False], 1), + ( + triggers.MaxValueTrigger, + ((1, "iteration"),), + 1, + [0.5, 0.5, 0.4, 0.6], + [True, False, False, True], + 1, + ), + ( + triggers.MinValueTrigger, + ((1, "iteration"),), + 1, + [0.5, 0.5, 0.4, 0.6], + [True, False, True, False], + 1, + ), # interval = 2 iterations - (triggers.MaxValueTrigger, ((2, 'iteration'),), 1, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, False, False, True], 2), - (triggers.MinValueTrigger, ((2, 'iteration'),), 1, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, True, False, False], 2), + ( + triggers.MaxValueTrigger, + ((2, "iteration"),), + 1, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, False, False, True], + 2, + ), + ( + triggers.MinValueTrigger, + ((2, "iteration"),), + 1, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, True, False, False], + 2, + ), # interval = 2 iterations, unaligned resume - (triggers.MaxValueTrigger, ((2, 'iteration'),), 1, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, False, False, True], 3), - (triggers.MinValueTrigger, ((2, 'iteration'),), 1, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, True, False, False], 3), + ( + triggers.MaxValueTrigger, + ((2, "iteration"),), + 1, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, False, False, True], + 3, + ), + ( + triggers.MinValueTrigger, + ((2, "iteration"),), + 1, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, True, False, False], + 3, + ), # interval = 1 epoch, 1 epoch = 2 iterations - (triggers.MaxValueTrigger, ((1, 'epoch'),), 2, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, False, False, True], 2), - (triggers.MinValueTrigger, ((1, 'epoch'),), 2, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, True, False, False], 2), + ( + triggers.MaxValueTrigger, + ((1, "epoch"),), + 2, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, False, False, True], + 2, + ), + ( + triggers.MinValueTrigger, + ((1, "epoch"),), + 2, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, True, False, False], + 2, + ), # interval = 1 epoch, 1 epoch = 2 iterations, unaligned resume - (triggers.MaxValueTrigger, ((1, 'epoch'),), 2, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, False, False, True], 3), - (triggers.MinValueTrigger, ((1, 'epoch'),), 2, - [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], - [False, True, False, False, False, True, False, False], 3), - + ( + triggers.MaxValueTrigger, + ((1, "epoch"),), + 2, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, False, False, True], + 3, + ), + ( + triggers.MinValueTrigger, + ((1, "epoch"),), + 2, + [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6], + [False, True, False, False, False, True, False, False], + 3, + ), # best_value trigger test - (triggers.BestValueTrigger, (_compare, (1, 'iteration')), 2, - [0.5, -0.5, -0.6, 0.6, 0.4, -0.4, -0.3, 0.3], - [True, False, False, False, True, False, True, False], 3), + ( + triggers.BestValueTrigger, + (_compare, (1, "iteration")), + 2, + [0.5, -0.5, -0.6, 0.6, 0.4, -0.4, -0.3, 0.3], + [True, False, False, False, True, False, True, False], + 3, + ), ] @pytest.mark.parametrize( - 'trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume', - _trigger_test_params + "trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume", + _trigger_test_params, ) def test_trigger( - trigger_type, trigger_args, iters_per_epoch, accuracies, expected, - resume): - key = 'main/accuracy' + trigger_type, trigger_args, iters_per_epoch, accuracies, expected, resume +): + key = "main/accuracy" manager = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = trigger_type(key, *trigger_args) - _test_trigger( - manager, trigger, key, accuracies, expected) + _test_trigger(manager, trigger, key, accuracies, expected) @pytest.mark.parametrize( - 'trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume', - _trigger_test_params + "trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume", + _trigger_test_params, ) def test_resumed_trigger( - trigger_type, trigger_args, iters_per_epoch, accuracies, expected, - resume): - key = 'main/accuracy' + trigger_type, trigger_args, iters_per_epoch, accuracies, expected, resume +): + key = "main/accuracy" manager = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = trigger_type(key, *trigger_args) - _test_trigger( - manager, trigger, key, accuracies[:resume], - expected[:resume]) + _test_trigger(manager, trigger, key, accuracies[:resume], expected[:resume]) state = trigger.state_dict() new_trigger = trigger_type(key, *trigger_args) new_trigger.load_state_dict(state) _test_trigger( - manager, new_trigger, key, accuracies[resume:], expected[resume:]) + manager, new_trigger, key, accuracies[resume:], expected[resume:] + ) diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py index 5b27a05e4..66404ceb4 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py @@ -1,17 +1,17 @@ import random -import pytest +import pytest import pytorch_pfn_extras as ppe - _parametrize = pytest.mark.parametrize( - 'iters_per_epoch,call_on_resume,resume', + "iters_per_epoch,call_on_resume,resume", [ # basic (5, False, 4), # call on resume (5, True, 4), - ]) + ], +) @_parametrize @@ -20,8 +20,8 @@ def test_trigger(iters_per_epoch, call_on_resume, resume): expected = [True] + [False] * 6 finished = [False] + [True] * 6 manager = ppe.training.ExtensionsManager( - {}, [], 100, - iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = ppe.training.triggers.OnceTrigger(call_on_resume) for e, f in zip(expected, finished): assert trigger.finished == f @@ -38,8 +38,8 @@ def test_resumed_trigger(iters_per_epoch, call_on_resume, resume): expected[resume] = True finished[resume] = False manager = ppe.training.ExtensionsManager( - {}, [], 100, - iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = ppe.training.triggers.OnceTrigger(call_on_resume) for e, f in zip(expected[:resume], finished[:resume]): with manager.run_iteration(): @@ -64,8 +64,8 @@ def test_trigger_sparse_call(iters_per_epoch, call_on_resume, resume): finished = [False] + [True] * 6 for _ in range(10): manager = ppe.training.ExtensionsManager( - {}, [], 100, - iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = ppe.training.triggers.OnceTrigger(call_on_resume) accumulated = False accumulated_finished = True diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py index d06858b4a..9d19eeb20 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py @@ -1,25 +1,30 @@ import pytest - from pytorch_pfn_extras import training from pytorch_pfn_extras.training import triggers - _scheduled_trigger_test_params = [ # single iteration - (2, (2, 'iteration'), - [False, True, False, False, False, False, False], 3), + (2, (2, "iteration"), [False, True, False, False, False, False, False], 3), # multiple iteration - (2, ([2, 4], 'iteration'), - [False, True, False, True, False, False, False], 3), + ( + 2, + ([2, 4], "iteration"), + [False, True, False, True, False, False, False], + 3, + ), # single epoch - (3, (1, 'epoch'), [False, False, True, False, False, False, False], 3), + (3, (1, "epoch"), [False, False, True, False, False, False, False], 3), # multiple epoch - (3, ([1, 2], 'epoch'), [False, False, True, False, False, True, False], 4), + (3, ([1, 2], "epoch"), [False, False, True, False, False, True, False], 4), # single fractional epoch - (2, (1.5, 'epoch'), [False, False, True, False, False, False, False], 4), + (2, (1.5, "epoch"), [False, False, True, False, False, False, False], 4), # multiple fractional epoch - (2, ([1.5, 2.5], 'epoch'), - [False, False, True, False, True, False, False], 4), + ( + 2, + ([1.5, 2.5], "epoch"), + [False, False, True, False, True, False, False], + 4, + ), # TODO(imanishi): Restore these tests after supported. # # single unaligned epoch # (2.5, (1, 'epoch'), [False, False, True, False, False, False, False], 4), @@ -42,30 +47,27 @@ def _test_trigger(trainer, trigger, expected): @pytest.mark.parametrize( - 'iters_per_epoch,schedule,expected,resume', - _scheduled_trigger_test_params + "iters_per_epoch,schedule,expected,resume", _scheduled_trigger_test_params ) def test_trigger(iters_per_epoch, schedule, expected, resume): trainer = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = triggers.ManualScheduleTrigger(*schedule) _test_trigger(trainer, trigger, expected) @pytest.mark.parametrize( - 'iters_per_epoch,schedule,expected,resume', - _scheduled_trigger_test_params + "iters_per_epoch,schedule,expected,resume", _scheduled_trigger_test_params ) -def test_resumed_trigger( - iters_per_epoch, schedule, expected, resume): +def test_resumed_trigger(iters_per_epoch, schedule, expected, resume): trainer = training.ExtensionsManager( - {}, [], 100, iters_per_epoch=iters_per_epoch) + {}, [], 100, iters_per_epoch=iters_per_epoch + ) trigger = triggers.ManualScheduleTrigger(*schedule) - _test_trigger( - trainer, trigger, - expected[:resume]) + _test_trigger(trainer, trigger, expected[:resume]) state = trigger.state_dict() new_trigger = triggers.ManualScheduleTrigger(*schedule) @@ -76,4 +78,4 @@ def test_resumed_trigger( def test_invalid_unit(): with pytest.raises(ValueError): - triggers.ManualScheduleTrigger(1, 'day') + triggers.ManualScheduleTrigger(1, "day") diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py index 6f7c76624..cf3efb694 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py @@ -2,7 +2,6 @@ class DummyTrainer: - def __init__(self): self.elapsed_time = 0 diff --git a/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py b/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py index 8b44c2eb8..0a51a6fe8 100644 --- a/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py +++ b/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py @@ -1,7 +1,6 @@ import pytest -import torch - import pytorch_pfn_extras as ppe +import torch class SubNet(torch.nn.Module): @@ -27,9 +26,9 @@ def __init__(self, checkpoint_type): def forward(self, x): x = self.bn1(self.conv1(x)).relu() - if self.checkpoint_type == 'none': + if self.checkpoint_type == "none": x = self.part1(x) - elif self.checkpoint_type == 'bnaware': + elif self.checkpoint_type == "bnaware": x = ppe.utils.checkpoint.checkpoint(self.part1, x) x = self.part2(x) @@ -58,7 +57,7 @@ def _get_bn_stats_test_checkpoint(cp_type): @pytest.mark.gpu def test_checkpoint(): - baseline = _get_bn_stats_test_checkpoint('none') - ckpt = _get_bn_stats_test_checkpoint('bnaware') + baseline = _get_bn_stats_test_checkpoint("none") + ckpt = _get_bn_stats_test_checkpoint("bnaware") for p_b, p_c in zip(baseline, ckpt): assert torch.allclose(p_b, p_c) diff --git a/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py b/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py index afcac9f0e..9ac934f30 100644 --- a/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py +++ b/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py @@ -1,9 +1,8 @@ import typing import pytest -import torch - import pytorch_pfn_extras as ppe +import torch class Model(torch.nn.Module): @@ -42,13 +41,15 @@ def _get_trainer_with_evaluator(device, ret_val, model_class=Model): optimizer = torch.optim.SGD(model.parameters(), lr=1.0) evaluator = ppe.engine.create_evaluator(model, device=device) trainer = ppe.engine.create_trainer( - model, optimizer, 1, device=device, evaluator=evaluator) + model, optimizer, 1, device=device, evaluator=evaluator + ) return trainer @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_every_iter(engine_fn): engine_cpu = engine_fn("cpu", 1.0) engine_gpu = engine_fn("cuda:0", 1.0) @@ -66,8 +67,9 @@ def test_compare_every_iter(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_comparer_wrong(engine_fn): engine_cpu = engine_fn("cpu", 1.0) engine_gpu = engine_fn("cuda:0", 0.5) @@ -101,8 +103,9 @@ def __call__(self, eng_name_1, eng_name_2, out_name, out_1, out_2): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_comparer_n_iters(engine_fn): engine_cpu = engine_fn("cpu", 1.0) engine_gpu = engine_fn("cuda:0", 1.0) @@ -119,7 +122,8 @@ def test_comparer_n_iters(engine_fn): eval_2 = list(torch.ones(10) for _ in range(10)) comp.compare( {"cpu": (train_1, eval_1), "gpu": (train_2, eval_2)}, - n_iters=n_iters) + n_iters=n_iters, + ) assert comp.compare_fn.times_called == 6 else: comp.compare({"cpu": train_1, "gpu": train_2}, n_iters=n_iters) @@ -127,14 +131,16 @@ def test_comparer_n_iters(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_comparer_kwargs(engine_fn): engine_cpu = engine_fn("cpu", 1.0) engine_gpu = engine_fn("cuda:0", 0.991) compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2) comp = ppe.utils.comparer.OutputsComparer( - {"cpu": engine_cpu, "gpu": engine_gpu}, "a", + {"cpu": engine_cpu, "gpu": engine_gpu}, + "a", compare_fn=compare_fn, ) train_1 = list(torch.ones(10) for _ in range(10)) @@ -150,22 +156,29 @@ def test_comparer_kwargs(engine_fn): @pytest.mark.gpu def test_comparer_incompat_trigger(): model_cpu = Model("cpu", 1.0) - ppe.to(model_cpu, 'cpu') + ppe.to(model_cpu, "cpu") optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=1.0) trainer_cpu = ppe.engine.create_trainer( - model_cpu, optimizer_cpu, 1, device="cpu", + model_cpu, + optimizer_cpu, + 1, + device="cpu", ) model_gpu = Model("cuda:0", 1.0) - ppe.to(model_gpu, 'cuda:0') + ppe.to(model_gpu, "cuda:0") optimizer_gpu = torch.optim.SGD(model_gpu.parameters(), lr=1.0) trainer_gpu = ppe.engine.create_trainer( - model_gpu, optimizer_gpu, 1, device="cuda:0", + model_gpu, + optimizer_gpu, + 1, + device="cuda:0", stop_trigger=(1, "iteration"), ) comp = ppe.utils.comparer.OutputsComparer( - {"cpu": trainer_cpu, "gpu": trainer_gpu}, "a", + {"cpu": trainer_cpu, "gpu": trainer_gpu}, + "a", ) train_1 = list(torch.ones(10) for _ in range(10)) train_2 = list(torch.ones(10) for _ in range(10)) @@ -174,13 +187,15 @@ def test_comparer_incompat_trigger(): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_concurrency(engine_fn): engine_cpu = engine_fn("cpu", 1.0) engine_gpu = engine_fn("cuda:0", 1.0) comp = ppe.utils.comparer.OutputsComparer( - {"cpu": engine_cpu, "gpu": engine_gpu}, "a", + {"cpu": engine_cpu, "gpu": engine_gpu}, + "a", concurrency=1, ) train_1 = list(torch.ones(10) for _ in range(10)) @@ -194,13 +209,15 @@ def test_compare_concurrency(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_concurrency_wrong(engine_fn): engine_cpu = engine_fn("cpu", 1.0) engine_gpu = engine_fn("cuda:0", 0.5) comp = ppe.utils.comparer.OutputsComparer( - {"cpu": engine_cpu, "gpu": engine_gpu}, "a", + {"cpu": engine_cpu, "gpu": engine_gpu}, + "a", concurrency=1, ) train_1 = list(torch.ones(10) for _ in range(10)) @@ -232,22 +249,23 @@ def forward(self, x): def test_model_comparer(): model_cpu = ModelForComparer() model_gpu = ModelForComparer() - ppe.to(model_cpu, 'cpu') - ppe.to(model_gpu, 'cuda:0') + ppe.to(model_cpu, "cpu") + ppe.to(model_gpu, "cuda:0") # Make the models to have the same initial weights model_gpu.load_state_dict(model_cpu.state_dict()) - ppe.to(model_gpu, device='cuda:0') + ppe.to(model_gpu, device="cuda:0") optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01) trainer_cpu = ppe.engine.create_trainer( - model_cpu, optimizer_cpu, 1, device='cpu') + model_cpu, optimizer_cpu, 1, device="cpu" + ) optimizer_gpu = torch.optim.SGD(model_gpu.parameters(), lr=0.01) trainer_gpu = ppe.engine.create_trainer( - model_gpu, optimizer_gpu, 1, device='cuda:0') + model_gpu, optimizer_gpu, 1, device="cuda:0" + ) compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2) comp = ppe.utils.comparer.ModelComparer( - {"cpu": trainer_cpu, "gpu": trainer_gpu}, - compare_fn=compare_fn + {"cpu": trainer_cpu, "gpu": trainer_gpu}, compare_fn=compare_fn ) train_1 = list(torch.ones(2, 10, 10, 10) for _ in range(10)) @@ -259,19 +277,20 @@ def test_model_comparer(): def test_model_comparer_invalid(): model_cpu = ModelForComparer() model_gpu = ModelForComparer() - ppe.to(model_cpu, 'cpu') - ppe.to(model_gpu, device='cuda:0') + ppe.to(model_cpu, "cpu") + ppe.to(model_gpu, device="cuda:0") optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01) trainer_cpu = ppe.engine.create_trainer( - model_cpu, optimizer_cpu, 1, device='cpu') + model_cpu, optimizer_cpu, 1, device="cpu" + ) optimizer_gpu = torch.optim.SGD(model_gpu.parameters(), lr=0.01) trainer_gpu = ppe.engine.create_trainer( - model_gpu, optimizer_gpu, 1, device='cuda:0') + model_gpu, optimizer_gpu, 1, device="cuda:0" + ) compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2) comp = ppe.utils.comparer.ModelComparer( - {"cpu": trainer_cpu, "gpu": trainer_gpu}, - compare_fn=compare_fn + {"cpu": trainer_cpu, "gpu": trainer_gpu}, compare_fn=compare_fn ) train_1 = list(torch.ones(2, 10, 10, 10) for _ in range(10)) @@ -294,12 +313,15 @@ def forward(self, x): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_tuple_output(engine_fn): engine_cpu = engine_fn("cpu", 1.0, model_class=ModelRetTuple) engine_gpu = engine_fn("cuda:0", 1.0, model_class=ModelRetTuple) - comp = ppe.utils.comparer.OutputsComparer({"cpu": engine_cpu, "gpu": engine_gpu}) + comp = ppe.utils.comparer.OutputsComparer( + {"cpu": engine_cpu, "gpu": engine_gpu} + ) train_1 = list(torch.ones(10) for _ in range(10)) train_2 = list(torch.ones(10) for _ in range(10)) if engine_fn is _get_trainer_with_evaluator: @@ -329,12 +351,15 @@ def forward(self, x): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_namedtuple_output(engine_fn): engine_cpu = engine_fn("cpu", 1.0, model_class=ModelRetNamedTuple) engine_gpu = engine_fn("cuda:0", 1.0, model_class=ModelRetNamedTuple) - comp = ppe.utils.comparer.OutputsComparer({"cpu": engine_cpu, "gpu": engine_gpu}) + comp = ppe.utils.comparer.OutputsComparer( + {"cpu": engine_cpu, "gpu": engine_gpu} + ) train_1 = list(torch.ones(10) for _ in range(10)) train_2 = list(torch.ones(10) for _ in range(10)) if engine_fn is _get_trainer_with_evaluator: diff --git a/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py b/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py index 51532836e..bccda9b34 100644 --- a/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py +++ b/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py @@ -2,10 +2,9 @@ import typing import pytest +import pytorch_pfn_extras as ppe import torch import torch.nn.functional as F - -import pytorch_pfn_extras as ppe from pytorch_pfn_extras_tests.runtime_tests.test_jit_runtime import JITRuntime @@ -25,18 +24,28 @@ def forward(self, x): def _get_trainer( - model_class, device, args, loader, *, - seed=0, max_epochs=10, stop_trigger=None): + model_class, + device, + args, + loader, + *, + seed=0, + max_epochs=10, + stop_trigger=None, +): torch.manual_seed(seed) model = model_class(device, *args) ppe.to(model, device) optimizer = torch.optim.SGD(model.parameters(), lr=1.0) trainer = ppe.engine.create_trainer( - model, optimizer, max_epochs, device=device, stop_trigger=stop_trigger) + model, optimizer, max_epochs, device=device, stop_trigger=stop_trigger + ) return trainer, (loader,) -def _get_evaluator(model_class, device, args, loader, *, seed=0, max_epochs=None): +def _get_evaluator( + model_class, device, args, loader, *, seed=0, max_epochs=None +): torch.manual_seed(seed) model = model_class(device, *args) ppe.to(model, device) @@ -45,22 +54,35 @@ def _get_evaluator(model_class, device, args, loader, *, seed=0, max_epochs=None def _get_trainer_with_evaluator( - model_class, device, args, loader, *, - seed=0, max_epochs=10, stop_trigger=None): + model_class, + device, + args, + loader, + *, + seed=0, + max_epochs=10, + stop_trigger=None, +): torch.manual_seed(seed) model = model_class(device, *args) ppe.to(model, device) optimizer = torch.optim.SGD(model.parameters(), lr=1.0) evaluator = ppe.engine.create_evaluator(model, device=device) trainer = ppe.engine.create_trainer( - model, optimizer, max_epochs, device=device, - evaluator=evaluator, stop_trigger=stop_trigger) + model, + optimizer, + max_epochs, + device=device, + evaluator=evaluator, + stop_trigger=stop_trigger, + ) return trainer, (loader, loader) @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_every_epoch(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader) @@ -72,8 +94,9 @@ def test_compare_every_epoch(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_comparer_wrong(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader) @@ -91,7 +114,10 @@ def __init__(self, n_iters=None): self.n_iters = n_iters def __call__(self, eng_name_1, eng_name_2, out_name, out_1, out_2): - assert out_name in ("output:a", "output:iter",) + assert out_name in ( + "output:a", + "output:iter", + ) assert eng_name_1 in ("cpu", "gpu") assert eng_name_1 != eng_name_2 if out_name == "output:iter": @@ -104,16 +130,22 @@ def __call__(self, eng_name_1, eng_name_2, out_name, out_1, out_2): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_trainer_with_evaluator] +) def test_comparer_trigger(engine_fn): n_iters = 3 loader = list(torch.ones(10) for _ in range(10)) - trainer_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader, max_epochs=1) - trainer_gpu, loaders_gpu = engine_fn(Model, "cuda:0", [1.0], loader, max_epochs=1) + trainer_cpu, loaders_cpu = engine_fn( + Model, "cpu", [1.0], loader, max_epochs=1 + ) + trainer_gpu, loaders_gpu = engine_fn( + Model, "cuda:0", [1.0], loader, max_epochs=1 + ) compare_fn = _CustomComparer(n_iters) comp = ppe.utils.comparer.Comparer( - trigger=(n_iters, "iteration"), compare_fn=compare_fn) + trigger=(n_iters, "iteration"), compare_fn=compare_fn + ) comp.add_engine("cpu", trainer_cpu, *loaders_cpu) comp.add_engine("gpu", trainer_gpu, *loaders_gpu) comp.compare() @@ -124,8 +156,9 @@ def test_comparer_trigger(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_comparer_kwargs(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader) @@ -138,13 +171,15 @@ def test_comparer_kwargs(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_trainer_with_evaluator] +) def test_comparer_incompat_trigger(engine_fn): loader = list(torch.ones(10) for _ in range(10)) trainer_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader) - trainer_gpu, loaders_gpu = engine_fn(Model, "cuda:0", [1.0], loader, - stop_trigger=(1, "iteration")) + trainer_gpu, loaders_gpu = engine_fn( + Model, "cuda:0", [1.0], loader, stop_trigger=(1, "iteration") + ) comp = ppe.utils.comparer.Comparer(outputs=["a"]) comp.add_engine("cpu", trainer_cpu, *loaders_cpu) comp.add_engine("gpu", trainer_gpu, *loaders_gpu) @@ -153,8 +188,9 @@ def test_comparer_incompat_trigger(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_concurrency(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader) @@ -166,8 +202,9 @@ def test_compare_concurrency(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_concurrency_wrong(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader) @@ -195,8 +232,9 @@ def forward(self, x): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_model_comparer(engine_fn): loader = list(torch.ones(2, 10, 10, 10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader) @@ -210,12 +248,17 @@ def test_model_comparer(engine_fn): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_model_comparer_invalid(engine_fn): loader = list(torch.ones(2, 10, 10, 10) for _ in range(10)) - engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader, seed=0) - engine_gpu, loaders_gpu = engine_fn(ModelForComparer, "cuda:0", [], loader, seed=1) + engine_cpu, loaders_cpu = engine_fn( + ModelForComparer, "cpu", [], loader, seed=0 + ) + engine_gpu, loaders_gpu = engine_fn( + ModelForComparer, "cuda:0", [], loader, seed=1 + ) comp = ppe.utils.comparer.Comparer(outputs=["a"]) compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2) comp = ppe.utils.comparer.Comparer(compare_fn=compare_fn, params=True) @@ -239,8 +282,9 @@ def forward(self, x): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_tuple_output(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(ModelRetTuple, "cpu", [1.0], loader) @@ -270,12 +314,17 @@ def forward(self, x): @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_namedtuple_output(engine_fn): loader = list(torch.ones(10) for _ in range(10)) - engine_cpu, loaders_cpu = engine_fn(ModelRetNamedTuple, "cpu", [1.0], loader) - engine_gpu, loaders_gpu = engine_fn(ModelRetNamedTuple, "cuda:0", [1.0], loader) + engine_cpu, loaders_cpu = engine_fn( + ModelRetNamedTuple, "cpu", [1.0], loader + ) + engine_gpu, loaders_gpu = engine_fn( + ModelRetNamedTuple, "cuda:0", [1.0], loader + ) comp = ppe.utils.comparer.Comparer() comp.add_engine("cpu", engine_cpu, *loaders_cpu) comp.add_engine("gpu", engine_gpu, *loaders_gpu) @@ -290,76 +339,122 @@ def __init__(self, device, offset): def forward(self, x, t): y = self.model(x) - prefix = 'train' if self.training else 'val' + prefix = "train" if self.training else "val" loss = F.l1_loss(y, t) - ppe.reporting.report({prefix + '/loss': loss}) + ppe.reporting.report({prefix + "/loss": loss}) return loss + self.offset @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_jit_runtime_output_comparer(engine_fn): loader = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn(MyModel, "cpu", [0.0], loader) engine_gpu, loaders_gpu = engine_fn(MyModel, "cuda:0", [0.0], loader) engine_jit, loaders_jit = engine_fn(MyModel, "jit-cpu", [0.0], loader) comp = ppe.utils.comparer.Comparer() - comp.add_engine('cpu', engine_cpu, *loaders_cpu) - comp.add_engine('gpu', engine_gpu, *loaders_gpu) - comp.add_engine('jit', engine_jit, *loaders_jit) + comp.add_engine("cpu", engine_cpu, *loaders_cpu) + comp.add_engine("gpu", engine_gpu, *loaders_gpu) + comp.add_engine("jit", engine_jit, *loaders_jit) comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_jit_runtime_output_comparer_invalid(engine_fn): loader = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn(MyModel, "cpu", [0.0], loader) engine_gpu, loaders_gpu = engine_fn(MyModel, "cuda:0", [0.0], loader) engine_jit, loaders_jit = engine_fn(MyModel, "jit-cpu", [0.5], loader) comp = ppe.utils.comparer.Comparer() - comp.add_engine('cpu', engine_cpu, *loaders_cpu) - comp.add_engine('gpu', engine_gpu, *loaders_gpu) - comp.add_engine('jit', engine_jit, *loaders_jit) + comp.add_engine("cpu", engine_cpu, *loaders_cpu) + comp.add_engine("gpu", engine_gpu, *loaders_gpu) + comp.add_engine("jit", engine_jit, *loaders_jit) with pytest.raises(AssertionError): comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_jit_runtime_model_comparer(engine_fn): - loader = torch.utils.data.DataLoader([torch.rand(20,) for i in range(100)]) + loader = torch.utils.data.DataLoader( + [ + torch.rand( + 20, + ) + for i in range(100) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader) engine_gpu, loaders_gpu = engine_fn(ModelForComparer, "cuda:0", [], loader) engine_jit, loaders_jit = engine_fn(ModelForComparer, "jit-cpu", [], loader) comp = ppe.utils.comparer.Comparer(params=True) - comp.add_engine('cpu', engine_cpu, *loaders_cpu) - comp.add_engine('gpu', engine_gpu, *loaders_gpu) - comp.add_engine('jit', engine_jit, *loaders_jit) + comp.add_engine("cpu", engine_cpu, *loaders_cpu) + comp.add_engine("gpu", engine_gpu, *loaders_gpu) + comp.add_engine("jit", engine_jit, *loaders_jit) comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_jit_runtime_model_comparer_invalid(engine_fn): - loader = torch.utils.data.DataLoader([torch.rand(20,) for i in range(100)]) + loader = torch.utils.data.DataLoader( + [ + torch.rand( + 20, + ) + for i in range(100) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) - engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader, seed=0) - engine_gpu, loaders_gpu = engine_fn(ModelForComparer, "cuda:0", [], loader, seed=0) - engine_jit, loaders_jit = engine_fn(ModelForComparer, "jit-cpu", [], loader, seed=1) + engine_cpu, loaders_cpu = engine_fn( + ModelForComparer, "cpu", [], loader, seed=0 + ) + engine_gpu, loaders_gpu = engine_fn( + ModelForComparer, "cuda:0", [], loader, seed=0 + ) + engine_jit, loaders_jit = engine_fn( + ModelForComparer, "jit-cpu", [], loader, seed=1 + ) comp = ppe.utils.comparer.Comparer(params=True) - comp.add_engine('cpu', engine_cpu, *loaders_cpu) - comp.add_engine('gpu', engine_gpu, *loaders_gpu) - comp.add_engine('jit', engine_jit, *loaders_jit) + comp.add_engine("cpu", engine_cpu, *loaders_cpu) + comp.add_engine("gpu", engine_gpu, *loaders_gpu) + comp.add_engine("jit", engine_jit, *loaders_jit) with pytest.raises(AssertionError): comp.compare() @@ -373,104 +468,156 @@ def __init__(self, device, intermediate_value): def forward(self, x, t): y = self.model(x) for i in range(5): - ppe.utils.comparer.intermediate_value('y', y + self.hidden + i) + ppe.utils.comparer.intermediate_value("y", y + self.hidden + i) loss = F.l1_loss(y, t) return loss @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_intermediate(engine_fn): loader = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn( - ModelForIntermediateValue, "cpu", [10.0], loader) + ModelForIntermediateValue, "cpu", [10.0], loader + ) engine_gpu, loaders_gpu = engine_fn( - ModelForIntermediateValue, "cuda:0", [10.0], loader) + ModelForIntermediateValue, "cuda:0", [10.0], loader + ) engine_jit, loaders_jit = engine_fn( - ModelForIntermediateValue, "jit-cpu", [10.0], loader) + ModelForIntermediateValue, "jit-cpu", [10.0], loader + ) comp = ppe.utils.comparer.Comparer() - comp.add_engine('cpu', engine_cpu, *loaders_cpu) - comp.add_engine('gpu', engine_gpu, *loaders_gpu) - comp.add_engine('jit', engine_jit, *loaders_jit) + comp.add_engine("cpu", engine_cpu, *loaders_cpu) + comp.add_engine("gpu", engine_gpu, *loaders_gpu) + comp.add_engine("jit", engine_jit, *loaders_jit) comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_intermediate_invalid(engine_fn): loader = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(100)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(100) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn( - ModelForIntermediateValue, "cpu", [10.0], loader) + ModelForIntermediateValue, "cpu", [10.0], loader + ) engine_gpu, loaders_gpu = engine_fn( - ModelForIntermediateValue, "cuda:0", [10.0], loader) + ModelForIntermediateValue, "cuda:0", [10.0], loader + ) engine_jit, loaders_jit = engine_fn( - ModelForIntermediateValue, "jit-cpu", [11.1], loader) + ModelForIntermediateValue, "jit-cpu", [11.1], loader + ) comp = ppe.utils.comparer.Comparer() - comp.add_engine('cpu', engine_cpu, *loaders_cpu) - comp.add_engine('gpu', engine_gpu, *loaders_gpu) - comp.add_engine('jit', engine_jit, *loaders_jit) + comp.add_engine("cpu", engine_cpu, *loaders_cpu) + comp.add_engine("gpu", engine_gpu, *loaders_gpu) + comp.add_engine("jit", engine_jit, *loaders_jit) with pytest.raises(AssertionError): comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) @pytest.mark.parametrize("model_class", [MyModel, ModelForIntermediateValue]) @pytest.mark.parametrize("params", [False, True]) def test_dump(engine_fn, model_class, params): loader = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(5)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(5) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn(model_class, "cpu", [1.0], loader) engine_gpu, loaders_gpu = engine_fn(model_class, "cuda:0", [1.0], loader) engine_jit, loaders_jit = engine_fn(model_class, "jit-cpu", [1.0], loader) comp = ppe.utils.comparer.Comparer(params=params) with tempfile.TemporaryDirectory() as tmpdir: - comp.dump(engine_cpu, f'{tmpdir}/cpu', *loaders_cpu) - comp.dump(engine_gpu, f'{tmpdir}/gpu', *loaders_gpu) - comp.dump(engine_jit, f'{tmpdir}/jit', *loaders_jit) - comp.add_dump('cpu', f'{tmpdir}/cpu') - comp.add_dump('gpu', f'{tmpdir}/gpu') - comp.add_dump('jit', f'{tmpdir}/jit') + comp.dump(engine_cpu, f"{tmpdir}/cpu", *loaders_cpu) + comp.dump(engine_gpu, f"{tmpdir}/gpu", *loaders_gpu) + comp.dump(engine_jit, f"{tmpdir}/jit", *loaders_jit) + comp.add_dump("cpu", f"{tmpdir}/cpu") + comp.add_dump("gpu", f"{tmpdir}/gpu") + comp.add_dump("jit", f"{tmpdir}/jit") comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) -@pytest.mark.parametrize("model_class", [ - MyModel, ModelForIntermediateValue]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) +@pytest.mark.parametrize("model_class", [MyModel, ModelForIntermediateValue]) def test_dump_invalid(engine_fn, model_class): loader = torch.utils.data.DataLoader( - [(torch.rand(20,), torch.rand(10,)) for i in range(5)]) + [ + ( + torch.rand( + 20, + ), + torch.rand( + 10, + ), + ) + for i in range(5) + ] + ) ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime) engine_cpu, loaders_cpu = engine_fn(model_class, "cpu", [1.0], loader) engine_gpu, loaders_gpu = engine_fn(model_class, "cuda:0", [1.0], loader) engine_jit, loaders_jit = engine_fn(model_class, "jit-cpu", [2.0], loader) comp = ppe.utils.comparer.Comparer() - with tempfile.TemporaryDirectory() as tmpdir1, \ - tempfile.TemporaryDirectory() as tmpdir2, \ - tempfile.TemporaryDirectory() as tmpdir3: + with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2, tempfile.TemporaryDirectory() as tmpdir3: comp.dump(engine_cpu, tmpdir1, *loaders_cpu) comp.dump(engine_gpu, tmpdir2, *loaders_gpu) comp.dump(engine_jit, tmpdir3, *loaders_jit) - comp.add_dump('cpu', tmpdir1) - comp.add_dump('gpu', tmpdir2) - comp.add_dump('jit', tmpdir3) + comp.add_dump("cpu", tmpdir1) + comp.add_dump("gpu", tmpdir2) + comp.add_dump("jit", tmpdir3) with pytest.raises(AssertionError): comp.compare() @pytest.mark.gpu -@pytest.mark.parametrize("engine_fn", [ - _get_trainer, _get_evaluator, _get_trainer_with_evaluator]) +@pytest.mark.parametrize( + "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator] +) def test_compare_baseline(engine_fn): loader = list(torch.ones(10) for _ in range(10)) engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)