diff --git a/.github/workflows/pretest-and-test.yml b/.github/workflows/pretest-and-test.yml
index 3ccb18b0b..44ebb5722 100644
--- a/.github/workflows/pretest-and-test.yml
+++ b/.github/workflows/pretest-and-test.yml
@@ -38,7 +38,7 @@ jobs:
- name: Code Style
run: |
- pip install pysen black==21.11b1 flake8==4.0.1 isort==5.10.1 mypy==0.991
+ pip install pysen black==23.3.0 flake8==4.0.1 isort==5.10.1 mypy==0.991
pip install types-PyYAML types-setuptools
cp "$(pip show torch | awk '/^Location:/ { print $2 }')/torch/__init__.py" stubs/torch/__init__.py
MYPYPATH="${PWD}/stubs" pysen run lint
diff --git a/docs/source/conf.py b/docs/source/conf.py
index b2ef36240..fd906b89b 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -17,9 +17,9 @@
# -- Project information -----------------------------------------------------
-project = 'pytorch-pfn-extras'
-copyright = '2021, Preferred Networks, Inc.'
-author = 'Preferred Networks, Inc.'
+project = "pytorch-pfn-extras"
+copyright = "2021, Preferred Networks, Inc."
+author = "Preferred Networks, Inc."
# -- General configuration ---------------------------------------------------
@@ -28,15 +28,15 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autosummary',
- 'sphinx.ext.napoleon',
- 'myst_parser',
+ "sphinx.ext.autosummary",
+ "sphinx.ext.napoleon",
+ "myst_parser",
]
autodoc_typehints = "description"
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
@@ -51,9 +51,9 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
-html_theme = 'pydata_sphinx_theme'
+html_theme = "pydata_sphinx_theme"
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
diff --git a/example/ignite-mnist.py b/example/ignite-mnist.py
index c048d5ea5..2d2750102 100644
--- a/example/ignite-mnist.py
+++ b/example/ignite-mnist.py
@@ -1,20 +1,20 @@
from argparse import ArgumentParser
+import pytorch_pfn_extras as ppe
+import pytorch_pfn_extras.training.extensions as extensions
+import torch
+import torch.nn.functional as F
+from ignite.engine import (
+ Events,
+ create_supervised_evaluator,
+ create_supervised_trainer,
+)
+from ignite.metrics import Accuracy, Loss
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
-import torch
-import torch.nn.functional as F
-from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST
-
-from ignite.engine import Events
-from ignite.engine import create_supervised_trainer
-from ignite.engine import create_supervised_evaluator
-from ignite.metrics import Accuracy, Loss
-
-import pytorch_pfn_extras as ppe
-import pytorch_pfn_extras.training.extensions as extensions
+from torchvision.transforms import Compose, Normalize, ToTensor
class Net(nn.Module):
@@ -40,55 +40,75 @@ def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(
- MNIST(download=True, root="../data", transform=data_transform,
- train=True),
- batch_size=train_batch_size, shuffle=True)
+ MNIST(
+ download=True, root="../data", transform=data_transform, train=True
+ ),
+ batch_size=train_batch_size,
+ shuffle=True,
+ )
val_loader = DataLoader(
- MNIST(download=False, root="../data", transform=data_transform,
- train=False),
- batch_size=val_batch_size, shuffle=False)
+ MNIST(
+ download=False,
+ root="../data",
+ transform=data_transform,
+ train=False,
+ ),
+ batch_size=val_batch_size,
+ shuffle=False,
+ )
return train_loader, val_loader
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
train_loader, val_loader = get_data_loaders(
- train_batch_size, val_batch_size)
+ train_batch_size, val_batch_size
+ )
model = Net()
- device = 'cpu'
+ device = "cpu"
if torch.cuda.is_available():
- device = 'cuda:0'
+ device = "cuda:0"
model = model.to(device)
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
optimizer.step()
trainer = create_supervised_trainer(
- model, optimizer, F.nll_loss, device=device)
+ model, optimizer, F.nll_loss, device=device
+ )
evaluator = create_supervised_evaluator(
model,
- metrics={'acc': Accuracy(), 'loss': Loss(F.nll_loss)},
- device=device)
+ metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
+ device=device,
+ )
# manager.extend(...) also works
my_extensions = [
extensions.LogReport(),
extensions.ProgressBar(),
extensions.observe_lr(optimizer=optimizer),
- extensions.ParameterStatistics(model, prefix='model'),
+ extensions.ParameterStatistics(model, prefix="model"),
extensions.VariableStatisticsPlot(model),
extensions.snapshot(),
extensions.IgniteEvaluator(
- evaluator, val_loader, model, progress_bar=True),
- extensions.PlotReport(['train/loss'], 'epoch', filename='loss.png'),
- extensions.PrintReport([
- 'epoch', 'iteration', 'train/loss', 'lr',
- 'model/fc2.bias/grad/min', 'val/loss', 'val/acc',
- ]),
+ evaluator, val_loader, model, progress_bar=True
+ ),
+ extensions.PlotReport(["train/loss"], "epoch", filename="loss.png"),
+ extensions.PrintReport(
+ [
+ "epoch",
+ "iteration",
+ "train/loss",
+ "lr",
+ "model/fc2.bias/grad/min",
+ "val/loss",
+ "val/acc",
+ ]
+ ),
]
- models = {'main': model}
- optimizers = {'main': optimizer}
+ models = {"main": model}
+ optimizers = {"main": optimizer}
manager = ppe.training.IgniteExtensionsManager(
- trainer, models, optimizers, args.epochs,
- extensions=my_extensions)
+ trainer, models, optimizers, args.epochs, extensions=my_extensions
+ )
# Lets load the snapshot
if args.snapshot is not None:
@@ -97,30 +117,57 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
- ppe.reporting.report({'train/loss': engine.state.output})
+ ppe.reporting.report({"train/loss": engine.state.output})
trainer.run(train_loader, max_epochs=epochs)
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument('--batch_size', type=int, default=64,
- help='input batch size for training (default: 64)')
- parser.add_argument('--val_batch_size', type=int, default=1000,
- help='input batch size for validation (default: 1000)')
- parser.add_argument('--epochs', type=int, default=10,
- help='number of epochs to train (default: 10)')
- parser.add_argument('--lr', type=float, default=0.01,
- help='learning rate (default: 0.01)')
- parser.add_argument('--momentum', type=float, default=0.5,
- help='SGD momentum (default: 0.5)')
- parser.add_argument('--log_interval', type=int, default=10,
- help='how many batches to wait before logging '
- 'training status')
- parser.add_argument('--snapshot', type=str, default=None,
- help='path to snapshot file')
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=64,
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--val_batch_size",
+ type=int,
+ default=1000,
+ help="input batch size for validation (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--lr", type=float, default=0.01, help="learning rate (default: 0.01)"
+ )
+ parser.add_argument(
+ "--momentum",
+ type=float,
+ default=0.5,
+ help="SGD momentum (default: 0.5)",
+ )
+ parser.add_argument(
+ "--log_interval",
+ type=int,
+ default=10,
+ help="how many batches to wait before logging " "training status",
+ )
+ parser.add_argument(
+ "--snapshot", type=str, default=None, help="path to snapshot file"
+ )
args = parser.parse_args()
- run(args.batch_size, args.val_batch_size, args.epochs, args.lr,
- args.momentum, args.log_interval)
+ run(
+ args.batch_size,
+ args.val_batch_size,
+ args.epochs,
+ args.lr,
+ args.momentum,
+ args.log_interval,
+ )
diff --git a/example/mnist.py b/example/mnist.py
index cab6b1cb0..ba2ddaa74 100644
--- a/example/mnist.py
+++ b/example/mnist.py
@@ -1,13 +1,13 @@
import argparse
+
+import pytorch_pfn_extras as ppe
+import pytorch_pfn_extras.training.extensions as extensions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
-import pytorch_pfn_extras as ppe
-import pytorch_pfn_extras.training.extensions as extensions
-
class Net(nn.Module):
def __init__(self):
@@ -33,56 +33,98 @@ def train(manager, args, model, device, train_loader):
while not manager.stop_trigger:
model.train()
for _, (data, target) in enumerate(train_loader):
- with manager.run_iteration(step_optimizers=['main']):
+ with manager.run_iteration(step_optimizers=["main"]):
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
- ppe.reporting.report({'train/loss': loss.item()})
+ ppe.reporting.report({"train/loss": loss.item()})
loss.backward()
def test(args, model, device, data, target):
- """ The extension loops over the iterator in order to
- drive the evaluator progress bar and reporting
- averages
+ """The extension loops over the iterator in order to
+ drive the evaluator progress bar and reporting
+ averages
"""
model.eval()
data, target = data.to(device), target.to(device)
output = model(data)
# Final result will be average of averages of the same size
- test_loss = F.nll_loss(output, target, reduction='mean').item()
- ppe.reporting.report({'val/loss': test_loss})
+ test_loss = F.nll_loss(output, target, reduction="mean").item()
+ ppe.reporting.report({"val/loss": test_loss})
pred = output.argmax(dim=1, keepdim=True)
correct = pred.eq(target.view_as(pred)).sum().item()
- ppe.reporting.report({'val/acc': correct / len(data)})
+ ppe.reporting.report({"val/acc": correct / len(data)})
def main():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000,
- metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=10, metavar='N',
- help='number of epochs to train (default: 10)')
- parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
- help='learning rate (default: 0.01)')
- parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
- help='SGD momentum (default: 0.5)')
- parser.add_argument('--no-cuda', dest='cuda',
- action='store_false', default=True,
- help='disables CUDA training')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--save-model', action='store_true', default=False,
- help='For Saving the current Model')
- parser.add_argument('--snapshot', type=str, default=None,
- help='path to snapshot file')
- parser.add_argument('--slack', type=str, default=None,
- help='post to the specified Slack channel')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=0.01,
+ metavar="LR",
+ help="learning rate (default: 0.01)",
+ )
+ parser.add_argument(
+ "--momentum",
+ type=float,
+ default=0.5,
+ metavar="M",
+ help="SGD momentum (default: 0.5)",
+ )
+ parser.add_argument(
+ "--no-cuda",
+ dest="cuda",
+ action="store_false",
+ default=True,
+ help="disables CUDA training",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1,
+ metavar="S",
+ help="random seed (default: 1)",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--snapshot", type=str, default=None, help="path to snapshot file"
+ )
+ parser.add_argument(
+ "--slack",
+ type=str,
+ default=None,
+ help="post to the specified Slack channel",
+ )
args = parser.parse_args()
use_cuda = args.cuda and torch.cuda.is_available()
@@ -90,68 +132,98 @@ def main():
device = torch.device("cuda" if use_cuda else "cpu")
- kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('../data', train=True, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=args.batch_size, shuffle=True,
- **kwargs) # type: ignore[arg-type]
+ datasets.MNIST(
+ "../data",
+ train=True,
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ ),
+ batch_size=args.batch_size,
+ shuffle=True,
+ **kwargs, # type: ignore[arg-type]
+ )
test_loader = torch.utils.data.DataLoader(
- datasets.MNIST('../data', train=False, transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=args.test_batch_size, shuffle=True,
- **kwargs) # type: ignore[arg-type]
+ datasets.MNIST(
+ "../data",
+ train=False,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ ),
+ batch_size=args.test_batch_size,
+ shuffle=True,
+ **kwargs, # type: ignore[arg-type]
+ )
model = Net()
model.to(device)
optimizer = optim.SGD(
- model.parameters(), lr=args.lr, momentum=args.momentum)
+ model.parameters(), lr=args.lr, momentum=args.momentum
+ )
# manager.extend(...) also works
my_extensions = [
extensions.LogReport(),
-
# Enables TensorBoard support.
# Run `tensorboard --logdir runs` to launch the TensorBoard.
extensions.LogReport(
- writer=ppe.writing.TensorBoardWriter(out_dir='runs'),
- trigger=(1, 'iteration')),
+ writer=ppe.writing.TensorBoardWriter(out_dir="runs"),
+ trigger=(1, "iteration"),
+ ),
extensions.ProgressBar(),
extensions.observe_lr(optimizer=optimizer),
- extensions.ParameterStatistics(model, prefix='model'),
+ extensions.ParameterStatistics(model, prefix="model"),
extensions.VariableStatisticsPlot(model),
extensions.Evaluator(
- test_loader, model,
- eval_func=lambda data, target:
- test(args, model, device, data, target),
- progress_bar=True),
+ test_loader,
+ model,
+ eval_func=lambda data, target: test(
+ args, model, device, data, target
+ ),
+ progress_bar=True,
+ ),
extensions.PlotReport(
- ['train/loss', 'val/loss'], 'epoch', filename='loss.png'),
- extensions.PrintReport(['epoch', 'iteration',
- 'train/loss', 'lr', 'model/fc2.bias/grad/min',
- 'val/loss', 'val/acc']),
+ ["train/loss", "val/loss"], "epoch", filename="loss.png"
+ ),
+ extensions.PrintReport(
+ [
+ "epoch",
+ "iteration",
+ "train/loss",
+ "lr",
+ "model/fc2.bias/grad/min",
+ "val/loss",
+ "val/acc",
+ ]
+ ),
extensions.snapshot(),
]
if args.slack is not None:
- my_extensions.append(extensions.Slack(
- channel=args.slack,
- msg='Epoch #{manager.epoch}: val/loss = {val/loss}',
- # Surround the username with <> to mention.
- end_msg='{default}\n<@your_slack_user_name>',
-
- # Upload any artifacts generated during the training.
- filenames=['result/statistics.png'],
- # You can specify when to upload these files.
- # e.g., only at the final epoch:
- # upload_trigger=(args.epochs, 'epoch'),
- ))
+ my_extensions.append(
+ extensions.Slack(
+ channel=args.slack,
+ msg="Epoch #{manager.epoch}: val/loss = {val/loss}",
+ # Surround the username with <> to mention.
+ end_msg="{default}\n<@your_slack_user_name>",
+ # Upload any artifacts generated during the training.
+ filenames=["result/statistics.png"],
+ # You can specify when to upload these files.
+ # e.g., only at the final epoch:
+ # upload_trigger=(args.epochs, 'epoch'),
+ )
+ )
# Custom stop triggers can be added to the manager and
# their status accessed through `manager.stop_trigger`
@@ -159,10 +231,13 @@ def main():
# trigger = ppe.training.triggers.EarlyStoppingTrigger(
# check_trigger=(1, 'epoch'), monitor='val/loss')
manager = ppe.training.ExtensionsManager(
- model, optimizer, args.epochs,
+ model,
+ optimizer,
+ args.epochs,
extensions=my_extensions,
iters_per_epoch=len(train_loader),
- stop_trigger=trigger)
+ stop_trigger=trigger,
+ )
# Lets load the snapshot
if args.snapshot is not None:
state = torch.load(args.snapshot)
@@ -172,9 +247,9 @@ def main():
# to get access to the reporter and other facilities
# test(args, model, device, test_loader)
- if (args.save_model):
+ if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/example/mnist_custom_logic.py b/example/mnist_custom_logic.py
index 9721cf101..4fcb01f3b 100644
--- a/example/mnist_custom_logic.py
+++ b/example/mnist_custom_logic.py
@@ -1,13 +1,13 @@
import argparse
+
+import pytorch_pfn_extras as ppe
+import pytorch_pfn_extras.training.extensions as extensions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
-import pytorch_pfn_extras as ppe
-import pytorch_pfn_extras.training.extensions as extensions
-
class Net(nn.Module):
def __init__(self):
@@ -29,7 +29,6 @@ def forward(self, x):
class CustomLogic(ppe.handler.Logic):
-
def __init__(self, steps_per_update):
self.steps_per_update = steps_per_update
super().__init__()
@@ -58,67 +57,132 @@ def train_step_optimizers(self, models, optimizers, batch_idx):
def main():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000,
- metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=10, metavar='N',
- help='number of epochs to train (default: 10)')
- parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
- help='learning rate (default: 0.01)')
- parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
- help='SGD momentum (default: 0.5)')
- parser.add_argument('--device', type=str, default='cuda',
- help='PyTorch device specifier')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--save-model', action='store_true', default=False,
- help='For Saving the current Model')
- parser.add_argument('--snapshot', type=str, default=None,
- help='path to snapshot file')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=0.01,
+ metavar="LR",
+ help="learning rate (default: 0.01)",
+ )
+ parser.add_argument(
+ "--momentum",
+ type=float,
+ default=0.5,
+ metavar="M",
+ help="SGD momentum (default: 0.5)",
+ )
+ parser.add_argument(
+ "--device", type=str, default="cuda", help="PyTorch device specifier"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1,
+ metavar="S",
+ help="random seed (default: 1)",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--snapshot", type=str, default=None, help="path to snapshot file"
+ )
args = parser.parse_args()
torch.manual_seed(args.seed)
- use_cuda = args.device.startswith('cuda')
+ use_cuda = args.device.startswith("cuda")
- kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('../data', train=True, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=args.batch_size, shuffle=True,
+ datasets.MNIST(
+ "../data",
+ train=True,
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ ),
+ batch_size=args.batch_size,
+ shuffle=True,
collate_fn=ppe.dataloaders.utils.CollateAsDict(
- names=['data', 'target']), **kwargs) # type: ignore[arg-type]
+ names=["data", "target"]
+ ),
+ **kwargs, # type: ignore[arg-type]
+ )
test_loader = torch.utils.data.DataLoader(
- datasets.MNIST('../data', train=False, transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=args.test_batch_size, shuffle=True,
+ datasets.MNIST(
+ "../data",
+ train=False,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ ),
+ batch_size=args.test_batch_size,
+ shuffle=True,
collate_fn=ppe.dataloaders.utils.CollateAsDict(
- names=['data', 'target']), **kwargs) # type: ignore[arg-type]
+ names=["data", "target"]
+ ),
+ **kwargs, # type: ignore[arg-type]
+ )
model = Net()
optimizer = optim.SGD(
- model.parameters(), lr=args.lr, momentum=args.momentum)
+ model.parameters(), lr=args.lr, momentum=args.momentum
+ )
my_extensions = [
extensions.LogReport(),
extensions.ProgressBar(),
extensions.observe_lr(optimizer=optimizer),
- extensions.ParameterStatistics(model, prefix='model'),
+ extensions.ParameterStatistics(model, prefix="model"),
extensions.VariableStatisticsPlot(model),
extensions.PlotReport(
- ['train/loss', 'val/loss'], 'epoch', filename='loss.png'),
- extensions.PrintReport(['epoch', 'iteration',
- 'train/loss', 'lr', 'model/fc2.bias/grad/min',
- 'val/loss', 'val/accuracy']),
+ ["train/loss", "val/loss"], "epoch", filename="loss.png"
+ ),
+ extensions.PrintReport(
+ [
+ "epoch",
+ "iteration",
+ "train/loss",
+ "lr",
+ "model/fc2.bias/grad/min",
+ "val/loss",
+ "val/accuracy",
+ ]
+ ),
extensions.snapshot(),
]
@@ -138,13 +202,13 @@ def forward(self, data, target):
if model.training:
loss = F.nll_loss(output, target)
- ppe.reporting.report({'train/loss': loss.item()})
- return {'loss': loss}
+ ppe.reporting.report({"train/loss": loss.item()})
+ return {"loss": loss}
# Final result will be average of averages of the same size
- test_loss = F.nll_loss(output, target, reduction='mean').item()
+ test_loss = F.nll_loss(output, target, reduction="mean").item()
pred = output.argmax(dim=1, keepdim=True)
- return {'loss': test_loss, 'output': pred}
+ return {"loss": test_loss, "output": pred}
model_with_loss = ModelWithLoss(model)
trainer = ppe.engine.create_trainer(
@@ -158,9 +222,10 @@ def forward(self, data, target):
model_with_loss,
device=args.device,
progress_bar=True,
- metrics=[ppe.training.metrics.AccuracyMetric('target', 'output')],
- options={'eval_report_keys': ['loss', 'accuracy']}),
- options={'train_report_keys': ['loss']},
+ metrics=[ppe.training.metrics.AccuracyMetric("target", "output")],
+ options={"eval_report_keys": ["loss", "accuracy"]},
+ ),
+ options={"train_report_keys": ["loss"]},
logic=CustomLogic(3),
)
@@ -174,9 +239,9 @@ def forward(self, data, target):
trainer.run(train_loader, test_loader)
- if (args.save_model):
+ if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/example/mnist_ddp.py b/example/mnist_ddp.py
index b972be291..abee0502a 100644
--- a/example/mnist_ddp.py
+++ b/example/mnist_ddp.py
@@ -1,14 +1,13 @@
import argparse
+import pytorch_pfn_extras as ppe
+import pytorch_pfn_extras.training.extensions as extensions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
-import pytorch_pfn_extras as ppe
-import pytorch_pfn_extras.training.extensions as extensions
-
class Net(nn.Module):
def __init__(self):
@@ -33,18 +32,18 @@ def train(manager, args, model, device, train_loader):
while not manager.stop_trigger:
model.train()
for _, (data, target) in enumerate(train_loader):
- with manager.run_iteration(step_optimizers=['main']):
+ with manager.run_iteration(step_optimizers=["main"]):
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
- ppe.reporting.report({'train/loss': loss.item()})
+ ppe.reporting.report({"train/loss": loss.item()})
loss.backward()
def test(args, model, device, data, target):
- """ The extension loops over the iterator in order to
- drive the evaluator progress bar and reporting
- averages
+ """The extension loops over the iterator in order to
+ drive the evaluator progress bar and reporting
+ averages
"""
model.eval()
test_loss = 0.0
@@ -52,52 +51,95 @@ def test(args, model, device, data, target):
data, target = data.to(device), target.to(device)
output = model(data)
# Final result will be average of averages of the same size
- test_loss += F.nll_loss(output, target, reduction='mean').item()
- ppe.reporting.report({'val/loss': test_loss})
+ test_loss += F.nll_loss(output, target, reduction="mean").item()
+ ppe.reporting.report({"val/loss": test_loss})
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
- ppe.reporting.report({'val/acc': correct / len(data)})
+ ppe.reporting.report({"val/acc": correct / len(data)})
def init_distributed(use_cuda=True):
# setup env for torch.distributed
- comm_world_size, comm_rank, comm_local_rank = (
- ppe.distributed.initialize_ompi_environment(
- backend="nccl", init_method="env"))
+ (
+ comm_world_size,
+ comm_rank,
+ comm_local_rank,
+ ) = ppe.distributed.initialize_ompi_environment(
+ backend="nccl", init_method="env"
+ )
if comm_rank == 0:
print("World size = {}".format(comm_world_size))
print("Rank = {}, Local Rank = {}".format(comm_rank, comm_local_rank))
torch.cuda.set_device(comm_local_rank)
device = torch.device(
- "cuda:{}".format(comm_local_rank) if use_cuda else "cpu")
+ "cuda:{}".format(comm_local_rank) if use_cuda else "cpu"
+ )
return comm_world_size, comm_rank, comm_local_rank, device
def main():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000,
- metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=10, metavar='N',
- help='number of epochs to train (default: 10)')
- parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
- help='learning rate (default: 0.01)')
- parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
- help='SGD momentum (default: 0.5)')
- parser.add_argument('--no-cuda', dest='cuda',
- action='store_false', default=True,
- help='disables CUDA training')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--save-model', action='store_true', default=False,
- help='For Saving the current Model')
- parser.add_argument('--snapshot', type=str, default=None,
- help='path to snapshot file')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=0.01,
+ metavar="LR",
+ help="learning rate (default: 0.01)",
+ )
+ parser.add_argument(
+ "--momentum",
+ type=float,
+ default=0.5,
+ metavar="M",
+ help="SGD momentum (default: 0.5)",
+ )
+ parser.add_argument(
+ "--no-cuda",
+ dest="cuda",
+ action="store_false",
+ default=True,
+ help="disables CUDA training",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1,
+ metavar="S",
+ help="random seed (default: 1)",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--snapshot", type=str, default=None, help="path to snapshot file"
+ )
args = parser.parse_args()
use_cuda = args.cuda and torch.cuda.is_available()
@@ -106,14 +148,15 @@ def main():
torch.manual_seed(args.seed)
comm_world_size, comm_rank, comm_local_rank, device = init_distributed(
- use_cuda)
+ use_cuda
+ )
if comm_rank == 0:
print("World size = {}".format(comm_world_size))
print("Rank = {}, Local Rank = {}".format(comm_rank, comm_local_rank))
print("Device = {}".format(device))
- kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
- dataset_root = '../data'
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
+ dataset_root = "../data"
if comm_local_rank == 0:
# download mnist
datasets.MNIST(dataset_root, download=True)
@@ -122,37 +165,53 @@ def main():
train_dataset = datasets.MNIST(
dataset_root,
train=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,)),
- ]))
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ )
test_dataset = datasets.MNIST(
dataset_root,
train=False,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,)),
- ]))
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ )
train_sampler = torch.utils.data.DistributedSampler[int](
- train_dataset, num_replicas=comm_world_size, rank=comm_rank)
+ train_dataset, num_replicas=comm_world_size, rank=comm_rank
+ )
train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args.batch_size, sampler=train_sampler,
- **kwargs) # type: ignore[arg-type]
+ train_dataset,
+ batch_size=args.batch_size,
+ sampler=train_sampler,
+ **kwargs, # type: ignore[arg-type]
+ )
test_dataset_indices = list(range(len(test_dataset)))
local_test_dataset_indices = test_dataset_indices[
- comm_rank:len(test_dataset_indices):comm_world_size]
+ comm_rank : len(test_dataset_indices) : comm_world_size
+ ]
local_test_dataset = torch.utils.data.Subset(
- test_dataset, local_test_dataset_indices)
+ test_dataset, local_test_dataset_indices
+ )
test_loader = torch.utils.data.DataLoader(
- local_test_dataset, batch_size=args.test_batch_size, shuffle=True,
- **kwargs) # type: ignore[arg-type]
+ local_test_dataset,
+ batch_size=args.test_batch_size,
+ shuffle=True,
+ **kwargs, # type: ignore[arg-type]
+ )
model = ppe.nn.parallel.DistributedDataParallel(Net().to(device))
optimizer = optim.SGD(
- model.parameters(), lr=args.lr, momentum=args.momentum)
+ model.parameters(), lr=args.lr, momentum=args.momentum
+ )
# manager.extend(...) also works
if comm_local_rank == 0:
@@ -160,19 +219,30 @@ def main():
extensions.LogReport(),
extensions.ProgressBar(),
extensions.observe_lr(optimizer=optimizer),
- extensions.ParameterStatistics(model, prefix='model'),
+ extensions.ParameterStatistics(model, prefix="model"),
extensions.VariableStatisticsPlot(model),
extensions.Evaluator(
- test_loader, model,
- eval_func=lambda data, target:
- test(args, model, device, data, target),
- progress_bar=True),
+ test_loader,
+ model,
+ eval_func=lambda data, target: test(
+ args, model, device, data, target
+ ),
+ progress_bar=True,
+ ),
extensions.PlotReport(
- ['train/loss', 'val/loss'], 'epoch', filename='loss.png'),
- extensions.PrintReport(['epoch', 'iteration',
- 'train/loss', 'lr',
- 'model/fc2.bias/grad/min',
- 'val/loss', 'val/acc']),
+ ["train/loss", "val/loss"], "epoch", filename="loss.png"
+ ),
+ extensions.PrintReport(
+ [
+ "epoch",
+ "iteration",
+ "train/loss",
+ "lr",
+ "model/fc2.bias/grad/min",
+ "val/loss",
+ "val/acc",
+ ]
+ ),
extensions.snapshot(),
]
else:
@@ -184,10 +254,13 @@ def main():
# trigger = ppe.training.triggers.EarlyStoppingTrigger(
# check_trigger=(1, 'epoch'), monitor='val/loss')
manager = ppe.training.ExtensionsManager(
- model, optimizer, args.epochs,
+ model,
+ optimizer,
+ args.epochs,
extensions=my_extensions,
iters_per_epoch=len(train_loader),
- stop_trigger=trigger)
+ stop_trigger=trigger,
+ )
# Lets load the snapshot
if args.snapshot is not None:
state = torch.load(args.snapshot)
@@ -197,12 +270,12 @@ def main():
# to get access to the reporter and other facilities
# test(args, model, device, test_loader)
- if (args.save_model):
+ if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
# Wait for all processes to finish to complete successfully
torch.distributed.barrier()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/example/mnist_trainer.py b/example/mnist_trainer.py
index 55f9b5fcf..f0caff434 100644
--- a/example/mnist_trainer.py
+++ b/example/mnist_trainer.py
@@ -1,15 +1,14 @@
import argparse
import numpy
+import pytorch_pfn_extras as ppe
+import pytorch_pfn_extras.training.extensions as extensions
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
-import pytorch_pfn_extras as ppe
-import pytorch_pfn_extras.training.extensions as extensions
-
class Net(nn.Module):
def __init__(self):
@@ -40,88 +39,169 @@ def forward(self, data, target):
if self.training:
loss = F.nll_loss(output, target)
- ppe.reporting.report({'train/loss': loss.item()})
- return {'loss': loss}
+ ppe.reporting.report({"train/loss": loss.item()})
+ return {"loss": loss}
# Final result will be average of averages of the same size
- test_loss = F.nll_loss(output, target, reduction='mean').item()
+ test_loss = F.nll_loss(output, target, reduction="mean").item()
pred = output.argmax(dim=1, keepdim=True)
- return {'loss': test_loss, 'output': pred}
+ return {"loss": test_loss, "output": pred}
def main():
# Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--batch-size', type=int, default=64, metavar='N',
- help='input batch size for training (default: 64)')
- parser.add_argument('--test-batch-size', type=int, default=1000,
- metavar='N',
- help='input batch size for testing (default: 1000)')
- parser.add_argument('--epochs', type=int, default=10, metavar='N',
- help='number of epochs to train (default: 10)')
- parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
- help='learning rate (default: 0.01)')
- parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
- help='SGD momentum (default: 0.5)')
- parser.add_argument('--device', type=str, default='cuda',
- help='PyTorch device specifier')
- parser.add_argument('--seed', type=int, default=1, metavar='S',
- help='random seed (default: 1)')
- parser.add_argument('--deterministic', action='store_true', default=False,
- help='make the behavior deterministic')
- parser.add_argument('--save-model', action='store_true', default=False,
- help='For Saving the current Model')
- parser.add_argument('--snapshot', type=str, default=None,
- help='path to snapshot file')
- parser.add_argument('--compare-dump', type=str, default=None,
- help='directory to save comparer dump to')
- parser.add_argument('--compare-with', type=str, default=None,
- help='directory to load comparer dump from')
- parser.add_argument('--profiler', type=str, default=None,
- help='output mode for profiler results')
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10,
+ metavar="N",
+ help="number of epochs to train (default: 10)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=0.01,
+ metavar="LR",
+ help="learning rate (default: 0.01)",
+ )
+ parser.add_argument(
+ "--momentum",
+ type=float,
+ default=0.5,
+ metavar="M",
+ help="SGD momentum (default: 0.5)",
+ )
+ parser.add_argument(
+ "--device", type=str, default="cuda", help="PyTorch device specifier"
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=1,
+ metavar="S",
+ help="random seed (default: 1)",
+ )
+ parser.add_argument(
+ "--deterministic",
+ action="store_true",
+ default=False,
+ help="make the behavior deterministic",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--snapshot", type=str, default=None, help="path to snapshot file"
+ )
+ parser.add_argument(
+ "--compare-dump",
+ type=str,
+ default=None,
+ help="directory to save comparer dump to",
+ )
+ parser.add_argument(
+ "--compare-with",
+ type=str,
+ default=None,
+ help="directory to load comparer dump from",
+ )
+ parser.add_argument(
+ "--profiler",
+ type=str,
+ default=None,
+ help="output mode for profiler results",
+ )
args = parser.parse_args()
torch.manual_seed(args.seed)
numpy.random.seed(args.seed)
torch.use_deterministic_algorithms(args.deterministic)
- use_cuda = args.device.startswith('cuda')
+ use_cuda = args.device.startswith("cuda")
- kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('../data', train=True, download=True,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=args.batch_size, shuffle=True,
+ datasets.MNIST(
+ "../data",
+ train=True,
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ ),
+ batch_size=args.batch_size,
+ shuffle=True,
collate_fn=ppe.dataloaders.utils.CollateAsDict(
- names=['data', 'target']), **kwargs) # type: ignore[arg-type]
+ names=["data", "target"]
+ ),
+ **kwargs, # type: ignore[arg-type]
+ )
test_loader = torch.utils.data.DataLoader(
- datasets.MNIST('../data', train=False, transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])),
- batch_size=args.test_batch_size, shuffle=True,
+ datasets.MNIST(
+ "../data",
+ train=False,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ ]
+ ),
+ ),
+ batch_size=args.test_batch_size,
+ shuffle=True,
collate_fn=ppe.dataloaders.utils.CollateAsDict(
- names=['data', 'target']), **kwargs) # type: ignore[arg-type]
+ names=["data", "target"]
+ ),
+ **kwargs, # type: ignore[arg-type]
+ )
model = Net()
optimizer = optim.SGD(
- model.parameters(), lr=args.lr, momentum=args.momentum)
+ model.parameters(), lr=args.lr, momentum=args.momentum
+ )
my_extensions = [
extensions.LogReport(),
extensions.ProgressBar(),
extensions.observe_lr(optimizer=optimizer),
- extensions.ParameterStatistics(model, prefix='model'),
+ extensions.ParameterStatistics(model, prefix="model"),
extensions.VariableStatisticsPlot(model),
extensions.PlotReport(
- ['train/loss', 'val/loss'], 'epoch', filename='loss.png'),
- extensions.PrintReport(['epoch', 'iteration',
- 'train/loss', 'lr', 'model/fc2.bias/grad/min',
- 'val/loss', 'val/accuracy']),
+ ["train/loss", "val/loss"], "epoch", filename="loss.png"
+ ),
+ extensions.PrintReport(
+ [
+ "epoch",
+ "iteration",
+ "train/loss",
+ "lr",
+ "model/fc2.bias/grad/min",
+ "val/loss",
+ "val/accuracy",
+ ]
+ ),
extensions.snapshot(),
]
@@ -133,25 +213,37 @@ def main():
profile = None
if args.profiler is not None:
- if args.profiler == 'tensorboard':
+ if args.profiler == "tensorboard":
+
def callback(prof):
- torch.profiler.tensorboard_trace_handler('./prof') # type: ignore[attr-defined]
- elif args.profiler == 'export_chrome_trace':
+ torch.profiler.tensorboard_trace_handler("./prof") # type: ignore[attr-defined]
+
+ elif args.profiler == "export_chrome_trace":
+
def callback(prof):
- prof.export_chrome_trace('./prof')
- elif args.profiler == 'export_stacks':
+ prof.export_chrome_trace("./prof")
+
+ elif args.profiler == "export_stacks":
+
def callback(prof):
- prof.export_stacks('./prof')
- elif args.profiler == 'to_pickle':
+ prof.export_stacks("./prof")
+
+ elif args.profiler == "to_pickle":
+
def callback(prof):
import pandas as pd
+
df = pd.DataFrame([e.__dict__ for e in prof.events()])
df.to_pickle(f"{trainer.epoch}.pkl")
- elif args.profiler == 'print':
+
+ elif args.profiler == "print":
+
def callback(prof):
table = prof.key_averages().table(
- sort_by="self_cuda_time_total", row_limit=-1)
+ sort_by="self_cuda_time_total", row_limit=-1
+ )
print(table)
+
else:
assert False
profile = torch.profiler.profile( # type: ignore[attr-defined]
@@ -160,7 +252,8 @@ def callback(prof):
torch.profiler.ProfilerActivity.CUDA, # type: ignore[attr-defined]
],
schedule=torch.profiler.schedule( # type: ignore[attr-defined]
- wait=0, warmup=0, active=len(train_loader)),
+ wait=0, warmup=0, active=len(train_loader)
+ ),
on_trace_ready=callback,
)
@@ -176,10 +269,10 @@ def callback(prof):
model_with_loss,
device=args.device,
progress_bar=True,
- metrics=[ppe.training.metrics.AccuracyMetric('target', 'output')],
- options={'eval_report_keys': ['loss', 'accuracy']},
+ metrics=[ppe.training.metrics.AccuracyMetric("target", "output")],
+ options={"eval_report_keys": ["loss", "accuracy"]},
),
- options={'train_report_keys': ['loss']},
+ options={"train_report_keys": ["loss"]},
profile=profile,
)
@@ -194,11 +287,11 @@ def callback(prof):
if args.compare_dump is not None or args.compare_with is not None:
comp = ppe.utils.comparer.Comparer(
compare_fn=ppe.utils.comparer.get_default_comparer(rtol=1e-2),
- outputs=['loss'],
+ outputs=["loss"],
)
if args.compare_dump is None:
# Compare the engine with an existing dump directory.
- comp.add_dump('baseline', args.compare_with)
+ comp.add_dump("baseline", args.compare_with)
comp.add_engine(args.device, trainer, train_loader, test_loader)
comp.compare()
else:
@@ -209,9 +302,9 @@ def callback(prof):
trainer.run(train_loader, test_loader)
- if (args.save_model):
+ if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/pyproject.toml b/pyproject.toml
index e5a6980ea..9fc29385e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,12 +5,16 @@ requires = ["setuptools<64", "wheel"]
version = "0.10.1"
[tool.pysen.lint]
-enable_black = false
+enable_black = true
enable_flake8 = true
-enable_isort = false
+enable_isort = true
enable_mypy = true
mypy_preset = "strict"
line_length = 80
[[tool.pysen.lint.mypy_targets]]
paths = ["pytorch_pfn_extras"]
+
+[tool.pysen.lint.source]
+ includes = ["."]
+ excludes = ["pytorch_pfn_extras/onnx/", "tests/pytorch_pfn_extras_tests/onnx_tests/"]
diff --git a/pytorch_pfn_extras/__init__.py b/pytorch_pfn_extras/__init__.py
index 7dccbf965..6898f25c4 100644
--- a/pytorch_pfn_extras/__init__.py
+++ b/pytorch_pfn_extras/__init__.py
@@ -1,11 +1,12 @@
# Configure the logging before instantiating anything else
from pytorch_pfn_extras import logging # NOQA
+
logging._configure_logging()
from pytorch_pfn_extras import config # NOQA
from pytorch_pfn_extras import cuda # NOQA
-from pytorch_pfn_extras import dataset # NOQA
from pytorch_pfn_extras import dataloaders # NOQA
+from pytorch_pfn_extras import dataset # NOQA
from pytorch_pfn_extras import distributed # NOQA
from pytorch_pfn_extras import engine # NOQA
from pytorch_pfn_extras import handler # NOQA
@@ -16,14 +17,12 @@
from pytorch_pfn_extras import training # NOQA
from pytorch_pfn_extras import utils # NOQA
from pytorch_pfn_extras import writing # NOQA
-
-from pytorch_pfn_extras._tensor import from_ndarray # NOQA
from pytorch_pfn_extras._tensor import as_ndarray # NOQA
-from pytorch_pfn_extras._tensor import get_xp # NOQA
from pytorch_pfn_extras._tensor import as_numpy_dtype # NOQA
+from pytorch_pfn_extras._tensor import from_ndarray # NOQA
from pytorch_pfn_extras._tensor import from_numpy_dtype # NOQA
-from pytorch_pfn_extras.runtime._to import to # NOQA
-from pytorch_pfn_extras.runtime._map import map # NOQA
+from pytorch_pfn_extras._tensor import get_xp # NOQA
from pytorch_pfn_extras._torch_version import requires # NOQA
-
from pytorch_pfn_extras._version import __version__ # NOQA
+from pytorch_pfn_extras.runtime._map import map # NOQA
+from pytorch_pfn_extras.runtime._to import to # NOQA
diff --git a/pytorch_pfn_extras/_cupy/__init__.py b/pytorch_pfn_extras/_cupy/__init__.py
index 8d66a5704..cb0ad364b 100644
--- a/pytorch_pfn_extras/_cupy/__init__.py
+++ b/pytorch_pfn_extras/_cupy/__init__.py
@@ -1,15 +1,18 @@
try:
import cupy # NOQA
+
_cupy_import_error = None
except Exception as e:
from pytorch_pfn_extras._cupy import _cupy_stub as cupy # NOQA
+
_cupy_import_error = e
def ensure_cupy() -> None:
if _cupy_import_error is not None:
raise RuntimeError(
- f'CuPy is not available. Reason:\n{_cupy_import_error}')
+ f"CuPy is not available. Reason:\n{_cupy_import_error}"
+ )
def is_available() -> bool:
diff --git a/pytorch_pfn_extras/_tensor.py b/pytorch_pfn_extras/_tensor.py
index 3c646784f..d94ad3016 100644
--- a/pytorch_pfn_extras/_tensor.py
+++ b/pytorch_pfn_extras/_tensor.py
@@ -1,11 +1,9 @@
+from typing import Any, Dict, Union
+
import numpy
import torch
import torch.utils.dlpack
-from typing import Any, Dict, Union
-
-from pytorch_pfn_extras._cupy import cupy
-from pytorch_pfn_extras._cupy import ensure_cupy
-
+from pytorch_pfn_extras._cupy import cupy, ensure_cupy
_NDArray = Any # TypeVar("_NDArray", numpy.ndarray, cupy.ndarray)
_NumpyDtype = Any # numpy.dtype
@@ -32,8 +30,9 @@ def from_ndarray(ndarray: _NDArray) -> torch.Tensor:
elif isinstance(ndarray, numpy.ndarray):
return torch.from_numpy(_copy_if_negative_strides(ndarray))
raise TypeError(
- 'expected numpy.ndarray or cupy.ndarray '
- f'(got {type(ndarray).__name__})')
+ "expected numpy.ndarray or cupy.ndarray "
+ f"(got {type(ndarray).__name__})"
+ )
def _copy_if_negative_strides(ndarray: _NDArray) -> _NDArray:
@@ -53,18 +52,18 @@ def as_ndarray(tensor: torch.Tensor) -> _NDArray:
cannot be tracked in the computational graph.
"""
devtype = tensor.device.type
- if devtype == 'cpu':
+ if devtype == "cpu":
return tensor.detach().numpy()
- elif devtype == 'cuda':
+ elif devtype == "cuda":
ensure_cupy()
- if hasattr(cupy, 'from_dlpack'):
+ if hasattr(cupy, "from_dlpack"):
# TODO: Avoid using ``torch.utils.dlpack.to_dlpack``.
# => return cupy.from_dlpack(tensor)
# Blocked by PyTorch 1.10 bug
# (https://github.com/pytorch/pytorch/pull/67618)
return cupy.from_dlpack(torch.utils.dlpack.to_dlpack(tensor))
return cupy.fromDlpack(torch.utils.dlpack.to_dlpack(tensor))
- raise ValueError(f'Tensor is on unsupported device: {devtype}')
+ raise ValueError(f"Tensor is on unsupported device: {devtype}")
def get_xp(obj: Union[_NDArray, torch.Tensor]) -> Any:
@@ -78,21 +77,22 @@ def get_xp(obj: Union[_NDArray, torch.Tensor]) -> Any:
elif isinstance(obj, torch.device):
devtype = obj.type
elif isinstance(obj, numpy.ndarray):
- devtype = 'cpu'
+ devtype = "cpu"
elif isinstance(obj, cupy.ndarray):
- devtype = 'cuda'
+ devtype = "cuda"
else:
raise TypeError(
- 'expected torch.Tensor, torch.device, numpy.ndarray, '
- f'or cupy.ndarray (got {type(obj).__name__})')
+ "expected torch.Tensor, torch.device, numpy.ndarray, "
+ f"or cupy.ndarray (got {type(obj).__name__})"
+ )
- if devtype == 'cpu':
+ if devtype == "cpu":
return numpy
- elif devtype == 'cuda':
+ elif devtype == "cuda":
ensure_cupy()
return cupy
- raise ValueError(f'unsupported device type: {devtype}')
+ raise ValueError(f"unsupported device type: {devtype}")
def as_numpy_dtype(torch_dtype: torch.dtype) -> _NumpyDtype:
@@ -107,7 +107,7 @@ def as_numpy_dtype(torch_dtype: torch.dtype) -> _NumpyDtype:
"""
numpy_dtype = _torch_dtype_mapping.get(torch_dtype, None)
if numpy_dtype is None:
- raise TypeError(f'NumPy does not support {torch_dtype} equivalent')
+ raise TypeError(f"NumPy does not support {torch_dtype} equivalent")
return numpy_dtype
@@ -123,27 +123,26 @@ def from_numpy_dtype(numpy_dtype: _NumpyDtype) -> torch.dtype:
"""
torch_dtype = _numpy_dtype_mapping.get(numpy_dtype, None)
if torch_dtype is None:
- raise TypeError(f'PyTorch does not support {numpy_dtype} equivalent')
+ raise TypeError(f"PyTorch does not support {numpy_dtype} equivalent")
return torch_dtype
_torch_dtype_mapping: Dict[torch.dtype, _NumpyDtype] = {
# https://pytorch.org/docs/stable/tensors.html
# https://numpy.org/doc/stable/user/basics.types.html
-
- torch.float32: numpy.dtype('float32'),
- torch.float64: numpy.dtype('float64'),
- torch.float16: numpy.dtype('float16'),
+ torch.float32: numpy.dtype("float32"),
+ torch.float64: numpy.dtype("float64"),
+ torch.float16: numpy.dtype("float16"),
# unsupported: torch.bfloat16
# unsupported: torch.complex32
- torch.complex64: numpy.dtype('complex64'),
- torch.complex128: numpy.dtype('complex128'),
- torch.uint8: numpy.dtype('uint8'),
- torch.int8: numpy.dtype('int8'),
- torch.int16: numpy.dtype('int16'),
- torch.int32: numpy.dtype('int32'),
- torch.int64: numpy.dtype('int64'),
- torch.bool: numpy.dtype('bool'),
+ torch.complex64: numpy.dtype("complex64"),
+ torch.complex128: numpy.dtype("complex128"),
+ torch.uint8: numpy.dtype("uint8"),
+ torch.int8: numpy.dtype("int8"),
+ torch.int16: numpy.dtype("int16"),
+ torch.int32: numpy.dtype("int32"),
+ torch.int64: numpy.dtype("int64"),
+ torch.bool: numpy.dtype("bool"),
}
_numpy_dtype_mapping: Dict[_NumpyDtype, torch.dtype] = {
diff --git a/pytorch_pfn_extras/_torch_version.py b/pytorch_pfn_extras/_torch_version.py
index 695b637cb..0806d34ae 100644
--- a/pytorch_pfn_extras/_torch_version.py
+++ b/pytorch_pfn_extras/_torch_version.py
@@ -2,6 +2,6 @@
from packaging.version import Version
-def requires(version: str, package: str = 'torch') -> bool:
+def requires(version: str, package: str = "torch") -> bool:
pkg_ver = pkg_resources.get_distribution(package).version
return Version(pkg_ver.split("+")[0].split("-")[0]) >= Version(version)
diff --git a/pytorch_pfn_extras/_version.py b/pytorch_pfn_extras/_version.py
index 400a10474..f6104e0c2 100644
--- a/pytorch_pfn_extras/_version.py
+++ b/pytorch_pfn_extras/_version.py
@@ -1 +1 @@
-__version__ = '0.6.7'
+__version__ = "0.6.7"
diff --git a/pytorch_pfn_extras/config.py b/pytorch_pfn_extras/config.py
index adf04b6d0..eafd74835 100644
--- a/pytorch_pfn_extras/config.py
+++ b/pytorch_pfn_extras/config.py
@@ -1,8 +1,16 @@
import json
import os
-from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
import reprlib
-
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
ConfigKey = Tuple[Union[str, int], ...]
AttrKey = Tuple[Union[str, int], ...]
@@ -13,20 +21,21 @@
LoadTrace = Tuple[Tuple[str, ConfigKey], ...]
-def customize_type(**default_kwargs: Any) -> Callable[
- [Callable[..., Any]], Callable[..., Any]]:
+def customize_type(
+ **default_kwargs: Any,
+) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def deco(type_: Callable[..., Any]) -> Callable[..., Any]:
type_._custom_default_kwargs = default_kwargs # type: ignore[attr-defined] # NOQA
return type_
+
return deco
class Config(object):
-
def __init__(
- self,
- config: Any,
- types: Optional[Mapping[str, Callable[..., Any]]] = None,
+ self,
+ config: Any,
+ types: Optional[Mapping[str, Callable[..., Any]]] = None,
) -> None:
self._cache: Dict[KeyPair, Any] = {((), None): config}
self._types = types or {}
@@ -36,21 +45,21 @@ def __getitem__(self, key: str) -> Any:
@classmethod
def load_path(
- cls,
- path: str,
- *,
- loader: Optional[Loader] = None,
- types: Optional[Mapping[str, Callable[..., Any]]] = None,
- ) -> 'Config':
+ cls,
+ path: str,
+ *,
+ loader: Optional[Loader] = None,
+ types: Optional[Mapping[str, Callable[..., Any]]] = None,
+ ) -> "Config":
if loader is None:
loader = _json_loader
return cls(_load(path, loader, ()), types)
def _eval(
- self,
- config_key: ConfigKey,
- attr_key: Optional[AttrKey],
- trace: DumpTrace,
+ self,
+ config_key: ConfigKey,
+ attr_key: Optional[AttrKey],
+ trace: DumpTrace,
) -> Any:
if (config_key, attr_key) in self._cache:
return self._cache[(config_key, attr_key)]
@@ -58,9 +67,7 @@ def _eval(
circular = (config_key, attr_key) in trace
trace = (*trace, (config_key, attr_key))
if circular:
- raise RuntimeError(
- 'Circular dependency',
- _dump_trace(trace))
+ raise RuntimeError("Circular dependency", _dump_trace(trace))
def cache(value: Any) -> Any:
self._cache[(config_key, attr_key)] = value
@@ -69,18 +76,19 @@ def cache(value: Any) -> Any:
if attr_key:
obj = self._eval(config_key, attr_key[:-1], trace)
try:
- if isinstance(attr_key[-1], str) \
- and hasattr(obj, attr_key[-1]):
+ if isinstance(attr_key[-1], str) and hasattr(obj, attr_key[-1]):
return cache(getattr(obj, attr_key[-1]))
else:
return cache(obj[attr_key[-1]])
except Exception as e:
e.args = e.args + (
- '{} not in {} ({})'.format(
+ "{} not in {} ({})".format(
attr_key[-1],
_dump_key(config_key, attr_key[:-1]),
- reprlib.repr(obj)),
- _dump_trace(trace))
+ reprlib.repr(obj),
+ ),
+ _dump_trace(trace),
+ )
raise e
elif attr_key is None:
@@ -89,52 +97,64 @@ def cache(value: Any) -> Any:
return cache(config[config_key[-1]])
except Exception as e:
e.args = e.args + (
- '{} not in {}'.format(
- config_key[-1],
- _dump_key(config_key[:-1], None)),
- _dump_trace(trace))
+ "{} not in {}".format(
+ config_key[-1], _dump_key(config_key[:-1], None)
+ ),
+ _dump_trace(trace),
+ )
raise e
else:
config = self._eval(config_key, None, trace)
if isinstance(config, dict):
- if 'type' in config:
+ if "type" in config:
try:
- type_ = self._types[config['type']]
+ type_ = self._types[config["type"]]
except Exception as e:
e.args = e.args + (
- '{} not in types'.format(config['type']),
- _dump_trace(trace))
+ "{} not in types".format(config["type"]),
+ _dump_trace(trace),
+ )
raise e
else:
type_ = dict
kwargs = {}
for k in config.keys():
- if not k == 'type':
+ if not k == "type":
kwargs[k] = self._eval((*config_key, k), (), trace)
for k, v in getattr(
- type_, '_custom_default_kwargs', {}).items():
+ type_, "_custom_default_kwargs", {}
+ ).items():
if k not in kwargs:
kwargs[k] = self._eval(
- *_parse_key(v, config_key)[:2], trace)
+ *_parse_key(v, config_key)[:2], trace
+ )
try:
return cache(type_(**kwargs))
except Exception as e:
e.args = e.args + (
- '{} ({}) failed with kwargs {}'.format(
- config['type'], type_, reprlib.repr(kwargs)),
- _dump_trace(trace))
+ "{} ({}) failed with kwargs {}".format(
+ config["type"], type_, reprlib.repr(kwargs)
+ ),
+ _dump_trace(trace),
+ )
raise e
elif isinstance(config, list):
- return cache([
- self._eval((*config_key, i), (), trace)
- for i in range(len(config))])
- elif isinstance(config, str) and config.startswith('@'):
- return cache(self._eval(
- *_parse_key(config[1:], config_key[:-1])[:2], trace))
+ return cache(
+ [
+ self._eval((*config_key, i), (), trace)
+ for i in range(len(config))
+ ]
+ )
+ elif isinstance(config, str) and config.startswith("@"):
+ return cache(
+ self._eval(
+ *_parse_key(config[1:], config_key[:-1])[:2], trace
+ )
+ )
else:
return cache(config)
@@ -142,13 +162,12 @@ def update_via_args(self, args: Sequence[Tuple[str, Any]]) -> None:
for k, v in args:
n_k, c_k = _parse_key(k, ())[:2]
if (n_k, c_k) in self._cache:
- if (
- isinstance(self._cache[(n_k, c_k)], bool)
- and isinstance(v, str)
+ if isinstance(self._cache[(n_k, c_k)], bool) and isinstance(
+ v, str
):
if not v.lower() in ("true", "false"):
raise ValueError(
- f'bool should be true/false. Found {v}'
+ f"bool should be true/false. Found {v}"
)
v = v.lower() == "true"
self._cache[(n_k, c_k)] = type(self._cache[(n_k, c_k)])(v)
@@ -157,22 +176,22 @@ def update_via_args(self, args: Sequence[Tuple[str, Any]]) -> None:
def _parse_key(
- key: str, current_config_key: ConfigKey
+ key: str, current_config_key: ConfigKey
) -> Tuple[ConfigKey, Optional[AttrKey], bool]:
- if key.startswith('!'):
+ if key.startswith("!"):
key = key[1:]
escape = True
else:
escape = False
- if key.startswith('/'):
+ if key.startswith("/"):
key = key[1:]
rel = False
else:
rel = True
- config_key_str = key.split('/')
- config_key_str[-1], *attr_key_list = config_key_str[-1].split('.')
+ config_key_str = key.split("/")
+ config_key_str[-1], *attr_key_list = config_key_str[-1].split(".")
config_key = [_parse_k(k) for k in config_key_str]
attr_key: Optional[AttrKey] = tuple(_parse_k(k) for k in attr_key_list)
@@ -186,9 +205,9 @@ def _parse_key(
i = 0
while i < len(config_key):
- if config_key[i] in {'', '.'}:
+ if config_key[i] in {"", "."}:
config_key.pop(i)
- elif config_key[i] == '..':
+ elif config_key[i] == "..":
assert i > 0
config_key.pop(i)
config_key.pop(i - 1)
@@ -207,21 +226,21 @@ def _parse_k(k: str) -> Union[str, int]:
def _dump_key(config_key: ConfigKey, attr_key: Optional[AttrKey]) -> str:
- config_key_str = '/' + '/'.join(str(k) for k in config_key)
+ config_key_str = "/" + "/".join(str(k) for k in config_key)
if attr_key:
- attr_key_str = '.'.join(str(k) for k in attr_key)
- return config_key_str + '.' + attr_key_str
+ attr_key_str = ".".join(str(k) for k in attr_key)
+ return config_key_str + "." + attr_key_str
elif attr_key is None:
- return '!' + config_key_str
+ return "!" + config_key_str
else:
return config_key_str
def _dump_trace(trace: DumpTrace) -> str:
- return ' -> '.join(
- _dump_key(config_key, attr_key)
- for config_key, attr_key in trace)
+ return " -> ".join(
+ _dump_key(config_key, attr_key) for config_key, attr_key in trace
+ )
def _load(path: str, loader: Loader, trace: LoadTrace) -> ConfigType:
@@ -230,31 +249,37 @@ def _load(path: str, loader: Loader, trace: LoadTrace) -> ConfigType:
trace = (*trace, (path, ()))
if circular:
raise RuntimeError(
- 'Circular import',
- ' -> '.join('{} of {}'.format(_dump_key(config_key, None), path)
- for path, config_key in trace))
+ "Circular import",
+ " -> ".join(
+ "{} of {}".format(_dump_key(config_key, None), path)
+ for path, config_key in trace
+ ),
+ )
config = loader(path)
return _expand_import(config, os.path.dirname(path), loader, trace)
def _expand_import(
- config: ConfigType,
- workdir: str,
- loader: Loader,
- trace: LoadTrace,
+ config: ConfigType,
+ workdir: str,
+ loader: Loader,
+ trace: LoadTrace,
) -> ConfigType:
path, config_key = trace[-1]
if isinstance(config, dict):
- config = {k: _expand_import(v, workdir, loader,
- (*trace, (path, (*config_key, k))))
- for k, v in config.items()}
- if 'import' in config:
- path = config['import']
+ config = {
+ k: _expand_import(
+ v, workdir, loader, (*trace, (path, (*config_key, k)))
+ )
+ for k, v in config.items()
+ }
+ if "import" in config:
+ path = config["import"]
if not os.path.isabs(path):
path = os.path.join(workdir, path)
config_orig, config = config, _load(path, loader, trace)
for k, v in config_orig.items():
- if k == 'import':
+ if k == "import":
continue
config_key, attr_key, rel = _parse_key(k, ())
assert attr_key == ()
@@ -267,15 +292,20 @@ def _expand_import(
c[config_key[-1]] = v
except Exception as e:
e.args = e.args + (
- '{} not in {}'.format(
- _dump_key(config_key, attr_key), path),)
+ "{} not in {}".format(
+ _dump_key(config_key, attr_key), path
+ ),
+ )
raise e
return config
elif isinstance(config, list):
- return [_expand_import(v, workdir, loader,
- (*trace, (path, (*config_key, i))))
- for i, v in enumerate(config)]
+ return [
+ _expand_import(
+ v, workdir, loader, (*trace, (path, (*config_key, i)))
+ )
+ for i, v in enumerate(config)
+ ]
else:
return config
diff --git a/pytorch_pfn_extras/config_types.py b/pytorch_pfn_extras/config_types.py
index d5060b727..61cfdf148 100644
--- a/pytorch_pfn_extras/config_types.py
+++ b/pytorch_pfn_extras/config_types.py
@@ -1,14 +1,13 @@
import warnings
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from pytorch_pfn_extras import config
-
if TYPE_CHECKING:
import optuna
-def optuna_types(trial: 'optuna.trial.Trial') -> Dict[str, Any]:
+def optuna_types(trial: "optuna.trial.Trial") -> Dict[str, Any]:
types = {
"optuna_suggest_categorical": trial.suggest_categorical,
"optuna_suggest_discrete_uniform": trial.suggest_discrete_uniform,
@@ -21,15 +20,15 @@ def optuna_types(trial: 'optuna.trial.Trial') -> Dict[str, Any]:
def load_path_with_optuna_types(
- path: str,
- trial: 'optuna.trial.Trial',
- loader: Optional[config.Loader] = None,
- types: Optional[Dict[str, Callable[..., Any]]] = None,
+ path: str,
+ trial: "optuna.trial.Trial",
+ loader: Optional[config.Loader] = None,
+ types: Optional[Dict[str, Callable[..., Any]]] = None,
) -> config.Config:
if types is None:
types = {}
for key, value in optuna_types(trial).items():
if key in types:
- warnings.warn(key + ' is overwritten by optuna suggest.')
+ warnings.warn(key + " is overwritten by optuna suggest.")
types[key] = value
return config.Config.load_path(path, loader=loader, types=types)
diff --git a/pytorch_pfn_extras/cuda/__init__.py b/pytorch_pfn_extras/cuda/__init__.py
index e13ff6773..61642511e 100644
--- a/pytorch_pfn_extras/cuda/__init__.py
+++ b/pytorch_pfn_extras/cuda/__init__.py
@@ -1,3 +1,5 @@
from pytorch_pfn_extras.cuda._allocator import stream # NOQA
from pytorch_pfn_extras.cuda._allocator import use_torch_mempool_in_cupy # NOQA
-from pytorch_pfn_extras.cuda._allocator import use_default_mempool_in_cupy # NOQA
+from pytorch_pfn_extras.cuda._allocator import ( # NOQA
+ use_default_mempool_in_cupy,
+)
diff --git a/pytorch_pfn_extras/cuda/_allocator.py b/pytorch_pfn_extras/cuda/_allocator.py
index e57365b70..12b363013 100644
--- a/pytorch_pfn_extras/cuda/_allocator.py
+++ b/pytorch_pfn_extras/cuda/_allocator.py
@@ -2,10 +2,7 @@
from typing import Any, Generator, Optional
import torch
-
-from pytorch_pfn_extras._cupy import cupy
-from pytorch_pfn_extras._cupy import is_available, ensure_cupy
-
+from pytorch_pfn_extras._cupy import cupy, ensure_cupy, is_available
_allocator = None
@@ -49,7 +46,8 @@ def use_torch_mempool_in_cupy() -> None:
ensure_cupy()
_allocator = cupy.cuda.memory.PythonFunctionAllocator(
- _torch_alloc, _torch_free)
+ _torch_alloc, _torch_free
+ )
cupy.cuda.set_allocator(_allocator.malloc)
@@ -58,11 +56,11 @@ def _torch_alloc(size: int, device_id: int) -> Any:
cupy_stream_ptr = cupy.cuda.get_current_stream().ptr
if torch_stream_ptr != cupy_stream_ptr:
raise RuntimeError(
- 'The current stream set in PyTorch and CuPy must be same.'
- ' Use `pytorch_pfn_extras.cuda.stream` instead of'
- ' `torch.cuda.stream`.')
- return torch.cuda.caching_allocator_alloc(
- size, device_id, torch_stream_ptr)
+ "The current stream set in PyTorch and CuPy must be same."
+ " Use `pytorch_pfn_extras.cuda.stream` instead of"
+ " `torch.cuda.stream`."
+ )
+ return torch.cuda.caching_allocator_alloc(size, device_id, torch_stream_ptr)
def _torch_free(mem_ptr: int, device_id: int) -> None:
diff --git a/pytorch_pfn_extras/dataloaders/__init__.py b/pytorch_pfn_extras/dataloaders/__init__.py
index 8a7549391..7f6abee2f 100644
--- a/pytorch_pfn_extras/dataloaders/__init__.py
+++ b/pytorch_pfn_extras/dataloaders/__init__.py
@@ -1,2 +1,2 @@
-from pytorch_pfn_extras.dataloaders.dataloader import DataLoader # NOQA
from pytorch_pfn_extras.dataloaders import utils # NOQA
+from pytorch_pfn_extras.dataloaders.dataloader import DataLoader # NOQA
diff --git a/pytorch_pfn_extras/dataloaders/utils.py b/pytorch_pfn_extras/dataloaders/utils.py
index 1bc338c03..f50fdbd1b 100644
--- a/pytorch_pfn_extras/dataloaders/utils.py
+++ b/pytorch_pfn_extras/dataloaders/utils.py
@@ -15,9 +15,11 @@ class CollateAsDict:
"""
def __init__(
- self, names: Sequence[str],
- collate_fn: Callable[..., Any] =
- torch.utils.data._utils.collate.default_collate,
+ self,
+ names: Sequence[str],
+ collate_fn: Callable[
+ ..., Any
+ ] = torch.utils.data._utils.collate.default_collate,
) -> None:
self.names = names
self.collate_fn = collate_fn
diff --git a/pytorch_pfn_extras/dataset/__init__.py b/pytorch_pfn_extras/dataset/__init__.py
index fa29a4a3b..b17bfe96e 100644
--- a/pytorch_pfn_extras/dataset/__init__.py
+++ b/pytorch_pfn_extras/dataset/__init__.py
@@ -1,3 +1,7 @@
-from pytorch_pfn_extras.dataset.tabular.tabular_dataset import TabularDataset # NOQA
-from pytorch_pfn_extras.dataset.shared_dataset import ItemNotFoundException # NOQA
from pytorch_pfn_extras.dataset.shared_dataset import SharedDataset # NOQA
+from pytorch_pfn_extras.dataset.shared_dataset import ( # NOQA
+ ItemNotFoundException,
+)
+from pytorch_pfn_extras.dataset.tabular.tabular_dataset import ( # NOQA
+ TabularDataset,
+)
diff --git a/pytorch_pfn_extras/dataset/shared_dataset.py b/pytorch_pfn_extras/dataset/shared_dataset.py
index 48ad8011e..0d91ad5a5 100644
--- a/pytorch_pfn_extras/dataset/shared_dataset.py
+++ b/pytorch_pfn_extras/dataset/shared_dataset.py
@@ -52,10 +52,11 @@ class ItemNotFoundException(Exception):
class SharedDataset(torch.utils.data.Dataset):
- """ Dataset that caches the load samples in shared memory
+ """Dataset that caches the load samples in shared memory
Args
"""
+
def __init__(self, sm_size, cache_type=InfiniteCache):
super().__init__()
self.cache = cache_type(sm_size)
@@ -64,7 +65,8 @@ def __getitem__(self, idx):
x = self.cache.get_value(idx)
if x is None:
raise ItemNotFoundException(
- 'Item {} is not in the cache'.format(idx))
+ "Item {} is not in the cache".format(idx)
+ )
return x
def is_cached(self, idx):
diff --git a/pytorch_pfn_extras/dataset/tabular/__init__.py b/pytorch_pfn_extras/dataset/tabular/__init__.py
index 79f19e53c..49b1898ab 100644
--- a/pytorch_pfn_extras/dataset/tabular/__init__.py
+++ b/pytorch_pfn_extras/dataset/tabular/__init__.py
@@ -4,6 +4,7 @@
from pytorch_pfn_extras.dataset.tabular import _slice # NOQA
from pytorch_pfn_extras.dataset.tabular import _transform # NOQA
from pytorch_pfn_extras.dataset.tabular import _with_converter # NOQA
-
-from pytorch_pfn_extras.dataset.tabular.delegate_dataset import DelegateDataset # NOQA
+from pytorch_pfn_extras.dataset.tabular.delegate_dataset import ( # NOQA
+ DelegateDataset,
+)
from pytorch_pfn_extras.dataset.tabular.from_data import from_data # NOQA
diff --git a/pytorch_pfn_extras/dataset/tabular/_asmode.py b/pytorch_pfn_extras/dataset/tabular/_asmode.py
index 543b61f52..82ae2460e 100644
--- a/pytorch_pfn_extras/dataset/tabular/_asmode.py
+++ b/pytorch_pfn_extras/dataset/tabular/_asmode.py
@@ -4,7 +4,6 @@
class _Astuple(tabular_dataset.TabularDataset):
-
def __init__(self, dataset):
self._dataset = dataset
@@ -27,7 +26,6 @@ def convert(self, data):
class _Asdict(tabular_dataset.TabularDataset):
-
def __init__(self, dataset):
self._dataset = dataset
diff --git a/pytorch_pfn_extras/dataset/tabular/_concat.py b/pytorch_pfn_extras/dataset/tabular/_concat.py
index 3ab36b488..43ef08c60 100644
--- a/pytorch_pfn_extras/dataset/tabular/_concat.py
+++ b/pytorch_pfn_extras/dataset/tabular/_concat.py
@@ -4,11 +4,10 @@
class _Concat(tabular_dataset.TabularDataset):
-
def __init__(self, *datasets):
for dataset in datasets[1:]:
if not dataset.keys == datasets[0].keys:
- raise ValueError('All datasets must have the same keys')
+ raise ValueError("All datasets must have the same keys")
self._datasets = datasets
@@ -32,12 +31,16 @@ def get_examples(self, indices, key_indices):
if indices is None:
examples = [
dataset.get_examples(None, key_indices)
- for dataset in self._datasets]
+ for dataset in self._datasets
+ ]
return tuple(
- [data
- for sub_examples in examples
- for data in sub_examples[col_index]]
- for col_index in range(n_cols))
+ [
+ data
+ for sub_examples in examples
+ for data in sub_examples[col_index]
+ ]
+ for col_index in range(n_cols)
+ )
elif isinstance(indices, slice):
start, stop, step = indices.indices(len(self))
@@ -53,8 +56,9 @@ def get_examples(self, indices, key_indices):
sub_stop = min(sub_stop, len(dataset))
else:
if sub_start >= len(dataset):
- sub_start = \
+ sub_start = (
len(dataset) + (sub_start - len(dataset)) % step
+ )
sub_stop = max(sub_stop, -1)
if len(range(sub_start, sub_stop, step)) > 0:
@@ -62,8 +66,11 @@ def get_examples(self, indices, key_indices):
sub_start = None
if sub_stop < 0 and step < 0:
sub_stop = None
- examples.append(dataset.get_examples(
- slice(sub_start, sub_stop, step), key_indices))
+ examples.append(
+ dataset.get_examples(
+ slice(sub_start, sub_stop, step), key_indices
+ )
+ )
offset += len(dataset)
@@ -75,10 +82,13 @@ def get_examples(self, indices, key_indices):
if step < 0:
examples.reverse()
return tuple(
- [data
- for sub_examples in examples
- for data in sub_examples[col_index]]
- for col_index in range(n_cols))
+ [
+ data
+ for sub_examples in examples
+ for data in sub_examples[col_index]
+ ]
+ for col_index in range(n_cols)
+ )
else:
examples = {}
@@ -90,12 +100,12 @@ def get_examples(self, indices, key_indices):
if index < offset or offset + len(dataset) <= index:
continue
sub_indices.append(index - offset)
- example_indices[p] = (
- dataset_index, len(sub_indices) - 1)
+ example_indices[p] = (dataset_index, len(sub_indices) - 1)
if len(sub_indices) > 0:
examples[dataset_index] = dataset.get_examples(
- sub_indices, key_indices)
+ sub_indices, key_indices
+ )
offset += len(dataset)
@@ -105,9 +115,12 @@ def get_examples(self, indices, key_indices):
return list(examples.values())[0]
else:
return tuple(
- [examples[dataset_index][col_index][p]
- for dataset_index, p in example_indices]
- for col_index in range(n_cols))
+ [
+ examples[dataset_index][col_index][p]
+ for dataset_index, p in example_indices
+ ]
+ for col_index in range(n_cols)
+ )
def convert(self, data):
return self._datasets[0].convert(data)
diff --git a/pytorch_pfn_extras/dataset/tabular/_join.py b/pytorch_pfn_extras/dataset/tabular/_join.py
index f04cbaa2c..e99541f8b 100644
--- a/pytorch_pfn_extras/dataset/tabular/_join.py
+++ b/pytorch_pfn_extras/dataset/tabular/_join.py
@@ -4,14 +4,13 @@
class _Join(tabular_dataset.TabularDataset):
-
def __init__(self, *datasets):
keys = set(datasets[0].keys)
for dataset in datasets[1:]:
if not len(dataset) == len(datasets[0]):
- raise ValueError('All datasets must have the same length')
+ raise ValueError("All datasets must have the same length")
if len(keys.intersection(dataset.keys)) > 0:
- raise ValueError('All keys must be unique among all datasets')
+ raise ValueError("All keys must be unique among all datasets")
keys = keys.union(dataset.keys)
self._datasets = datasets
@@ -35,7 +34,8 @@ def get_examples(self, indices, key_indices):
return tuple(
col
for dataset in self._datasets
- for col in dataset.get_examples(indices, None))
+ for col in dataset.get_examples(indices, None)
+ )
examples = {}
key_offset = 0
@@ -52,7 +52,8 @@ def get_examples(self, indices, key_indices):
sub_key_indices = tuple(sub_key_indices)
sub_examples = dataset.get_examples(indices, sub_key_indices)
for sub_key_index, col_example in zip(
- sub_key_indices, sub_examples):
+ sub_key_indices, sub_examples
+ ):
examples[key_offset + sub_key_index] = col_example
key_offset += len(dataset.keys)
diff --git a/pytorch_pfn_extras/dataset/tabular/_slice.py b/pytorch_pfn_extras/dataset/tabular/_slice.py
index 2eb02c89d..2c3fa222a 100644
--- a/pytorch_pfn_extras/dataset/tabular/_slice.py
+++ b/pytorch_pfn_extras/dataset/tabular/_slice.py
@@ -1,11 +1,9 @@
# mypy: ignore-errors
-from pytorch_pfn_extras.dataset.tabular import tabular_dataset
-from pytorch_pfn_extras.dataset.tabular import _utils
+from pytorch_pfn_extras.dataset.tabular import _utils, tabular_dataset
class _Slice(tabular_dataset.TabularDataset):
-
def __init__(self, dataset, indices, keys):
if keys is None:
self._unary = None
@@ -13,7 +11,7 @@ def __init__(self, dataset, indices, keys):
self._unary = False
else:
self._unary = True
- keys = keys,
+ keys = (keys,)
self._dataset = dataset
self._indices = _utils._as_indices(indices, len(dataset))
@@ -33,8 +31,9 @@ def keys(self):
if self._key_indices is None:
return self._dataset.keys
else:
- return tuple(self._dataset.keys[key_index]
- for key_index in self._key_indices)
+ return tuple(
+ self._dataset.keys[key_index] for key_index in self._key_indices
+ )
@property
def mode(self):
@@ -47,7 +46,8 @@ def mode(self):
def get_examples(self, indices, key_indices):
indices = _utils._merge_indices(
- self._indices, indices, len(self._dataset), len(self))
+ self._indices, indices, len(self._dataset), len(self)
+ )
key_indices = _utils._merge_key_indices(self._key_indices, key_indices)
return self._dataset.get_examples(indices, key_indices)
@@ -56,7 +56,6 @@ def convert(self, data):
class _SliceHelper(object):
-
def __init__(self, dataset):
self._dataset = dataset
diff --git a/pytorch_pfn_extras/dataset/tabular/_transform.py b/pytorch_pfn_extras/dataset/tabular/_transform.py
index 0cb21a367..b4fcc8b9f 100644
--- a/pytorch_pfn_extras/dataset/tabular/_transform.py
+++ b/pytorch_pfn_extras/dataset/tabular/_transform.py
@@ -1,7 +1,6 @@
# mypy: ignore-errors
-from pytorch_pfn_extras.dataset.tabular import tabular_dataset
-from pytorch_pfn_extras.dataset.tabular import _utils
+from pytorch_pfn_extras.dataset.tabular import _utils, tabular_dataset
class _TransformBase(tabular_dataset.TabularDataset):
@@ -12,14 +11,15 @@ def __init__(self, dataset, keys, transforms):
self._transforms = []
for s, t in transforms:
if any(k in key_set for k in s[1]):
- raise ValueError('Transformations must be disjoint')
+ raise ValueError("Transformations must be disjoint")
key_set.update(s[1])
ops_idx = _utils._as_key_indices(s[0], self._dataset.keys)
res_idx = _utils._as_key_indices(s[1], keys)
self._transforms.append(((ops_idx, res_idx), t))
if key_set != set(keys):
raise ValueError(
- 'Transformations must produce only all specified keys')
+ "Transformations must produce only all specified keys"
+ )
self._keys = keys
@@ -70,7 +70,6 @@ def convert(self, data):
class _Transform(_TransformBase):
-
def get_examples(self, indices, key_indices):
if key_indices is None:
key_indices = range(len(self._keys))
@@ -94,9 +93,7 @@ def get_examples(self, indices, key_indices):
out_example = transform(*inputs)
elif self._dataset.mode is dict:
keys = [self._dataset.keys[i] for i in ops_idx]
- out_example = transform(
- **dict(zip(keys, inputs))
- )
+ out_example = transform(**dict(zip(keys, inputs)))
elif self._dataset.mode is None:
out_example = transform(*inputs)
if isinstance(out_example, tuple):
@@ -113,7 +110,8 @@ def get_examples(self, indices, key_indices):
# we are slicing the outputs using key_indices
# the result key index needs to be recalculated
out_examples[key_indices.index(key_index)].append(
- out_example[col_index])
+ out_example[col_index]
+ )
elif isinstance(out_example, dict):
if hasattr(self, "_mode") and self._mode is not dict:
raise ValueError(
@@ -138,7 +136,8 @@ def get_examples(self, indices, key_indices):
if key_index is None:
continue
out_examples[key_indices.index(key_index)].append(
- out_example[col_index])
+ out_example[col_index]
+ )
return out_examples
@@ -147,7 +146,6 @@ def convert(self, data):
class _TransformBatch(_TransformBase):
-
def get_examples(self, indices, key_indices):
if indices is None:
len_ = len(self)
@@ -169,9 +167,7 @@ def get_examples(self, indices, key_indices):
out_example = transform(*inputs)
elif self._dataset.mode is dict:
keys = [self._dataset.keys[i] for i in ops_idx]
- out_example = transform(
- **dict(zip(keys, inputs))
- )
+ out_example = transform(**dict(zip(keys, inputs)))
elif self._dataset.mode is None:
out_example = transform(*inputs)
@@ -192,8 +188,9 @@ def get_examples(self, indices, key_indices):
# all the outputs are covered this works but when
# we are slicing the outputs using key_indices
# the result key index needs to be recalculated
- out_examples[key_indices.index(key_index)] = (
- out_example[col_index])
+ out_examples[key_indices.index(key_index)] = out_example[
+ col_index
+ ]
elif isinstance(out_example, dict):
if hasattr(self, "_mode") and self._mode is not dict:
raise ValueError(
@@ -208,8 +205,9 @@ def get_examples(self, indices, key_indices):
if key_index is None:
continue
key = self._keys[key_index]
- out_examples[key_indices.index(key_index)] = (
- out_example[key])
+ out_examples[key_indices.index(key_index)] = out_example[
+ key
+ ]
else:
if hasattr(self, "_mode") and self._mode is not None:
raise ValueError(
@@ -224,6 +222,7 @@ def get_examples(self, indices, key_indices):
for col_index, key_index in enumerate(t_res_idx):
if key_index is None:
continue
- out_examples[key_indices.index(key_index)] = (
- out_example[col_index])
+ out_examples[key_indices.index(key_index)] = out_example[
+ col_index
+ ]
return tuple(out_examples)
diff --git a/pytorch_pfn_extras/dataset/tabular/_utils.py b/pytorch_pfn_extras/dataset/tabular/_utils.py
index ed551d7fd..1a4dddf1a 100644
--- a/pytorch_pfn_extras/dataset/tabular/_utils.py
+++ b/pytorch_pfn_extras/dataset/tabular/_utils.py
@@ -11,8 +11,10 @@ def _as_indices(indices, len_):
if all(isinstance(index, (bool, np.bool_)) for index in indices):
if not len(indices) == len_:
- raise ValueError('The number of booleans is '
- 'different from the length of dataset')
+ raise ValueError(
+ "The number of booleans is "
+ "different from the length of dataset"
+ )
return [i for i, index in enumerate(indices) if index]
else:
checked_indices = []
@@ -22,8 +24,10 @@ def _as_indices(indices, len_):
index += len_
if index < 0 or len_ <= index:
raise IndexError(
- 'index {} is out of bounds for dataset with size {}'
- .format(index, len_))
+ "index {} is out of bounds for dataset with size {}".format(
+ index, len_
+ )
+ )
checked_indices.append(index)
return checked_indices
@@ -40,13 +44,15 @@ def _as_key_indices(keys, key_names):
key_index += len(key_names)
if key_index < 0 or len(key_names) <= key_index:
raise IndexError(
- 'index {} is out of bounds for keys with size {}'.format(
- key, len(key_names)))
+ "index {} is out of bounds for keys with size {}".format(
+ key, len(key_names)
+ )
+ )
else:
try:
key_index = key_names.index(key)
except ValueError:
- raise KeyError('{} does not exists'.format(key))
+ raise KeyError("{} does not exists".format(key))
key_indices.append(key_index)
return tuple(key_indices)
diff --git a/pytorch_pfn_extras/dataset/tabular/_with_converter.py b/pytorch_pfn_extras/dataset/tabular/_with_converter.py
index c8d7b67e4..4782a5c52 100644
--- a/pytorch_pfn_extras/dataset/tabular/_with_converter.py
+++ b/pytorch_pfn_extras/dataset/tabular/_with_converter.py
@@ -4,7 +4,6 @@
class _WithConverter(tabular_dataset.TabularDataset):
-
def __init__(self, dataset, converter):
self._dataset = dataset
self._converter = converter
diff --git a/pytorch_pfn_extras/dataset/tabular/from_data.py b/pytorch_pfn_extras/dataset/tabular/from_data.py
index 165199b74..6d7c499ab 100644
--- a/pytorch_pfn_extras/dataset/tabular/from_data.py
+++ b/pytorch_pfn_extras/dataset/tabular/from_data.py
@@ -93,20 +93,20 @@ def from_data(data, *, size=None):
def _make_dataset(key, data, size):
if isinstance(data, (numpy.ndarray, torch.Tensor)):
if key is None:
- key = '_{}'.format(id(data))
+ key = "_{}".format(id(data))
return _Array(key, data)
elif isinstance(data, list):
if key is None:
- key = '_{}'.format(id(data))
+ key = "_{}".format(id(data))
return _List(key, data)
elif callable(data):
if key is None:
- raise ValueError('key(s) must be specified for callable')
+ raise ValueError("key(s) must be specified for callable")
if size is None:
- raise ValueError('size must be specified for callable')
+ raise ValueError("size must be specified for callable")
dataset = _Index(size)
if isinstance(key, str):
- key = key,
+ key = (key,)
if not isinstance(key, tuple):
key = tuple(key)
data = [((dataset.keys, key), data)]
@@ -114,7 +114,6 @@ def _make_dataset(key, data, size):
class _Array(tabular_dataset.TabularDataset):
-
def __init__(self, key, data):
self._key = key
self._data = data
@@ -124,7 +123,7 @@ def __len__(self):
@property
def keys(self):
- return self._key,
+ return (self._key,)
@property
def mode(self):
@@ -132,7 +131,7 @@ def mode(self):
def get_examples(self, indices, key_indices):
if key_indices is None:
- key_indices = 0,
+ key_indices = (0,)
if indices is None:
return (self._data,) * len(key_indices)
@@ -141,7 +140,6 @@ def get_examples(self, indices, key_indices):
class _List(tabular_dataset.TabularDataset):
-
def __init__(self, key, data):
self._key = key
self._data = data
@@ -151,7 +149,7 @@ def __len__(self):
@property
def keys(self):
- return self._key,
+ return (self._key,)
@property
def mode(self):
@@ -159,19 +157,19 @@ def mode(self):
def get_examples(self, indices, key_indices):
if key_indices is None:
- key_indices = 0,
+ key_indices = (0,)
if indices is None:
return (self._data,) * len(key_indices)
elif isinstance(indices, slice):
return (self._data[indices],) * len(key_indices)
else:
- return ([self._data[index] for index in indices],) \
- * len(key_indices)
+ return ([self._data[index] for index in indices],) * len(
+ key_indices
+ )
class _Index(tabular_dataset.TabularDataset):
-
def __init__(self, size):
self._len = size
@@ -180,7 +178,7 @@ def __len__(self):
@property
def keys(self):
- return 'index',
+ return ("index",)
@property
def mode(self):
@@ -194,6 +192,6 @@ def get_examples(self, indices, key_indices):
indices = list(range(start, stop, step))
if key_indices is None:
- key_indices = 0,
+ key_indices = (0,)
return (indices,) * len(key_indices)
diff --git a/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py b/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py
index c003cd690..8b7abc212 100644
--- a/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py
+++ b/pytorch_pfn_extras/dataset/tabular/tabular_dataset.py
@@ -1,9 +1,8 @@
# mypy: ignore-errors
import numpy
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
from torch.utils.data import Dataset
@@ -201,8 +200,7 @@ def concat(self, *datasets):
Returns:
A concatenated dataset.
"""
- return ppe.dataset.tabular._concat._Concat(
- self, *datasets)
+ return ppe.dataset.tabular._concat._Concat(self, *datasets)
def join(self, *datasets):
"""Stack datasets along columns.
@@ -243,8 +241,7 @@ def transform(self, keys, transform):
Returns:
A transfromed dataset.
"""
- return ppe.dataset.tabular._transform._Transform(
- self, keys, transform)
+ return ppe.dataset.tabular._transform._Transform(self, keys, transform)
def transform_batch(self, keys, transform_batch):
"""Apply a transform to examples.
@@ -274,7 +271,8 @@ def transform_batch(self, keys, transform_batch):
A transfromed dataset.
"""
return ppe.dataset.tabular._transform._TransformBatch(
- self, keys, transform_batch)
+ self, keys, transform_batch
+ )
def with_converter(self, converter):
"""Override the behaviour of :meth:`convert`.
@@ -289,7 +287,8 @@ def with_converter(self, converter):
"""
return ppe.dataset.tabular._with_converter._WithConverter(
- self, converter)
+ self, converter
+ )
def get_example(self, i):
example = self.get_examples([i], None)
@@ -320,8 +319,7 @@ def __getitem__(self, index):
"""
if isinstance(index, slice):
current, stop, step = index.indices(len(self))
- return [self.get_example(i) for i in
- range(current, stop, step)]
+ return [self.get_example(i) for i in range(current, stop, step)]
elif isinstance(index, list) or isinstance(index, numpy.ndarray):
return [self.get_example(i) for i in index]
else:
diff --git a/pytorch_pfn_extras/distributed/__init__.py b/pytorch_pfn_extras/distributed/__init__.py
index 3ae322192..d2d43f287 100644
--- a/pytorch_pfn_extras/distributed/__init__.py
+++ b/pytorch_pfn_extras/distributed/__init__.py
@@ -1,3 +1,9 @@
-from pytorch_pfn_extras.distributed._dataset_util import create_distributed_subset_indices # NOQA
-from pytorch_pfn_extras.distributed._distributed_validation_sampler import DistributedValidationSampler # NOQA
-from pytorch_pfn_extras.distributed._initialize import initialize_ompi_environment # NOQA
+from pytorch_pfn_extras.distributed._dataset_util import ( # NOQA
+ create_distributed_subset_indices,
+)
+from pytorch_pfn_extras.distributed._distributed_validation_sampler import ( # NOQA
+ DistributedValidationSampler,
+)
+from pytorch_pfn_extras.distributed._initialize import ( # NOQA
+ initialize_ompi_environment,
+)
diff --git a/pytorch_pfn_extras/distributed/_dataset_util.py b/pytorch_pfn_extras/distributed/_dataset_util.py
index 2ece1f280..cdd3db571 100644
--- a/pytorch_pfn_extras/distributed/_dataset_util.py
+++ b/pytorch_pfn_extras/distributed/_dataset_util.py
@@ -5,7 +5,7 @@
def _shared_random_seed() -> int:
- seed = torch.randint(0, 2 ** 31, size=())
+ seed = torch.randint(0, 2**31, size=())
if torch.distributed.is_initialized(): # type: ignore
if torch.distributed.get_backend() == "nccl": # type: ignore
seed = seed.cuda()
diff --git a/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py b/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py
index f4165ce71..bbf1bd2ab 100644
--- a/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py
+++ b/pytorch_pfn_extras/distributed/_distributed_validation_sampler.py
@@ -1,11 +1,10 @@
-from typing import TypeVar, Optional, Iterator, Sized
+from typing import Iterator, Optional, Sized, TypeVar
import numpy as np
import torch
import torch.distributed as dist
-
-T_co = TypeVar('T_co', covariant=True)
+T_co = TypeVar("T_co", covariant=True)
class DistributedValidationSampler(torch.utils.data.Sampler):
@@ -18,23 +17,31 @@ class DistributedValidationSampler(torch.utils.data.Sampler):
so for training do not use this sampler (use PyTorch DistributedSampler instead).
"""
- def __init__(self,
- dataset: Sized,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None, shuffle: bool = True,
- seed: int = 0) -> None:
+ def __init__(
+ self,
+ dataset: Sized,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ ) -> None:
if num_replicas is None:
if not dist.is_available(): # type: ignore[no-untyped-call]
- raise RuntimeError("Requires distributed package to be available")
+ raise RuntimeError(
+ "Requires distributed package to be available"
+ )
num_replicas = dist.get_world_size() # type: ignore[no-untyped-call]
if rank is None:
if not dist.is_available(): # type: ignore[no-untyped-call]
- raise RuntimeError("Requires distributed package to be available")
+ raise RuntimeError(
+ "Requires distributed package to be available"
+ )
rank = dist.get_rank() # type: ignore[no-untyped-call]
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
- " [0, {}]".format(rank, num_replicas - 1))
+ " [0, {}]".format(rank, num_replicas - 1)
+ )
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -42,7 +49,9 @@ def __init__(self,
self.seed = seed
self.dataset_len = len(dataset)
- self.num_samples = len(np.array_split(range(self.dataset_len), num_replicas)[rank])
+ self.num_samples = len(
+ np.array_split(range(self.dataset_len), num_replicas)[rank]
+ )
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
diff --git a/pytorch_pfn_extras/distributed/_initialize.py b/pytorch_pfn_extras/distributed/_initialize.py
index 8f5a64285..9d09c122b 100644
--- a/pytorch_pfn_extras/distributed/_initialize.py
+++ b/pytorch_pfn_extras/distributed/_initialize.py
@@ -12,7 +12,7 @@ def initialize_ompi_environment(
rank: int = 0,
local_rank: int = 0,
addr: str = "localhost",
- port: str = "1234"
+ port: str = "1234",
) -> Tuple[int, int, int]:
"""Initialize `torch.distributed` environments with values taken from
OpenMPI.
@@ -42,9 +42,10 @@ def initialize_ompi_environment(
addr = e.get("MASTER_ADDR", addr)
port = e.get("MASTER_PORT", port)
- if backend not in ("gloo" ,"nccl"):
+ if backend not in ("gloo", "nccl"):
raise ValueError(
- "Invalid value for backend, only 'gloo' and 'nccl' are supported")
+ "Invalid value for backend, only 'gloo' and 'nccl' are supported"
+ )
if init_method == "env":
init_method = "env://"
e["MASTER_ADDR"] = addr
@@ -56,12 +57,12 @@ def initialize_ompi_environment(
init_method = f"tcp://{addr}:{port}"
else:
raise ValueError(
- "Invalid value for init_method, only 'env' and 'tcp' are supported")
+ "Invalid value for init_method, only 'env' and 'tcp' are supported"
+ )
if world_size > 1 and not torch.distributed.is_initialized(): # type: ignore
torch.distributed.init_process_group( # type: ignore
- backend, init_method=init_method,
- world_size=world_size, rank=rank
+ backend, init_method=init_method, world_size=world_size, rank=rank
)
torch.distributed.barrier() # type: ignore
diff --git a/pytorch_pfn_extras/engine.py b/pytorch_pfn_extras/engine.py
index ed7a8c58c..3db3eea18 100644
--- a/pytorch_pfn_extras/engine.py
+++ b/pytorch_pfn_extras/engine.py
@@ -1,48 +1,62 @@
from typing import (
- Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type,
- Union, TYPE_CHECKING,
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
)
-import torch
-
import pytorch_pfn_extras.handler as handler_module
+import torch
from pytorch_pfn_extras.runtime import runtime_registry
from pytorch_pfn_extras.training._transform_model import default_transform_model
if TYPE_CHECKING:
+ from pytorch_pfn_extras import writing
from pytorch_pfn_extras.runtime._runtime import DeviceLike
from pytorch_pfn_extras.training import extension
- from pytorch_pfn_extras.training.trigger import TriggerLike
- from pytorch_pfn_extras.training._trainer import Trainer
from pytorch_pfn_extras.training._evaluator import Evaluator
+ from pytorch_pfn_extras.training._trainer import Trainer
from pytorch_pfn_extras.training.metrics import MetricType
- from pytorch_pfn_extras import writing
+ from pytorch_pfn_extras.training.trigger import TriggerLike
def create_trainer(
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- optimizers: Union[torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]],
- max_epochs: int,
- *,
- extensions: Optional[Sequence[Union['extension.ExtensionLike',
- 'extension.ExtensionEntry']]] = None,
- out_dir: str = 'result',
- stop_trigger: 'TriggerLike' = None,
- writer: Optional['writing.Writer'] = None,
- evaluator: Optional[Union[
- 'Evaluator', Tuple['Evaluator', 'TriggerLike'],
- Mapping[str, Union['Evaluator', Tuple['Evaluator', 'TriggerLike']]]
- ]] = None,
- device: 'DeviceLike' = 'cpu',
- logic: Optional[handler_module.BaseLogic] = None,
- transform_model: Callable[
- [str, torch.nn.Module], torch.nn.Module] = default_transform_model,
- handler_class: Optional[Type[handler_module.Handler]] = None,
- options: Optional[Dict[str, Any]] = None,
- runtime_options: Optional[Mapping[str, Any]] = None,
- profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
- **kwargs: Any,
-) -> 'Trainer':
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ optimizers: Union[
+ torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]
+ ],
+ max_epochs: int,
+ *,
+ extensions: Optional[
+ Sequence[Union["extension.ExtensionLike", "extension.ExtensionEntry"]]
+ ] = None,
+ out_dir: str = "result",
+ stop_trigger: "TriggerLike" = None,
+ writer: Optional["writing.Writer"] = None,
+ evaluator: Optional[
+ Union[
+ "Evaluator",
+ Tuple["Evaluator", "TriggerLike"],
+ Mapping[str, Union["Evaluator", Tuple["Evaluator", "TriggerLike"]]],
+ ]
+ ] = None,
+ device: "DeviceLike" = "cpu",
+ logic: Optional[handler_module.BaseLogic] = None,
+ transform_model: Callable[
+ [str, torch.nn.Module], torch.nn.Module
+ ] = default_transform_model,
+ handler_class: Optional[Type[handler_module.Handler]] = None,
+ options: Optional[Dict[str, Any]] = None,
+ runtime_options: Optional[Mapping[str, Any]] = None,
+ profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
+ **kwargs: Any,
+) -> "Trainer":
"""Creates a trainer object.
Args:
@@ -94,13 +108,14 @@ def create_trainer(
options = options.copy() if options else {}
# TODO(kmaehashi): deprecate specifying 'runtime' key in options
runtime_options = dict(
- runtime_options if runtime_options
- else options.pop('runtime', {}))
+ runtime_options if runtime_options else options.pop("runtime", {})
+ )
logic = handler_module.Logic() if logic is None else logic
handler_class = handler_class if handler_class else handler_module.Handler
entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec(
- device)
+ device
+ )
entry_runtime = entry_runtime_cls(device, runtime_options)
handler = handler_class(logic, entry_runtime, {})
@@ -108,14 +123,20 @@ def create_trainer(
handler.consume_options(options)
logic.consume_options(options)
if len(options) > 0:
- raise ValueError('Unknown options: ', options)
+ raise ValueError("Unknown options: ", options)
from pytorch_pfn_extras.training._trainer import Trainer
+
return Trainer(
- handler, evaluator=evaluator,
- models=models, optimizers=optimizers, max_epochs=max_epochs,
- extensions=extensions, out_dir=out_dir,
- stop_trigger=stop_trigger, writer=writer,
+ handler,
+ evaluator=evaluator,
+ models=models,
+ optimizers=optimizers,
+ max_epochs=max_epochs,
+ extensions=extensions,
+ out_dir=out_dir,
+ stop_trigger=stop_trigger,
+ writer=writer,
transform_model=transform_model,
profile=profile,
**kwargs,
@@ -123,17 +144,17 @@ def create_trainer(
def create_evaluator(
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- *,
- progress_bar: bool = False,
- device: 'DeviceLike' = 'cpu',
- metrics: Optional[Sequence['MetricType']] = None,
- logic: Optional[handler_module.Logic] = None,
- handler_class: Optional[Type[handler_module.Handler]] = None,
- options: Optional[Dict[str, Any]] = None,
- runtime_options: Optional[Mapping[str, Any]] = None,
- profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
-) -> 'Evaluator':
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ *,
+ progress_bar: bool = False,
+ device: "DeviceLike" = "cpu",
+ metrics: Optional[Sequence["MetricType"]] = None,
+ logic: Optional[handler_module.Logic] = None,
+ handler_class: Optional[Type[handler_module.Handler]] = None,
+ options: Optional[Dict[str, Any]] = None,
+ runtime_options: Optional[Mapping[str, Any]] = None,
+ profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
+) -> "Evaluator":
"""Creates an evaluator object. The return value of this function is
expected to be fed to `ppe.engine.create_trainer` as an argument.
@@ -171,13 +192,14 @@ def create_evaluator(
options = options.copy() if options else {}
# TODO(kmaehashi): deprecate specifying 'runtime' key in options
runtime_options = dict(
- runtime_options if runtime_options
- else options.pop('runtime', {}))
+ runtime_options if runtime_options else options.pop("runtime", {})
+ )
logic = handler_module.Logic() if logic is None else logic
handler_class = handler_class if handler_class else handler_module.Handler
entry_runtime_cls = runtime_registry.get_runtime_class_for_device_spec(
- device)
+ device
+ )
entry_runtime = entry_runtime_cls(device, runtime_options)
handler = handler_class(logic, entry_runtime, options)
@@ -185,9 +207,10 @@ def create_evaluator(
handler.consume_options(options)
logic.consume_options(options)
if len(options) > 0:
- raise ValueError('Unknown options: ', options)
+ raise ValueError("Unknown options: ", options)
from pytorch_pfn_extras.training._evaluator import Evaluator
+
return Evaluator(
handler,
models=models,
diff --git a/pytorch_pfn_extras/handler/__init__.py b/pytorch_pfn_extras/handler/__init__.py
index 09014488e..613fe8c5d 100644
--- a/pytorch_pfn_extras/handler/__init__.py
+++ b/pytorch_pfn_extras/handler/__init__.py
@@ -1,6 +1,15 @@
-from pytorch_pfn_extras.handler._code_block import CodeBlock, update_parameters, forward # NOQA
+from pytorch_pfn_extras.handler._code_block import ( # NOQA
+ CodeBlock,
+ forward,
+ update_parameters,
+)
from pytorch_pfn_extras.handler._handler import BaseHandler, Handler # NOQA
-from pytorch_pfn_extras.handler._logic import BaseLogic, Logic, CodeBlockLogic, ClousureLogic # NOQA
# Deprecated, only imported for backward compatibility
from pytorch_pfn_extras.handler._logic import torch_autocast # NOQA
+from pytorch_pfn_extras.handler._logic import ( # NOQA
+ BaseLogic,
+ ClousureLogic,
+ CodeBlockLogic,
+ Logic,
+)
diff --git a/pytorch_pfn_extras/handler/_code_block.py b/pytorch_pfn_extras/handler/_code_block.py
index 94d4396ad..d2e23f3a6 100644
--- a/pytorch_pfn_extras/handler/_code_block.py
+++ b/pytorch_pfn_extras/handler/_code_block.py
@@ -21,6 +21,7 @@ class CodeBlock:
backprop_to: Name of the values where backpropagation will be stopped.
state: Data that can be used during the CodeBlock execution.
"""
+
func: Callable
optimizers: List[torch.optim.Optimizer]
backprop: bool
@@ -104,11 +105,11 @@ def forward(block: Callable) -> CodeBlock:
if isinstance(block, torch.nn.Module):
module = block
else:
- module = getattr(block, '__self__', None)
+ module = getattr(block, "__self__", None)
assert module is not None
func = block
state = {}
- runtime = getattr(module, '_ppe_runtime', None)
+ runtime = getattr(module, "_ppe_runtime", None)
assert runtime is not None
return CodeBlock(
diff --git a/pytorch_pfn_extras/handler/_handler.py b/pytorch_pfn_extras/handler/_handler.py
index 36820dbba..bc506c62f 100644
--- a/pytorch_pfn_extras/handler/_handler.py
+++ b/pytorch_pfn_extras/handler/_handler.py
@@ -1,11 +1,19 @@
from typing import (
- Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional,
- Tuple, Union, TYPE_CHECKING,
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
)
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.handler._logic import BaseLogic
from pytorch_pfn_extras.training import Evaluator, Trainer
@@ -15,13 +23,12 @@
class BaseHandler:
-
def __init__(
- self,
- logic: BaseLogic,
- options: Dict[str, Any],
- *args: Any,
- **kwargs: Any
+ self,
+ logic: BaseLogic,
+ options: Dict[str, Any],
+ *args: Any,
+ **kwargs: Any,
) -> None:
"""Base class of Handler.
@@ -60,9 +67,7 @@ def train_setup(self, trainer: Trainer, loader: Iterable[Any]) -> None:
pass
def train_epoch_begin(
- self,
- trainer: Trainer,
- loader: Iterable[Any]
+ self, trainer: Trainer, loader: Iterable[Any]
) -> None:
"""A method called when starting a new epoch.
@@ -94,9 +99,9 @@ def train_cleanup(self, trainer: Trainer) -> None:
pass
def train_validation_begin(
- self,
- trainer: Trainer,
- evaluator: Evaluator,
+ self,
+ trainer: Trainer,
+ evaluator: Evaluator,
) -> None:
"""A method called when starting a validation.
@@ -109,9 +114,9 @@ def train_validation_begin(
pass
def train_validation_end(
- self,
- trainer: Trainer,
- evaluator: Evaluator,
+ self,
+ trainer: Trainer,
+ evaluator: Evaluator,
) -> None:
"""A method called after validation.
@@ -125,11 +130,11 @@ def train_validation_end(
pass
def train_step(
- self,
- trainer: Trainer,
- batch_idx: int,
- batch: Any,
- complete_fn: Callable[[int, Any], None],
+ self,
+ trainer: Trainer,
+ batch_idx: int,
+ batch: Any,
+ complete_fn: Callable[[int, Any], None],
) -> None:
"""A training step.
@@ -141,11 +146,11 @@ def train_step(
pass
def train_post_step(
- self,
- trainer: Trainer,
- batch_idx: int,
- batch: Any,
- outputs: Any,
+ self,
+ trainer: Trainer,
+ batch_idx: int,
+ batch: Any,
+ outputs: Any,
) -> None:
"""A method called after each training step.
@@ -156,11 +161,7 @@ def train_post_step(
# Called after train_step.
pass
- def eval_setup(
- self,
- evaluator: Evaluator,
- loader: Iterable[Any]
- ) -> None:
+ def eval_setup(self, evaluator: Evaluator, loader: Iterable[Any]) -> None:
"""A method called only once when starting a training run.
When evaluator is not given, this method is not called.
@@ -183,11 +184,11 @@ def eval_loop_begin(self, evaluator: Evaluator) -> None:
pass
def eval_step(
- self,
- evaluator: Evaluator,
- batch_idx: int,
- batch: Any,
- complete_fn: Callable[[int, Any], None],
+ self,
+ evaluator: Evaluator,
+ batch_idx: int,
+ batch: Any,
+ complete_fn: Callable[[int, Any], None],
) -> None:
"""Evaluation iteration.
@@ -209,11 +210,11 @@ def eval_loop_end(self, evaluator: Evaluator) -> None:
pass
def eval_post_step(
- self,
- evaluator: Evaluator,
- batch_idx: int,
- batch: Any,
- outputs: Any,
+ self,
+ evaluator: Evaluator,
+ batch_idx: int,
+ batch: Any,
+ outputs: Any,
) -> None:
"""A method called after each evaluation step.
@@ -225,16 +226,15 @@ def eval_post_step(
pass
-ModulesTuple = Tuple[str, torch.nn.Module, 'BaseRuntime']
+ModulesTuple = Tuple[str, torch.nn.Module, "BaseRuntime"]
class Handler(BaseHandler):
-
def __init__(
- self,
- logic: BaseLogic,
- entry_runtime: 'BaseRuntime',
- options: Dict[str, Any],
+ self,
+ logic: BaseLogic,
+ entry_runtime: "BaseRuntime",
+ options: Dict[str, Any],
) -> None:
"""A set of callback functions to perform device-specific operations.
@@ -260,14 +260,14 @@ def __init__(
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
- self._eval_report_keys = options.pop('eval_report_keys', [])
- self._train_report_keys = options.pop('train_report_keys', [])
+ self._eval_report_keys = options.pop("eval_report_keys", [])
+ self._train_report_keys = options.pop("train_report_keys", [])
# Consume this argument for backward compatibility
- options.pop('async', False)
+ options.pop("async", False)
def _runtime_iterator(
- self,
- models: Mapping[str, torch.nn.Module],
+ self,
+ models: Mapping[str, torch.nn.Module],
) -> Generator[ModulesTuple, ModulesTuple, None]:
if not self._ppe_modules:
for n, m in models.items():
@@ -281,18 +281,19 @@ def _runtime_iterator(
yield sn, sm, rt
def _setup(
- self,
- models: Mapping[str, torch.nn.Module],
- loader: Union[Iterable[Any], Mapping[str, Iterable[Any]]],
- optimizers: Optional[Mapping[str, torch.optim.Optimizer]] = None,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ loader: Union[Iterable[Any], Mapping[str, Iterable[Any]]],
+ optimizers: Optional[Mapping[str, torch.optim.Optimizer]] = None,
) -> None:
# This requires loader to be always a dict
# should be avoided?
if not isinstance(loader, dict):
# The default model always has empty name when obtained from the
# modules
- loaders = {sn: loader
- for sn, _, _ in self._runtime_iterator(models)}
+ loaders = {
+ sn: loader for sn, _, _ in self._runtime_iterator(models)
+ }
else:
loaders = loader
if optimizers is None:
@@ -310,7 +311,8 @@ def _setup(
if len(self._ppe_modules) == 0:
raise RuntimeError(
- 'call `ppe.to(module, device)` before starting the training')
+ "call `ppe.to(module, device)` before starting the training"
+ )
def train_setup(self, trainer: Trainer, loader: Iterable[Any]) -> None:
"""A method called only once when starting a training run.
@@ -334,9 +336,7 @@ def train_cleanup(self, trainer: Trainer) -> None:
rt.train_cleanup(sm)
def train_epoch_begin(
- self,
- trainer: Trainer,
- loader: Iterable[Any]
+ self, trainer: Trainer, loader: Iterable[Any]
) -> None:
"""A method called when starting a new epoch.
@@ -361,9 +361,9 @@ def train_epoch_end(self, trainer: Trainer) -> None:
self._logic.train_epoch_end(trainer.models, trainer.epoch)
def train_validation_begin(
- self,
- trainer: Trainer,
- evaluator: Evaluator,
+ self,
+ trainer: Trainer,
+ evaluator: Evaluator,
) -> None:
"""A method called when starting a validation.
@@ -376,9 +376,9 @@ def train_validation_begin(
self._logic.train_validation_begin(evaluator.models)
def train_validation_end(
- self,
- trainer: Trainer,
- evaluator: Evaluator,
+ self,
+ trainer: Trainer,
+ evaluator: Evaluator,
) -> None:
"""A method called after validation.
@@ -396,11 +396,11 @@ def train_validation_end(
self._logic.train_validation_end(evaluator.models)
def train_step(
- self,
- trainer: Trainer,
- batch_idx: int,
- batch: Any,
- complete_fn: Callable[[int, Any], None],
+ self,
+ trainer: Trainer,
+ batch_idx: int,
+ batch: Any,
+ complete_fn: Callable[[int, Any], None],
) -> None:
"""A training step.
@@ -418,17 +418,15 @@ def train_step(
batch = self._entry_runtime.convert_batch(batch)
outs = self._logic.train_step(
- trainer.models, trainer.optimizers, batch_idx, batch)
+ trainer.models, trainer.optimizers, batch_idx, batch
+ )
self._logic.train_step_optimizers(
- trainer.models, trainer.optimizers, batch_idx)
+ trainer.models, trainer.optimizers, batch_idx
+ )
complete_fn(batch_idx, outs)
- def eval_setup(
- self,
- evaluator: Evaluator,
- loader: Iterable[Any]
- ) -> None:
+ def eval_setup(self, evaluator: Evaluator, loader: Iterable[Any]) -> None:
"""Called only once when starting a training run.
When evaluator is not given, this method is not called.
@@ -441,11 +439,11 @@ def eval_setup(
self._setup(evaluator.models, loader)
def eval_step(
- self,
- evaluator: Evaluator,
- batch_idx: int,
- batch: Any,
- complete_fn: Callable[[int, Any], None],
+ self,
+ evaluator: Evaluator,
+ batch_idx: int,
+ batch: Any,
+ complete_fn: Callable[[int, Any], None],
) -> None:
"""Evaluation iteration.
@@ -466,11 +464,11 @@ def eval_step(
complete_fn(batch_idx, outs)
def eval_post_step(
- self,
- evaluator: Evaluator,
- batch_idx: int,
- batch: Any,
- outputs: Any,
+ self,
+ evaluator: Evaluator,
+ batch_idx: int,
+ batch: Any,
+ outputs: Any,
) -> None:
"""A method called after each evaluation step.
@@ -497,11 +495,7 @@ def eval_loop_end(self, evaluator: Evaluator) -> None:
pass
def train_post_step(
- self,
- trainer: Trainer,
- batch_idx: int,
- batch: Any,
- outputs: Any
+ self, trainer: Trainer, batch_idx: int, batch: Any, outputs: Any
) -> None:
"""A method called after each training step.
diff --git a/pytorch_pfn_extras/handler/_logic.py b/pytorch_pfn_extras/handler/_logic.py
index b9dacc663..af9f78a53 100644
--- a/pytorch_pfn_extras/handler/_logic.py
+++ b/pytorch_pfn_extras/handler/_logic.py
@@ -1,10 +1,9 @@
import contextlib
import dataclasses
-from typing import Any, Dict, Generator, Iterable, Mapping, Optional
import warnings
+from typing import Any, Dict, Generator, Iterable, Mapping, Optional
import torch
-
from pytorch_pfn_extras.handler._code_block import forward, update_parameters
from pytorch_pfn_extras.runtime import _autocast
@@ -21,7 +20,7 @@ def torch_autocast(enabled: bool = True) -> Generator[None, None, None]:
def _normalize_outputs(outputs: Any) -> Dict[str, Any]:
target: Dict[str, Any]
- if isinstance(outputs, tuple) and hasattr(outputs, '_fields'):
+ if isinstance(outputs, tuple) and hasattr(outputs, "_fields"):
# namedtuple
target = outputs._asdict() # type: ignore[attr-defined]
elif isinstance(outputs, dict):
@@ -50,10 +49,10 @@ def consume_options(self, options: Dict[str, Any]) -> None:
pass
def train_epoch_begin(
- self,
- models: Mapping[str, torch.nn.Module],
- epoch: int,
- loader: Iterable[Any],
+ self,
+ models: Mapping[str, torch.nn.Module],
+ epoch: int,
+ loader: Iterable[Any],
) -> None:
"""A method called when starting a new epoch of training.
@@ -65,9 +64,9 @@ def train_epoch_begin(
pass
def train_epoch_end(
- self,
- models: Mapping[str, torch.nn.Module],
- epoch: int,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ epoch: int,
) -> None:
"""A method called when completing an epoch of training.
@@ -78,11 +77,11 @@ def train_epoch_end(
pass
def train_step(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method invokes the models forward and backward passes.
@@ -102,10 +101,10 @@ def train_step(
pass
def train_step_optimizers(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
) -> None:
"""A method in charge of stepping the provided optimizers.
@@ -118,8 +117,7 @@ def train_step_optimizers(
pass
def train_validation_begin(
- self,
- models: Mapping[str, torch.nn.Module]
+ self, models: Mapping[str, torch.nn.Module]
) -> None:
"""A method called when starting a validation.
@@ -129,8 +127,8 @@ def train_validation_begin(
pass
def train_validation_end(
- self,
- models: Mapping[str, torch.nn.Module],
+ self,
+ models: Mapping[str, torch.nn.Module],
) -> None:
"""A method called when the validation completes.
@@ -140,10 +138,10 @@ def train_validation_end(
pass
def eval_step(
- self,
- models: Mapping[str, torch.nn.Module],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method for an evaluation step.
@@ -157,11 +155,10 @@ def eval_step(
class Logic(BaseLogic):
-
def __init__(
- self,
- model_name: str = 'main',
- options: Optional[Dict[str, Any]] = None,
+ self,
+ model_name: str = "main",
+ options: Optional[Dict[str, Any]] = None,
) -> None:
"""A set of methods that defines the training logic.
@@ -188,24 +185,29 @@ def __init__(
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
- self.backward_outputs = options.pop('backward_outputs', None)
- self._grad_scaler = options.pop('grad_scaler', None)
+ self.backward_outputs = options.pop("backward_outputs", None)
+ self._grad_scaler = options.pop("grad_scaler", None)
- self._backward_fn = options.pop('backward_function', None)
+ self._backward_fn = options.pop("backward_function", None)
autocast_options = options.pop("autocast", False)
if isinstance(autocast_options, bool):
- autocast_options = {"enabled": autocast_options, "device_type": "cuda"}
+ autocast_options = {
+ "enabled": autocast_options,
+ "device_type": "cuda",
+ }
self._autocast = _autocast._AutocastManager(
autocast_options, self._grad_scaler is not None
)
if self._grad_scaler is not None:
if not isinstance(self._grad_scaler, torch.cuda.amp.GradScaler):
- raise RuntimeError('grad_scaler should be a '
- 'torch.cuda.amp.GradScaler object')
+ raise RuntimeError(
+ "grad_scaler should be a "
+ "torch.cuda.amp.GradScaler object"
+ )
def _forward(self, model: torch.nn.Module, batch: Any) -> Any:
- if isinstance(batch, tuple) and hasattr(batch, '_fields'):
+ if isinstance(batch, tuple) and hasattr(batch, "_fields"):
# namedtuple
return model(batch)
if isinstance(batch, dict):
@@ -218,10 +220,16 @@ def _backward(self, outputs: Dict[str, Any]) -> None:
to_backward = set()
if self.backward_outputs is None:
for _, v in outputs.items():
- if isinstance(v, torch.Tensor) and v.grad_fn is not None and (
- (
- v.numel() == 1
- and (v.dtype.is_floating_point or v.dtype.is_complex)
+ if (
+ isinstance(v, torch.Tensor)
+ and v.grad_fn is not None
+ and (
+ (
+ v.numel() == 1
+ and (
+ v.dtype.is_floating_point or v.dtype.is_complex
+ )
+ )
)
):
to_backward.add(v)
@@ -238,8 +246,8 @@ def _backward(self, outputs: Dict[str, Any]) -> None:
to_backward.add(v)
except KeyError:
warnings.warn(
- 'Couldn\'t find requested backward value: '
- f'{k} in {outputs.keys()}'
+ "Couldn't find requested backward value: "
+ f"{k} in {outputs.keys()}"
)
for v in to_backward:
@@ -249,10 +257,10 @@ def _backward(self, outputs: Dict[str, Any]) -> None:
self._backward_fn(v)
def train_epoch_begin(
- self,
- models: Mapping[str, torch.nn.Module],
- epoch: int,
- loader: Iterable[Any],
+ self,
+ models: Mapping[str, torch.nn.Module],
+ epoch: int,
+ loader: Iterable[Any],
) -> None:
"""A method called when starting a new epoch of training.
@@ -263,8 +271,9 @@ def train_epoch_begin(
"""
model = models[self.model_name]
model.train()
- if hasattr(loader, 'sampler') and hasattr(
- loader.sampler, 'set_epoch'): # type: ignore[attr-defined]
+ if hasattr(loader, "sampler") and hasattr(
+ loader.sampler, "set_epoch"
+ ): # type: ignore[attr-defined]
# Needed for `torch.utils.data.DistributedSampler`
loader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
@@ -273,11 +282,11 @@ def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None:
model.eval()
def train_step(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes.
@@ -304,15 +313,16 @@ def train_step(
), "loss scaling with multiple outputs is not supported"
to_back_outs = {
k: self._grad_scaler.scale(v)
- for k, v in to_back_outs.items()}
+ for k, v in to_back_outs.items()
+ }
self._backward(to_back_outs)
return outs
def train_step_optimizers(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
) -> None:
"""A method in charge of stepping the provided optimizers.
@@ -332,8 +342,8 @@ def train_step_optimizers(
optimizer.step()
def train_validation_begin(
- self,
- models: Mapping[str, torch.nn.Module],
+ self,
+ models: Mapping[str, torch.nn.Module],
) -> None:
"""A method called when starting a validation.
@@ -348,10 +358,10 @@ def train_validation_end(self, models: Mapping[str, Any]) -> None:
model.train()
def eval_step(
- self,
- models: Mapping[str, torch.nn.Module],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method for an evaluation step.
@@ -368,9 +378,9 @@ def eval_step(
class CodeBlockLogic(BaseLogic):
def __init__(
- self,
- model_name: str = 'main',
- options: Optional[Dict[str, Any]] = None,
+ self,
+ model_name: str = "main",
+ options: Optional[Dict[str, Any]] = None,
) -> None:
"""A set of methods that defines the training logic.
@@ -388,15 +398,15 @@ def __init__(
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
- self.backward_outputs = options.pop('backward_outputs', None)
+ self.backward_outputs = options.pop("backward_outputs", None)
if self.backward_outputs is not None:
assert isinstance(self.backward_outputs, str)
def train_epoch_begin(
- self,
- models: Mapping[str, torch.nn.Module],
- epoch: int,
- loader: Iterable[Any],
+ self,
+ models: Mapping[str, torch.nn.Module],
+ epoch: int,
+ loader: Iterable[Any],
) -> None:
"""A method called when starting a new epoch of training.
@@ -407,8 +417,9 @@ def train_epoch_begin(
"""
model = models[self.model_name]
model.train()
- if hasattr(loader, 'sampler') and hasattr(
- loader.sampler, 'set_epoch'): # type: ignore[attr-defined]
+ if hasattr(loader, "sampler") and hasattr(
+ loader.sampler, "set_epoch"
+ ): # type: ignore[attr-defined]
# Needed for `torch.utils.data.DistributedSampler`
loader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
@@ -417,11 +428,11 @@ def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None:
model.eval()
def train_step(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes.
@@ -448,8 +459,8 @@ def train_step(
)(batch)
def train_validation_begin(
- self,
- models: Mapping[str, torch.nn.Module],
+ self,
+ models: Mapping[str, torch.nn.Module],
) -> None:
"""A method called when starting a validation.
@@ -464,10 +475,10 @@ def train_validation_end(self, models: Mapping[str, Any]) -> None:
model.train()
def eval_step(
- self,
- models: Mapping[str, torch.nn.Module],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method for an evaluation step.
@@ -492,18 +503,19 @@ def __float__(self) -> float:
class ClousureLogic(Logic):
-
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
if self._grad_scaler is not None:
- raise RuntimeError('torch.cuda.amp.GradScaler does not support clousure step mode.')
+ raise RuntimeError(
+ "torch.cuda.amp.GradScaler does not support clousure step mode."
+ )
def train_step(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
- batch: Any,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
+ batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes and performs an optimization step.
@@ -517,18 +529,21 @@ def train_step(
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
+
def clousure() -> ClousureModelOutput:
with self._autocast.autocast():
optimizers[self.model_name].zero_grad()
outs = self._forward(models[self.model_name], batch)
to_back_outs = _normalize_outputs(outs)
if len(to_back_outs) > 1:
- raise RuntimeError("Clousure step with multiple outputs is not supported.")
+ raise RuntimeError(
+ "Clousure step with multiple outputs is not supported."
+ )
elif len(to_back_outs) == 0:
raise RuntimeError("No backward target found.")
self._backward(to_back_outs)
- loss, = to_back_outs.values()
+ (loss,) = to_back_outs.values()
return ClousureModelOutput(
outs=outs,
loss=loss,
@@ -537,14 +552,16 @@ def clousure() -> ClousureModelOutput:
optimizer = optimizers[self.model_name]
clousure_model_output: ClousureModelOutput = optimizer.step(clousure) # type: ignore
if not isinstance(clousure_model_output, ClousureModelOutput):
- raise RuntimeError(f"{type(clousure_model_output)} type object returned from optimizer.step with clousure. optimizer.step is expected to return ppe.handler.ClousureModelOutput.")
+ raise RuntimeError(
+ f"{type(clousure_model_output)} type object returned from optimizer.step with clousure. optimizer.step is expected to return ppe.handler.ClousureModelOutput."
+ )
return clousure_model_output.outs
def train_step_optimizers(
- self,
- models: Mapping[str, torch.nn.Module],
- optimizers: Mapping[str, torch.optim.Optimizer],
- batch_idx: int,
+ self,
+ models: Mapping[str, torch.nn.Module],
+ optimizers: Mapping[str, torch.optim.Optimizer],
+ batch_idx: int,
) -> None:
"""In clousure mode, the stepping of the optimizer cannot be changed.
diff --git a/pytorch_pfn_extras/logging.py b/pytorch_pfn_extras/logging.py
index fb583f111..be9938a5e 100644
--- a/pytorch_pfn_extras/logging.py
+++ b/pytorch_pfn_extras/logging.py
@@ -1,22 +1,21 @@
import logging
import os
-
-from logging import DEBUG, INFO, WARNING, ERROR, CRITICAL # NOQA
+from logging import CRITICAL, DEBUG, ERROR, INFO, WARNING # NOQA
from typing import Optional
-_logger_name = 'ppe'
-_logger_format = '[%(name)s] %(asctime)s: (%(levelname)s) %(message)s'
+_logger_name = "ppe"
+_logger_format = "[%(name)s] %(asctime)s: (%(levelname)s) %(message)s"
_logger = None
def _configure_logging(
- *,
- filename: Optional[str] = None,
- level: str = 'ERROR',
- format: str = _logger_format
+ *,
+ filename: Optional[str] = None,
+ level: str = "ERROR",
+ format: str = _logger_format,
) -> None:
global _logger
- filename = os.environ.get('PPE_LOG_FILENAME', filename)
+ filename = os.environ.get("PPE_LOG_FILENAME", filename)
if filename is None:
handler: logging.Handler = logging.StreamHandler()
else:
@@ -25,15 +24,20 @@ def _configure_logging(
# To dynamically change the level if needed
# basicConfig does not allow to change the level right after
_logger = logging.getLogger(_logger_name)
- level = os.environ.get('PPE_LOG_LEVEL', level)
- for lvl in (logging.DEBUG, logging.INFO,
- logging.WARNING, logging.ERROR, logging.CRITICAL):
+ level = os.environ.get("PPE_LOG_LEVEL", level)
+ for lvl in (
+ logging.DEBUG,
+ logging.INFO,
+ logging.WARNING,
+ logging.ERROR,
+ logging.CRITICAL,
+ ):
if logging.getLevelName(lvl) == level:
_logger.setLevel(lvl)
break
else:
_logger.setLevel(logging.INFO)
- _logger.warning('invalid PPE_LOG_LEVEL (%s); using INFO', level)
+ _logger.warning("invalid PPE_LOG_LEVEL (%s); using INFO", level)
_logger.addHandler(handler)
diff --git a/pytorch_pfn_extras/nn/__init__.py b/pytorch_pfn_extras/nn/__init__.py
index 3b2273113..f188e18d1 100644
--- a/pytorch_pfn_extras/nn/__init__.py
+++ b/pytorch_pfn_extras/nn/__init__.py
@@ -1,10 +1,12 @@
+from pytorch_pfn_extras.nn import parallel # NOQA
from pytorch_pfn_extras.nn.modules.ensure_shape import Ensure, ensure # NOQA
-from pytorch_pfn_extras.nn.modules.lazy_linear import LazyLinear # NOQA
-from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv1d # NOQA
-from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv2d # NOQA
-from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv3d # NOQA
+from pytorch_pfn_extras.nn.modules.extended_sequential import ( # NOQA
+ ExtendedSequential,
+)
from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm1d # NOQA
from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm2d # NOQA
from pytorch_pfn_extras.nn.modules.lazy_batchnorm import LazyBatchNorm3d # NOQA
-from pytorch_pfn_extras.nn.modules.extended_sequential import ExtendedSequential # NOQA
-from pytorch_pfn_extras.nn import parallel # NOQA
+from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv1d # NOQA
+from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv2d # NOQA
+from pytorch_pfn_extras.nn.modules.lazy_conv import LazyConv3d # NOQA
+from pytorch_pfn_extras.nn.modules.lazy_linear import LazyLinear # NOQA
diff --git a/pytorch_pfn_extras/nn/modules/ensure_shape.py b/pytorch_pfn_extras/nn/modules/ensure_shape.py
index a1f58d036..00043372d 100644
--- a/pytorch_pfn_extras/nn/modules/ensure_shape.py
+++ b/pytorch_pfn_extras/nn/modules/ensure_shape.py
@@ -18,17 +18,16 @@ class Ensure(torch.nn.Module):
"""
def __init__(
- self,
- *,
- shape: Optional[Tuple[Optional[int], ...]] = None,
- dtype: Optional[torch.dtype] = None,
- broadcastable: bool = False,
- can_cast: bool = False,
+ self,
+ *,
+ shape: Optional[Tuple[Optional[int], ...]] = None,
+ dtype: Optional[torch.dtype] = None,
+ broadcastable: bool = False,
+ can_cast: bool = False,
):
super().__init__() # type: ignore[no-untyped-call]
if shape is None and dtype is None:
- raise ValueError(
- 'shape, dtype or both arguments must be specified')
+ raise ValueError("shape, dtype or both arguments must be specified")
self._dtype = dtype
self._broadcastable = broadcastable
self._can_cast = can_cast
@@ -36,10 +35,7 @@ def __init__(
# so we can compare the shapes using broadcast semantics
c_shape: Optional[Tuple[int, ...]] = None
if shape is not None:
- non_none_tuple = tuple(
- [x if x is not None else 1
- for x in shape]
- )
+ non_none_tuple = tuple([x if x is not None else 1 for x in shape])
if None in shape:
self._broadcastable = True
c_shape = non_none_tuple
@@ -74,29 +70,33 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self._broadcastable:
if not self._broadcast(t_shape, self._shape):
raise ValueError(
- f'Shapes {self._shape} and {input.shape} are non'
- ' broadcastable')
+ f"Shapes {self._shape} and {input.shape} are non"
+ " broadcastable"
+ )
else:
raise ValueError(
- f'Expected {self._shape}, input shape is {input.shape}')
+ f"Expected {self._shape}, input shape is {input.shape}"
+ )
if self._dtype is not None and input.dtype != self._dtype:
if self._can_cast:
if not torch.can_cast(input.dtype, self._dtype):
raise ValueError(
- f'Input dtype {input.dtype} can\'t be casted to'
- f' {self._dtype}')
+ f"Input dtype {input.dtype} can't be casted to"
+ f" {self._dtype}"
+ )
else:
raise ValueError(
- f'Expected {self._dtype}, input dtype is {input.dtype}')
+ f"Expected {self._dtype}, input dtype is {input.dtype}"
+ )
return input
def ensure(
- tensor: torch.Tensor,
- shape: Optional[Tuple[Optional[int], ...]] = None,
- dtype: Optional[torch.dtype] = None,
- broadcastable: bool = False,
- can_cast: bool = False
+ tensor: torch.Tensor,
+ shape: Optional[Tuple[Optional[int], ...]] = None,
+ dtype: Optional[torch.dtype] = None,
+ broadcastable: bool = False,
+ can_cast: bool = False,
) -> None:
"""Checks the shape and type of a tensor.
@@ -111,5 +111,5 @@ def ensure(
can_cast: Check if the input tensor can be casted to the provided type.
"""
Ensure(
- shape=shape, dtype=dtype,
- broadcastable=broadcastable, can_cast=can_cast)(tensor)
+ shape=shape, dtype=dtype, broadcastable=broadcastable, can_cast=can_cast
+ )(tensor)
diff --git a/pytorch_pfn_extras/nn/modules/extended_sequential.py b/pytorch_pfn_extras/nn/modules/extended_sequential.py
index 222328808..9f9918d55 100644
--- a/pytorch_pfn_extras/nn/modules/extended_sequential.py
+++ b/pytorch_pfn_extras/nn/modules/extended_sequential.py
@@ -1,10 +1,10 @@
-import torch
import copy
-from typing import TypeVar
import warnings
+from typing import TypeVar
+import torch
-Model = TypeVar('Model', torch.nn.Module, 'ExtendedSequential')
+Model = TypeVar("Model", torch.nn.Module, "ExtendedSequential")
def _reset_parameters(model: Model) -> Model:
@@ -15,38 +15,39 @@ def _reset_parameters(model: Model) -> Model:
for submodel in model.values():
_reset_parameters(submodel)
else:
- if hasattr(model, 'reset_parameters'):
+ if hasattr(model, "reset_parameters"):
model.reset_parameters() # type: ignore [operator]
- elif hasattr(model, '_reset_parameters'):
+ elif hasattr(model, "_reset_parameters"):
model._reset_parameters() # type: ignore [operator]
else:
- if (len(list(model.parameters())) != 0
- or len(list(model.buffers())) != 0):
- warnings.warn('Cannot reset the parameters of module {}. '
- 'Consider adding `reset_parameters` or '
- '`_reset_parameters` '
- 'functions to the module'.format(model),
- UserWarning)
+ if (
+ len(list(model.parameters())) != 0
+ or len(list(model.buffers())) != 0
+ ):
+ warnings.warn(
+ "Cannot reset the parameters of module {}. "
+ "Consider adding `reset_parameters` or "
+ "`_reset_parameters` "
+ "functions to the module".format(model),
+ UserWarning,
+ )
return model
class ExtendedSequential(torch.nn.Sequential):
- """Sequential module with extended features from chainer.
+ """Sequential module with extended features from chainer."""
- """
- def _copy_model(self, mode: str) -> 'ExtendedSequential':
- if mode == 'init':
+ def _copy_model(self, mode: str) -> "ExtendedSequential":
+ if mode == "init":
return _reset_parameters(copy.deepcopy(self))
- elif mode == 'copy':
+ elif mode == "copy":
return copy.deepcopy(self)
else:
# mode == share
return copy.copy(self)
- def repeat(
- self, n_repeat: int, mode: str = 'init'
- ) -> 'ExtendedSequential':
+ def repeat(self, n_repeat: int, mode: str = "init") -> "ExtendedSequential":
"""Repeats this Sequential multiple times.
This method returns a :class:`~torch.nn.Sequential` object which has
@@ -85,10 +86,11 @@ def repeat(
if n_repeat <= 0:
return ExtendedSequential()
- if mode not in ['copy', 'share', 'init']:
+ if mode not in ["copy", "share", "init"]:
raise ValueError(
- 'The \'mode\' argument should be either \'init\','
- '\'copy\', or \'share\'. But {} was given.'.format(mode))
+ "The 'mode' argument should be either 'init',"
+ "'copy', or 'share'. But {} was given.".format(mode)
+ )
model_list = []
for _ in range(n_repeat):
diff --git a/pytorch_pfn_extras/nn/modules/lazy.py b/pytorch_pfn_extras/nn/modules/lazy.py
index 51e8399c2..8c2cf3f4d 100644
--- a/pytorch_pfn_extras/nn/modules/lazy.py
+++ b/pytorch_pfn_extras/nn/modules/lazy.py
@@ -1,6 +1,6 @@
import inspect
-from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
import warnings
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
@@ -58,12 +58,14 @@ def lazy_parmeters_determined(self) -> bool:
parameters are determined. Note that this may be called during
``__init__``.
"""
- return self._lazy_ready and all([
- not isinstance(getattr(self, x), UninitializedParameter)
- for x in self.lazy_parameter_names])
-
- def state_dict(
- self: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
+ return self._lazy_ready and all(
+ [
+ not isinstance(getattr(self, x), UninitializedParameter)
+ for x in self.lazy_parameter_names
+ ]
+ )
+
+ def state_dict(self: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of the module.
This function overrides the default behavior to exclude uninitialized
@@ -81,8 +83,15 @@ def state_dict(
return destination # type: ignore[no-any-return]
def _lazy_load_hook( # type: ignore[no-untyped-def]
- self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
"""load_state_dict pre-hook function for lazy buffers and parameters.
The purpose of this hook is to check the current state and/or
@@ -98,8 +107,9 @@ def _lazy_load_hook( # type: ignore[no-untyped-def]
state_initialized = state_dict[key].shape != (0,)
if module_initialized and not state_initialized:
raise RuntimeError(
- 'Can\'t load non-initialized buffers in already '
- 'initialized modules')
+ "Can't load non-initialized buffers in already "
+ "initialized modules"
+ )
elif not module_initialized and state_initialized:
# Here we need to avoid a tensor size mismatch
# this is a regular tensor without a materialize
@@ -114,12 +124,14 @@ def _lazy_load_hook( # type: ignore[no-untyped-def]
# parameters (see comments of ``state_dict``).
key = prefix + name
module_initialized = not isinstance(
- getattr(self, name), UninitializedParameter)
+ getattr(self, name), UninitializedParameter
+ )
state_initialized = key in state_dict
if module_initialized and not state_initialized:
raise RuntimeError(
- 'Can\'t load uninitialized parameters in already '
- 'initialized modules')
+ "Can't load uninitialized parameters in already "
+ "initialized modules"
+ )
elif not module_initialized and state_initialized:
getattr(self, name).materialize(state_dict[key].shape)
elif key not in state_dict and not module_initialized:
@@ -128,15 +140,15 @@ def _lazy_load_hook( # type: ignore[no-untyped-def]
class UninitializedParameter(torch.nn.Parameter):
-
def __repr__(self) -> str: # type: ignore[override]
- return 'Uninitialized lazy parameter'
+ return "Uninitialized lazy parameter"
- def share_memory_(self) -> 'UninitializedParameter':
+ def share_memory_(self) -> "UninitializedParameter":
raise RuntimeError(
- 'Can\'t share memory on an unitialized parameter. '
- 'Run forward to initialize the network before calling '
- '`module.share_memory()`.')
+ "Can't share memory on an unitialized parameter. "
+ "Run forward to initialize the network before calling "
+ "`module.share_memory()`."
+ )
@property
def is_leaf(self) -> bool: # type: ignore[override]
@@ -145,18 +157,20 @@ def is_leaf(self) -> bool: # type: ignore[override]
# for parameters; optimizers check for this attribute and raise an
# error if non-leaf tensors are detected.
frame = inspect.currentframe()
- package_name = frame.f_back.f_globals['__package__'] # type: ignore
- if package_name.startswith('torch.optim'):
- warnings.warn('''
+ package_name = frame.f_back.f_globals["__package__"] # type: ignore
+ if package_name.startswith("torch.optim"):
+ warnings.warn(
+ """
Use of uninitialized lazy parameter in Optimizer has been detected.
- Maybe you forgot to run forward before passing `module.parameters()` to the optimizer?''') # NOQA
+ Maybe you forgot to run forward before passing `module.parameters()` to the optimizer?"""
+ ) # NOQA
return True
def materialize(
- self,
- shape: Tuple[int, ...],
- device: Optional['DeviceLike'] = None,
- dtype: Optional[torch.dtype] = None,
+ self,
+ shape: Tuple[int, ...],
+ device: Optional["DeviceLike"] = None,
+ dtype: Optional[torch.dtype] = None,
) -> None:
r"""Create a Parameter with the same properties of the uninitialized
one. Given a shape, it materializes a parameter in the same device
diff --git a/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py b/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py
index cff046034..a7a97cd1b 100644
--- a/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py
+++ b/pytorch_pfn_extras/nn/modules/lazy_batchnorm.py
@@ -1,33 +1,35 @@
from typing import Any, Optional
import torch
-
-from pytorch_pfn_extras.nn.modules.lazy import LazyInitializationMixin
-from pytorch_pfn_extras.nn.modules.lazy import UninitializedParameter
+from pytorch_pfn_extras.nn.modules.lazy import (
+ LazyInitializationMixin,
+ UninitializedParameter,
+)
class _LazyBatchNorm( # type: ignore[misc]
- LazyInitializationMixin,
- torch.nn.modules.batchnorm._BatchNorm
+ LazyInitializationMixin, torch.nn.modules.batchnorm._BatchNorm
):
-
running_mean: Any
running_var: Any
- lazy_parameter_names = ('weight', 'bias')
+ lazy_parameter_names = ("weight", "bias")
- def __init__(self, num_features: Optional[int], *args: Any, **kwargs: Any) -> None:
+ def __init__(
+ self, num_features: Optional[int], *args: Any, **kwargs: Any
+ ) -> None:
super().__init__(num_features or 0, *args, **kwargs)
if not self.affine:
raise ValueError(
- 'LazyBatchNorm is not compatible with affine=False.'
- ' Use the regular BatchNorm layers instead')
+ "LazyBatchNorm is not compatible with affine=False."
+ " Use the regular BatchNorm layers instead"
+ )
# weight and bias are registered in the mixin
if num_features is None:
self.num_features: Optional[int] = None # type: ignore[assignment]
if self.track_running_stats:
# these buffers are not always needed
# so we avoid explicit initializations
- self.lazy_buffer_names = ('running_mean', 'running_var')
+ self.lazy_buffer_names = ("running_mean", "running_var")
def reset_parameters(self) -> None:
if self.lazy_parmeters_determined:
@@ -46,13 +48,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.track_running_stats:
assert isinstance(self.running_mean, torch.Tensor)
self.running_mean = torch.zeros(
- self.num_features, device=self.running_mean.device,
- dtype=self.running_mean.dtype
+ self.num_features,
+ device=self.running_mean.device,
+ dtype=self.running_mean.dtype,
)
assert isinstance(self.running_var, torch.Tensor)
self.running_var = torch.ones(
- self.num_features, device=self.running_var.device,
- dtype=self.running_mean.dtype
+ self.num_features,
+ device=self.running_var.device,
+ dtype=self.running_mean.dtype,
)
self.reset_parameters()
return super().forward(input)
@@ -64,6 +68,7 @@ class LazyBatchNorm1d(_LazyBatchNorm, torch.nn.BatchNorm1d): # type: ignore[mis
When ``num_features`` is ``None``, it is determined at the first time of
the forward step.
"""
+
pass
@@ -73,6 +78,7 @@ class LazyBatchNorm2d(_LazyBatchNorm, torch.nn.BatchNorm2d): # type: ignore[mis
When ``num_features`` is ``None``, it is determined at the first time of
the forward step.
"""
+
pass
@@ -82,4 +88,5 @@ class LazyBatchNorm3d(_LazyBatchNorm, torch.nn.BatchNorm3d): # type: ignore[mis
When ``num_features`` is ``None``, it is determined at the first time of
the forward step.
"""
+
pass
diff --git a/pytorch_pfn_extras/nn/modules/lazy_conv.py b/pytorch_pfn_extras/nn/modules/lazy_conv.py
index 32a9e5869..51f00b666 100644
--- a/pytorch_pfn_extras/nn/modules/lazy_conv.py
+++ b/pytorch_pfn_extras/nn/modules/lazy_conv.py
@@ -1,18 +1,18 @@
from typing import Any, Optional
import torch
-
from pytorch_pfn_extras.nn.modules.lazy import (
- LazyInitializationMixin, UninitializedParameter
+ LazyInitializationMixin,
+ UninitializedParameter,
)
class _LazyConvNd(LazyInitializationMixin):
-
- lazy_parameter_names = ('weight',)
+ lazy_parameter_names = ("weight",)
def __init__(
- self: Any, in_channels: Optional[int], *args: Any, **kwargs: Any) -> None:
+ self: Any, in_channels: Optional[int], *args: Any, **kwargs: Any
+ ) -> None:
super().__init__(in_channels or 0, *args, **kwargs)
if in_channels is None:
self.in_channels: Optional[int] = None
@@ -22,11 +22,17 @@ def forward(self: Any, input: torch.Tensor) -> torch.Tensor:
if isinstance(self.weight, UninitializedParameter):
self.in_channels = input.shape[1]
if self.transposed:
- shape = (self.in_channels, self.out_channels // self.groups,
- *self.kernel_size)
+ shape = (
+ self.in_channels,
+ self.out_channels // self.groups,
+ *self.kernel_size,
+ )
else:
- shape = (self.out_channels, self.in_channels // self.groups,
- *self.kernel_size)
+ shape = (
+ self.out_channels,
+ self.in_channels // self.groups,
+ *self.kernel_size,
+ )
self.weight = torch.nn.Parameter(self.weight.new_empty(*shape))
self.reset_parameters()
return super().forward(input) # type: ignore
@@ -44,6 +50,7 @@ class LazyConv1d(_LazyConvNd, torch.nn.Conv1d): # type: ignore[misc]
When ``in_channels`` is ``None``, it is determined at the first time of
the forward step.
"""
+
pass
@@ -53,6 +60,7 @@ class LazyConv2d(_LazyConvNd, torch.nn.Conv2d): # type: ignore[misc]
When ``in_channels`` is ``None``, it is determined at the first time of
the forward step.
"""
+
pass
@@ -62,4 +70,5 @@ class LazyConv3d(_LazyConvNd, torch.nn.Conv3d): # type: ignore[misc]
When ``in_channels`` is ``None``, it is determined at the first time of
the forward step.
"""
+
pass
diff --git a/pytorch_pfn_extras/nn/modules/lazy_linear.py b/pytorch_pfn_extras/nn/modules/lazy_linear.py
index aabc2e892..519dcde4f 100644
--- a/pytorch_pfn_extras/nn/modules/lazy_linear.py
+++ b/pytorch_pfn_extras/nn/modules/lazy_linear.py
@@ -1,9 +1,10 @@
from typing import Any, Optional
import torch
-
-from pytorch_pfn_extras.nn.modules.lazy import UninitializedParameter
-from pytorch_pfn_extras.nn.modules.lazy import LazyInitializationMixin
+from pytorch_pfn_extras.nn.modules.lazy import (
+ LazyInitializationMixin,
+ UninitializedParameter,
+)
class LazyLinear(LazyInitializationMixin, torch.nn.Linear): # type: ignore[misc]
@@ -13,9 +14,11 @@ class LazyLinear(LazyInitializationMixin, torch.nn.Linear): # type: ignore[misc
the forward step.
"""
- lazy_parameter_names = ('weight',)
+ lazy_parameter_names = ("weight",)
- def __init__(self, in_features: Optional[int], *args: Any, **kwargs: Any) -> None:
+ def __init__(
+ self, in_features: Optional[int], *args: Any, **kwargs: Any
+ ) -> None:
super().__init__(in_features or 0, *args, **kwargs)
if in_features is None:
self.in_features = None # type: ignore[assignment]
@@ -24,8 +27,9 @@ def __init__(self, in_features: Optional[int], *args: Any, **kwargs: Any) -> Non
def forward(self, input: torch.Tensor) -> torch.Tensor:
if isinstance(self.weight, UninitializedParameter):
self.in_features = input.shape[-1]
- self.weight = torch.nn.Parameter(self.weight.new_empty(
- self.out_features, self.in_features))
+ self.weight = torch.nn.Parameter(
+ self.weight.new_empty(self.out_features, self.in_features)
+ )
self.reset_parameters()
return super().forward(input)
diff --git a/pytorch_pfn_extras/nn/parallel/__init__.py b/pytorch_pfn_extras/nn/parallel/__init__.py
index 68840f079..458e3b5e4 100644
--- a/pytorch_pfn_extras/nn/parallel/__init__.py
+++ b/pytorch_pfn_extras/nn/parallel/__init__.py
@@ -1 +1,3 @@
-from pytorch_pfn_extras.nn.parallel.distributed import DistributedDataParallel # NOQA
+from pytorch_pfn_extras.nn.parallel.distributed import ( # NOQA
+ DistributedDataParallel,
+)
diff --git a/pytorch_pfn_extras/nn/parallel/distributed.py b/pytorch_pfn_extras/nn/parallel/distributed.py
index 4724cc40a..ab3ccb600 100644
--- a/pytorch_pfn_extras/nn/parallel/distributed.py
+++ b/pytorch_pfn_extras/nn/parallel/distributed.py
@@ -1,27 +1,34 @@
import logging
-
-from contextlib import contextmanager
+import threading
from collections import OrderedDict
+from contextlib import contextmanager
from typing import (
- Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence, Tuple,
- TypeVar, Union
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+ Union,
)
import torch
-from torch import nn
+from pytorch_pfn_extras.profiler import record
from torch import distributed as dist
-from torch.utils import hooks
+from torch import nn
from torch.autograd import Variable
from torch.autograd.profiler import record_function
-import threading
-
-from pytorch_pfn_extras.profiler import record
+from torch.utils import hooks
logger = logging.getLogger(__name__)
Tensors = Union[Tuple[torch.Tensor, ...], torch.Tensor]
DistFunc = Callable[[Sequence[torch.Tensor], Optional[dist.ProcessGroup]], None]
-HookFun = Callable[['DistributedDataParallel'], None]
+HookFun = Callable[["DistributedDataParallel"], None]
class _ForEachWrapper:
@@ -31,22 +38,25 @@ class _ForEachWrapper:
- torch with python for loop
- torch._foreach_xxx
"""
+
def __init__(self) -> None:
self.flatten = torch._utils._flatten_dense_tensors
self.unflatten = torch._utils._unflatten_dense_tensors
- self._enable_foreach = (hasattr(torch, "_foreach_add")
- and hasattr(torch, '_foreach_zero_'))
+ self._enable_foreach = hasattr(torch, "_foreach_add") and hasattr(
+ torch, "_foreach_zero_"
+ )
if not self._enable_foreach:
logger.warning(
"torch does not have _foreach_xxx functions."
- " Please use newer torch")
+ " Please use newer torch"
+ )
def multi_tensor_scale(
- self,
- src: Sequence[torch.Tensor],
- dst: Sequence[torch.Tensor],
- scale: float,
+ self,
+ src: Sequence[torch.Tensor],
+ dst: Sequence[torch.Tensor],
+ scale: float,
) -> None:
with torch.no_grad(): # type: ignore[no-untyped-call]
# _foreach_zero for long type is not supported in CUDA
@@ -75,16 +85,18 @@ def get_foreach_wrapper() -> _ForEachWrapper:
def _reduce(
- values: Sequence[torch.Tensor],
- group: Optional[dist.ProcessGroup],
+ values: Sequence[torch.Tensor],
+ group: Optional[dist.ProcessGroup],
) -> None:
size = sum([v.numel() for v in values])
# flatten values to improve the runtime perfomance of all-reduce
- coalesced = torch.empty(size, device=values[0].device,
- dtype=values[0].dtype)
+ coalesced = torch.empty(
+ size, device=values[0].device, dtype=values[0].dtype
+ )
coalesced_views = get_foreach_wrapper().unflatten( # type: ignore[no-untyped-call]
- coalesced, values)
+ coalesced, values
+ )
get_foreach_wrapper().multi_tensor_scale(values, coalesced_views, 1.0)
with record(
@@ -94,28 +106,32 @@ def _reduce(
# unflatten values
get_foreach_wrapper().multi_tensor_scale(
- coalesced_views, values,
- 1.0 / dist.get_world_size(group) # type: ignore[no-untyped-call]
+ coalesced_views,
+ values,
+ 1.0 / dist.get_world_size(group), # type: ignore[no-untyped-call]
)
def _broadcast(
- values: Sequence[torch.Tensor],
- group: Optional[dist.ProcessGroup]
+ values: Sequence[torch.Tensor], group: Optional[dist.ProcessGroup]
) -> None:
with torch.no_grad(): # type: ignore[no-untyped-call]
coalesced = get_foreach_wrapper().flatten( # type: ignore[no-untyped-call]
- values)
+ values
+ )
with record(
"torch.distributed.broadcast", use_cuda=torch.cuda.is_available()
):
dist.broadcast(coalesced, 0, group=group) # type: ignore[no-untyped-call]
src = get_foreach_wrapper().unflatten( # type: ignore[no-untyped-call]
- coalesced, values)
+ coalesced, values
+ )
get_foreach_wrapper().multi_tensor_scale(src, values, 1.0)
-def _group_by_type(values: Sequence[Optional[torch.Tensor]]) -> List[List[torch.Tensor]]:
+def _group_by_type(
+ values: Sequence[Optional[torch.Tensor]],
+) -> List[List[torch.Tensor]]:
groups: Dict[torch.dtype, List[torch.Tensor]] = {}
for value in values:
if value is None:
@@ -151,19 +167,24 @@ class DistributedDataParallel(nn.Module):
broadcast_function: Broadcast function
"""
- _unused_parameters = ["device_ids", "output_device", "dim",
- "find_unused_parameters", "check_reduction",
- "gradient_as_bucket_view"]
+ _unused_parameters = [
+ "device_ids",
+ "output_device",
+ "dim",
+ "find_unused_parameters",
+ "check_reduction",
+ "gradient_as_bucket_view",
+ ]
def __init__(
- self,
- module: nn.Module,
- broadcast_buffers: bool = True,
- negotiate_grads: bool = True,
- process_group: Optional[dist.ProcessGroup] = None,
- reduce_function: Optional[DistFunc] = None,
- broadcast_function: Optional[DistFunc] = None,
- **kwargs: Any
+ self,
+ module: nn.Module,
+ broadcast_buffers: bool = True,
+ negotiate_grads: bool = True,
+ process_group: Optional[dist.ProcessGroup] = None,
+ reduce_function: Optional[DistFunc] = None,
+ broadcast_function: Optional[DistFunc] = None,
+ **kwargs: Any,
) -> None:
"""
This module receives keyword arguments for the compatibility with
@@ -204,9 +225,9 @@ def __init__(
# synchronize initial parameters and buffers
params = dict(self.named_parameters())
buffers = dict(self.named_buffers())
- values = \
- [buffers[name] for name in self._sorted_buffer_keys] + \
- [params[name] for name in self._sorted_param_keys]
+ values = [buffers[name] for name in self._sorted_buffer_keys] + [
+ params[name] for name in self._sorted_param_keys
+ ]
if dist.is_initialized(): # type: ignore[no-untyped-call]
groups = _group_by_type(values)
for group in groups:
@@ -219,8 +240,7 @@ def __init__(
@contextmanager
def no_sync(self) -> Generator[None, None, None]:
- """A context manager to disable synchronization after backward
- """
+ """A context manager to disable synchronization after backward"""
prev = self._require_sync
self._require_sync = False
try:
@@ -234,13 +254,13 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.module(*args, **kwargs)
def load_state_dict(
- self,
- state_dict: 'Mapping[str, torch.Tensor]',
- strict: bool = True,
+ self,
+ state_dict: "Mapping[str, torch.Tensor]",
+ strict: bool = True,
) -> None:
self.module.load_state_dict(state_dict, strict=strict) # type: ignore[arg-type]
- T_destination = TypeVar('T_destination', bound=Mapping[str, torch.Tensor])
+ T_destination = TypeVar("T_destination", bound=Mapping[str, torch.Tensor])
def state_dict(self) -> Dict[str, Any]: # type: ignore[override]
return self.module.state_dict()
@@ -257,10 +277,7 @@ def register_comm_hook(self, hook: HookFun) -> hooks.RemovableHandle:
return handle
def _backward_hook(
- self,
- module: torch.nn.Module,
- gin: Tensors,
- gout: Tensors
+ self, module: torch.nn.Module, gin: Tensors, gout: Tensors
) -> None:
def _synchronize() -> None:
if not self._require_sync:
@@ -270,14 +287,17 @@ def _synchronize() -> None:
hook(self)
with record_function(
- "ppe.nn.parallel.DistributedDataParallel.synchronize"):
+ "ppe.nn.parallel.DistributedDataParallel.synchronize"
+ ):
params = dict(self.named_parameters())
if self._negotiate_grads:
# find parameters that have gradients
has_grads = torch.tensor(
- [params[name].grad is not None
- for name in self._sorted_param_keys],
- device=self._device
+ [
+ params[name].grad is not None
+ for name in self._sorted_param_keys
+ ],
+ device=self._device,
)
# cast to long because bool may not be used in all_reduce
@@ -288,19 +308,25 @@ def _synchronize() -> None:
use_cuda=torch.cuda.is_available(),
):
dist.all_reduce( # type: ignore[no-untyped-call]
- has_grads, op=dist.ReduceOp.MAX)
+ has_grads, op=dist.ReduceOp.MAX
+ )
- for name, has_grad in zip(self._sorted_param_keys,
- has_grads.bool().cpu()):
+ for name, has_grad in zip(
+ self._sorted_param_keys, has_grads.bool().cpu()
+ ):
# create zero tensor as a gradient if a parameter
# does not have the gradient and other processes
# require to synchronize this parameter.
if has_grad and params[name].grad is None:
- params[name].grad = \
- torch.zeros_like(params[name].data)
-
- grads = [params[name].grad for name in self._sorted_param_keys
- if params[name].grad is not None]
+ params[name].grad = torch.zeros_like(
+ params[name].data
+ )
+
+ grads = [
+ params[name].grad
+ for name in self._sorted_param_keys
+ if params[name].grad is not None
+ ]
groups = _group_by_type(grads)
with record(
"pytorch_pfn_extras.nn.parallel."
@@ -320,8 +346,7 @@ def _synchronize() -> None:
use_cuda=torch.cuda.is_available(),
):
for group in groups:
- self._broadcast_function(
- group, self._process_group)
+ self._broadcast_function(group, self._process_group)
# PyTorch will invoke `_synchronize` after the backward computation.
Variable._execution_engine.queue_callback(_synchronize)
@@ -344,6 +369,7 @@ def _input_to_device(self, obj: Any) -> Any:
if isinstance(obj, list) and len(obj) > 0:
return [self._input_to_device(x) for x in obj]
if isinstance(obj, dict) and len(obj) > 0:
- return {key: self._input_to_device(value)
- for key, value in obj.items()}
+ return {
+ key: self._input_to_device(value) for key, value in obj.items()
+ }
return obj
diff --git a/pytorch_pfn_extras/profiler/__init__.py b/pytorch_pfn_extras/profiler/__init__.py
index 4b63035f9..88abf507a 100644
--- a/pytorch_pfn_extras/profiler/__init__.py
+++ b/pytorch_pfn_extras/profiler/__init__.py
@@ -1,5 +1,5 @@
from pytorch_pfn_extras.profiler._record import record # NOQA
from pytorch_pfn_extras.profiler._record import record_function # NOQA
from pytorch_pfn_extras.profiler._record import record_iterable # NOQA
-from pytorch_pfn_extras.profiler._time_summary import get_time_summary # NOQA
from pytorch_pfn_extras.profiler._time_summary import TimeSummary # NOQA
+from pytorch_pfn_extras.profiler._time_summary import get_time_summary # NOQA
diff --git a/pytorch_pfn_extras/profiler/_record.py b/pytorch_pfn_extras/profiler/_record.py
index a2cc14a2b..e77635319 100644
--- a/pytorch_pfn_extras/profiler/_record.py
+++ b/pytorch_pfn_extras/profiler/_record.py
@@ -1,10 +1,17 @@
-from contextlib import contextmanager
import inspect
-from typing import Any, Callable, Generator, Iterable, Optional, TypeVar, TYPE_CHECKING
import types
+from contextlib import contextmanager
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generator,
+ Iterable,
+ Optional,
+ TypeVar,
+)
import torch
-
from pytorch_pfn_extras.profiler import _time_summary
from pytorch_pfn_extras.runtime import runtime_registry
@@ -18,7 +25,7 @@ def _infer_tag_name(frame: Optional[types.FrameType], depth: int) -> str:
frame = frame.f_back
assert frame is not None
frame_info = inspect.getframeinfo(frame, context=0)
- return '{}:{}:{}'.format(
+ return "{}:{}:{}".format(
inspect.getmodulename(frame_info.filename),
frame_info.lineno,
frame_info.function,
@@ -38,13 +45,12 @@ def complete(self) -> None:
@contextmanager
def record(
- tag: Optional[str],
- metric: Optional[str] = None,
- use_cuda: bool = False,
- enable: bool = True,
- device: 'DeviceLike' = 'cpu'
+ tag: Optional[str],
+ metric: Optional[str] = None,
+ use_cuda: bool = False,
+ enable: bool = True,
+ device: "DeviceLike" = "cpu",
) -> Generator[_time_summary._ReportNotification, None, None]:
-
if not enable:
yield _DummyReportNotification()
return
@@ -55,8 +61,7 @@ def record(
if metric is None:
metric = tag
- runtime_cls = runtime_registry.get_runtime_class_for_device_spec(
- device)
+ runtime_cls = runtime_registry.get_runtime_class_for_device_spec(device)
runtime_tracer = runtime_cls.trace
if use_cuda:
@@ -71,13 +76,13 @@ def record(
torch.cuda.nvtx.range_pop() # type: ignore[no-untyped-call]
-_T = TypeVar('_T')
+_T = TypeVar("_T")
def record_function(
- tag: Optional[str],
- use_cuda: bool = False,
- enable: bool = True,
+ tag: Optional[str],
+ use_cuda: bool = False,
+ enable: bool = True,
) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
def wrapper(f: Callable[..., _T]) -> Callable[..., _T]:
def wrapped(*args: Any, **kwargs: Any) -> _T:
@@ -90,11 +95,11 @@ def wrapped(*args: Any, **kwargs: Any) -> _T:
def record_iterable(
- tag: Optional[str],
- iter: Iterable[_T],
- divide_metric: bool = False,
- use_cuda: bool = False,
- enable: bool = True,
+ tag: Optional[str],
+ iter: Iterable[_T],
+ divide_metric: bool = False,
+ use_cuda: bool = False,
+ enable: bool = True,
) -> Iterable[_T]:
if tag is None:
tag = _infer_tag_name(inspect.currentframe(), depth=1)
diff --git a/pytorch_pfn_extras/profiler/_time_summary.py b/pytorch_pfn_extras/profiler/_time_summary.py
index d62bfd076..aa5530bf0 100644
--- a/pytorch_pfn_extras/profiler/_time_summary.py
+++ b/pytorch_pfn_extras/profiler/_time_summary.py
@@ -1,28 +1,27 @@
import atexit
-from contextlib import contextmanager
+import multiprocessing as mp
import os
-import time
-from typing import Callable, Dict, Generator, Optional, Tuple
-import threading
import queue
-import multiprocessing as mp
-import torch
+import threading
+import time
import weakref
+from contextlib import contextmanager
+from typing import Callable, Dict, Generator, Optional, Tuple
+import torch
from pytorch_pfn_extras.reporting import DictSummary
-
Events = Tuple[torch.cuda.Event, torch.cuda.Event]
class _ReportNotification:
def __init__(
- self,
- summary: 'TimeSummary',
- tag: str,
- use_cuda: bool,
- begin_event: Optional[torch.cuda.Event],
- begin: float,
+ self,
+ summary: "TimeSummary",
+ tag: str,
+ use_cuda: bool,
+ begin_event: Optional[torch.cuda.Event],
+ begin: float,
) -> None:
self._is_completed = True
self._summary = summary
@@ -36,19 +35,22 @@ def defer(self) -> None:
def complete(self) -> None:
self._summary.complete_report(
- self._tag, self._use_cuda, self._begin_event, self._begin)
+ self._tag, self._use_cuda, self._begin_event, self._begin
+ )
class _CPUWorker:
def __init__(
- self,
- add: Callable[[str, float], None],
- max_queue_size: int,
+ self,
+ add: Callable[[str, float], None],
+ max_queue_size: int,
) -> None:
self._add = add
self._max_queue_size = max_queue_size
self._initialized = False
- self._queue: Optional[mp.JoinableQueue[Optional[Tuple[str, float]]]] = None
+ self._queue: Optional[
+ mp.JoinableQueue[Optional[Tuple[str, float]]]
+ ] = None
self._thread: Optional[threading.Thread] = None
self._thread_exited = False
@@ -108,17 +110,17 @@ def _worker(self) -> None:
class _CUDAWorker:
def __init__(
- self,
- add: Callable[[str, float], None],
- max_queue_size: int,
+ self,
+ add: Callable[[str, float], None],
+ max_queue_size: int,
) -> None:
self._add = add
self._max_queue_size = max_queue_size
self._initialized = False
self._thread: Optional[threading.Thread] = None
- self._queue: Optional['queue.Queue[Optional[_QueueElem]]'] = None
+ self._queue: Optional["queue.Queue[Optional[_QueueElem]]"] = None
self._event_lock = threading.Lock()
- self._events: Optional['queue.Queue[torch.cuda.Event]'] = None
+ self._events: Optional["queue.Queue[torch.cuda.Event]"] = None
self._thread_exited = False
def initialize(self) -> None:
@@ -146,9 +148,9 @@ def synchronize(self) -> None:
self._queue.join()
def put(
- self,
- name: str,
- events: Tuple[torch.cuda.Event, torch.cuda.Event],
+ self,
+ name: str,
+ events: Tuple[torch.cuda.Event, torch.cuda.Event],
) -> None:
assert self._queue is not None
assert not self._thread_exited
@@ -181,13 +183,14 @@ def get_cuda_event(self) -> torch.cuda.Event:
with self._event_lock:
if self._events.empty():
event = torch.cuda.Event( # type: ignore[no-untyped-call]
- enable_timing=True)
+ enable_timing=True
+ )
self._events.put(event)
return self._events.get()
class _Finalizer:
- def __init__(self, ts: 'TimeSummary') -> None:
+ def __init__(self, ts: "TimeSummary") -> None:
self._ts = weakref.ref(ts)
def __call__(self) -> None:
@@ -209,7 +212,9 @@ class TimeSummary:
when the instance is created.
"""
- def __init__(self, *, max_queue_size: int = 1000, auto_init: bool = True) -> None:
+ def __init__(
+ self, *, max_queue_size: int = 1000, auto_init: bool = True
+ ) -> None:
self._summary_lock = threading.Lock()
self._summary = DictSummary()
self._additional_stats: Dict[str, float] = {}
@@ -217,7 +222,9 @@ def __init__(self, *, max_queue_size: int = 1000, auto_init: bool = True) -> Non
self._cpu_worker = _CPUWorker(self._add_from_worker, max_queue_size)
self._cuda_worker: Optional[_CUDAWorker] = None
if torch.cuda.is_available():
- self._cuda_worker = _CUDAWorker(self._add_from_worker, max_queue_size)
+ self._cuda_worker = _CUDAWorker(
+ self._add_from_worker, max_queue_size
+ )
self._initialized = False
self._master_pid = os.getpid()
@@ -242,7 +249,8 @@ def initialize(self) -> None:
raise RuntimeError(
"TimeSummary must be initialized in the same process as the "
"one created the instance. Please call initialize() in the "
- "main process.")
+ "main process."
+ )
self._cpu_worker.initialize()
if self._cuda_worker is not None:
self._cuda_worker.initialize()
@@ -276,8 +284,8 @@ def add(self, name: str, value: float) -> None:
@contextmanager
def summary(
- self,
- clear: bool = False,
+ self,
+ clear: bool = False,
) -> Generator[Tuple[DictSummary, Dict[str, float]], None, None]:
self.initialize()
try:
@@ -289,11 +297,11 @@ def summary(
self._additional_stats = {}
def complete_report(
- self,
- tag: str,
- use_cuda: bool,
- begin_event: Optional[torch.cuda.Event],
- begin: float,
+ self,
+ tag: str,
+ use_cuda: bool,
+ begin_event: Optional[torch.cuda.Event],
+ begin: float,
) -> None:
end = time.time()
assert self._cpu_worker._queue is not None
@@ -305,13 +313,14 @@ def complete_report(
end_event = self._cuda_worker.get_cuda_event()
end_event.record() # type: ignore[no-untyped-call]
self._cuda_worker._queue.put(
- (f"{tag}.cuda", (begin_event, end_event)))
+ (f"{tag}.cuda", (begin_event, end_event))
+ )
@contextmanager
def report(
- self,
- tag: str,
- use_cuda: bool = False,
+ self,
+ tag: str,
+ use_cuda: bool = False,
) -> Generator[_ReportNotification, None, None]:
"""Context manager to automatically report execution times.
@@ -332,7 +341,8 @@ def report(
try:
begin = time.time()
notification = _ReportNotification(
- self, tag, use_cuda, begin_event, begin)
+ self, tag, use_cuda, begin_event, begin
+ )
yield notification
finally:
if notification._is_completed:
@@ -343,6 +353,6 @@ def report(
def get_time_summary() -> TimeSummary:
- if not hasattr(_thread_local, 'time_summary'):
+ if not hasattr(_thread_local, "time_summary"):
_thread_local.time_summary = TimeSummary(auto_init=False)
return _thread_local.time_summary # type: ignore[no-any-return]
diff --git a/pytorch_pfn_extras/reporting.py b/pytorch_pfn_extras/reporting.py
index 0676e35ba..a4132500e 100644
--- a/pytorch_pfn_extras/reporting.py
+++ b/pytorch_pfn_extras/reporting.py
@@ -2,17 +2,25 @@
import contextlib
import threading
import types
+import warnings
from typing import (
- Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence,
- Tuple, Type, Union,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ overload,
)
-from typing import overload
-import warnings
import numpy
import torch
-
Scalar = Union[torch.Tensor, numpy.ndarray, numpy.floating, float]
FloatLikeValue = Union[Scalar, float]
Value = Union[Scalar, Callable[[], float]]
@@ -33,7 +41,8 @@ def _nograd(value: Value) -> Value:
def _nograd(
- value: Union[FloatLikeValue, Value]) -> Union[FloatLikeValue, Value]:
+ value: Union[FloatLikeValue, Value]
+) -> Union[FloatLikeValue, Value]:
if isinstance(value, torch.Tensor):
return value.detach()
return value
@@ -99,10 +108,10 @@ def __enter__(self) -> None:
_get_reporters().append(self)
def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_value: Optional[BaseException],
- traceback: Optional[types.TracebackType],
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
) -> None:
"""Recovers the previous reporter object to the current."""
_get_reporters().pop()
@@ -148,9 +157,7 @@ def add_observer(self, name: str, observer: torch.nn.Module) -> None:
self._observer_names[id(observer)] = name
def add_observers(
- self,
- prefix: str,
- observers: Sequence[Tuple[str, torch.nn.Module]]
+ self, prefix: str, observers: Sequence[Tuple[str, torch.nn.Module]]
) -> None:
"""Registers multiple observers at once.
@@ -165,9 +172,9 @@ def add_observers(
self._observer_names[id(observer)] = prefix + name
def report(
- self,
- values: Mapping[str, Value],
- observer: Optional[torch.nn.Module] = None,
+ self,
+ values: Mapping[str, Value],
+ observer: Optional[torch.nn.Module] = None,
) -> None:
"""Reports observed values.
@@ -193,10 +200,11 @@ def report(
observer_id = id(observer)
if observer_id not in self._observer_names:
raise KeyError(
- 'Given observer is not registered to the reporter.')
+ "Given observer is not registered to the reporter."
+ )
observer_name = self._observer_names[observer_id]
for key, value in values.items():
- name = '%s/%s' % (observer_name, key)
+ name = "%s/%s" % (observer_name, key)
self.observation[name] = value
else:
self.observation.update(values)
@@ -216,8 +224,8 @@ def get_current_reporter() -> Reporter:
def report(
- values: Mapping[str, Value],
- observer: Optional[torch.nn.Module] = None,
+ values: Mapping[str, Value],
+ observer: Optional[torch.nn.Module] = None,
) -> None:
"""Reports observed values with the current reporter object.
@@ -352,20 +360,22 @@ def state_dict(self) -> Dict[str, Any]:
try:
# Save the stats as python scalars in order to avoid
# different device errors when loading them back
- state = {'_x': float(self._x),
- '_x2': float(self._x2),
- '_n': int(self._n)}
+ state = {
+ "_x": float(self._x),
+ "_x2": float(self._x2),
+ "_n": int(self._n),
+ }
except KeyError:
- warnings.warn('The previous statistics are not saved.')
+ warnings.warn("The previous statistics are not saved.")
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
# Casting here is because of backward compatibility
# Restore previously taken snapshots with autoload
self._add_deferred_values()
- self._x = float(_nograd(to_load['_x']))
- self._x2 = float(_nograd(to_load['_x2']))
- self._n = int(_nograd(to_load['_n']))
+ self._x = float(_nograd(to_load["_x"]))
+ self._x2 = float(_nograd(to_load["_x2"]))
+ self._n = int(_nograd(to_load["_n"]))
def __add__(self, other: "Summary") -> "Summary":
s = Summary()
@@ -404,10 +414,11 @@ def add(self, d: Mapping[str, Union[Value, Tuple[Value, Scalar]]]) -> None:
w: Scalar = 1
if isinstance(v, tuple):
v, w = v
- if not numpy.isscalar(w) and not getattr(w, 'ndim', -1) == 0:
+ if not numpy.isscalar(w) and not getattr(w, "ndim", -1) == 0:
raise ValueError(
- 'Given weight to {} was not scalar.'.format(k))
- if callable(v) or numpy.isscalar(v) or getattr(v, 'ndim', -1) == 0:
+ "Given weight to {} was not scalar.".format(k)
+ )
+ if callable(v) or numpy.isscalar(v) or getattr(v, "ndim", -1) == 0:
summaries[k].add(v, weight=w)
def compute_mean(self) -> Dict[str, Scalar]:
@@ -420,8 +431,10 @@ def compute_mean(self) -> Dict[str, Scalar]:
dict: Dictionary of mean values.
"""
- return {name: summary.compute_mean()
- for name, summary in self._summaries.items()}
+ return {
+ name: summary.compute_mean()
+ for name, summary in self._summaries.items()
+ }
def make_statistics(self) -> Dict[str, Scalar]:
"""Creates a dictionary of statistics.
@@ -439,13 +452,14 @@ def make_statistics(self) -> Dict[str, Scalar]:
for name, summary in self._summaries.items():
mean, std = summary.make_statistics()
stats[name] = mean
- stats[name + '.std'] = std
+ stats[name + ".std"] = std
return stats
def state_dict(self) -> Dict[str, Any]:
return {
- name: summ.state_dict() for name, summ in self._summaries.items()}
+ name: summ.state_dict() for name, summ in self._summaries.items()
+ }
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
self._summaries.clear()
diff --git a/pytorch_pfn_extras/runtime/__init__.py b/pytorch_pfn_extras/runtime/__init__.py
index 40bd83c15..0fb8258f8 100644
--- a/pytorch_pfn_extras/runtime/__init__.py
+++ b/pytorch_pfn_extras/runtime/__init__.py
@@ -1,6 +1,5 @@
+from pytorch_pfn_extras.runtime._registry import _RuntimeRegistry # NOQA
from pytorch_pfn_extras.runtime._runtime import BaseRuntime # NOQA
from pytorch_pfn_extras.runtime._runtime import PyTorchRuntime # NOQA
-from pytorch_pfn_extras.runtime._registry import _RuntimeRegistry # NOQA
-
runtime_registry = _RuntimeRegistry(PyTorchRuntime)
diff --git a/pytorch_pfn_extras/runtime/_autocast.py b/pytorch_pfn_extras/runtime/_autocast.py
index 228f1ce1a..fe86b1b5c 100644
--- a/pytorch_pfn_extras/runtime/_autocast.py
+++ b/pytorch_pfn_extras/runtime/_autocast.py
@@ -1,13 +1,16 @@
import contextlib
from typing import Any, Dict, Generator
+
from pytorch_pfn_extras._torch_version import requires
_cuda_amp_available = False
try:
import torch.cuda.amp
+
_cuda_amp_available = torch.cuda.is_available() and hasattr(
- torch.cuda.amp, 'autocast')
+ torch.cuda.amp, "autocast"
+ )
except ImportError:
pass
@@ -19,22 +22,23 @@ def __init__(
has_grad_scaler: bool,
) -> None:
autocast_options = autocast_options.copy()
- self._enabled = autocast_options.pop('enabled', True)
- self._device_type = autocast_options.pop('device_type', 'cuda')
+ self._enabled = autocast_options.pop("enabled", True)
+ self._device_type = autocast_options.pop("device_type", "cuda")
self._options = autocast_options
self._use_old_ac = not requires("1.10.0")
- if (
- self._enabled and self._use_old_ac and self._device_type != 'cuda'
- ):
- raise RuntimeError("Autocast only work with CUDA devices for PyTorch 1.9")
+ if self._enabled and self._use_old_ac and self._device_type != "cuda":
+ raise RuntimeError(
+ "Autocast only work with CUDA devices for PyTorch 1.9"
+ )
if not _cuda_amp_available:
- if (
- has_grad_scaler
- or (self._enabled and self._device_type == "cuda")
+ if has_grad_scaler or (
+ self._enabled and self._device_type == "cuda"
):
- raise RuntimeError('Requested AMP features but torch.cuda.amp'
- ' is not enabled')
+ raise RuntimeError(
+ "Requested AMP features but torch.cuda.amp"
+ " is not enabled"
+ )
@contextlib.contextmanager
def autocast(self, enabled: bool = True) -> Generator[None, None, None]:
diff --git a/pytorch_pfn_extras/runtime/_registry.py b/pytorch_pfn_extras/runtime/_registry.py
index 768b7a0c2..295b64201 100644
--- a/pytorch_pfn_extras/runtime/_registry.py
+++ b/pytorch_pfn_extras/runtime/_registry.py
@@ -1,8 +1,7 @@
from typing import Dict, Type
import torch
-
-from pytorch_pfn_extras.runtime._runtime import DeviceLike, BaseRuntime
+from pytorch_pfn_extras.runtime._runtime import BaseRuntime, DeviceLike
class _RuntimeRegistry:
@@ -11,17 +10,18 @@ def __init__(self, fallback_class: Type[BaseRuntime]):
self._fallback_class = fallback_class
def register(
- self,
- device_type: str,
- runtime_class: Type[BaseRuntime],
+ self,
+ device_type: str,
+ runtime_class: Type[BaseRuntime],
) -> None:
self._runtimes[device_type] = runtime_class
def get_runtime_class_for_device_spec(
- self, device: DeviceLike) -> Type[BaseRuntime]:
+ self, device: DeviceLike
+ ) -> Type[BaseRuntime]:
if isinstance(device, torch.device):
device_type = device.type
else:
assert isinstance(device, str)
- device_type = device.split(':')[0]
+ device_type = device.split(":")[0]
return self._runtimes.get(device_type, self._fallback_class)
diff --git a/pytorch_pfn_extras/runtime/_runtime.py b/pytorch_pfn_extras/runtime/_runtime.py
index 8e27d284f..e374045da 100644
--- a/pytorch_pfn_extras/runtime/_runtime.py
+++ b/pytorch_pfn_extras/runtime/_runtime.py
@@ -1,12 +1,18 @@
import contextlib
import types
-
from typing import (
- Any, Dict, Generator, Iterable, Optional, Set, Tuple, Union, TYPE_CHECKING
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ Optional,
+ Set,
+ Tuple,
+ Union,
)
import torch
-
from pytorch_pfn_extras.handler._code_block import CodeBlock
from pytorch_pfn_extras.runtime import _autocast
@@ -35,7 +41,9 @@ class BaseRuntime:
"""
def __init__(
- self, device_spec: DeviceLike, options: Dict[str, Any],
+ self,
+ device_spec: DeviceLike,
+ options: Dict[str, Any],
) -> None:
self.device_spec = device_spec
self.options = options
@@ -56,7 +64,8 @@ def convert_batch(self, args: Any) -> Any:
if isinstance(args, tuple) and hasattr(args, "_fields"):
# namedtuple
return args._replace( # type: ignore[attr-defined]
- **self._convert_batch_dict(args._asdict())) # type: ignore
+ **self._convert_batch_dict(args._asdict()) # type: ignore
+ )
if isinstance(args, dict):
return self._convert_batch_dict(args)
if isinstance(args, (list, tuple)):
@@ -142,7 +151,7 @@ def train_epoch_end(self, module: torch.nn.Module) -> None:
def train_pre_step(
self,
- trainer: 'Trainer',
+ trainer: "Trainer",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -165,7 +174,7 @@ def train_pre_step(
def train_post_step(
self,
- trainer: 'Trainer',
+ trainer: "Trainer",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -221,7 +230,7 @@ def train_validation_end(self, module: torch.nn.Module) -> None:
def eval_pre_step(
self,
- evaluator: 'Evaluator',
+ evaluator: "Evaluator",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -241,7 +250,7 @@ def eval_pre_step(
def eval_post_step(
self,
- evaluator: 'Evaluator',
+ evaluator: "Evaluator",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -262,7 +271,11 @@ def eval_post_step(
"""
raise NotImplementedError()
- def execute(self, code_block: CodeBlock, batch: Any,) -> Any:
+ def execute(
+ self,
+ code_block: CodeBlock,
+ batch: Any,
+ ) -> Any:
"""Method called by the CodeBlocks API to do device dependent execution.
Args:
@@ -297,7 +310,9 @@ def map(
@classmethod
@contextlib.contextmanager
- def trace(cls, event_name: Optional[str], arg: Any) -> Generator[None, None, None]:
+ def trace(
+ cls, event_name: Optional[str], arg: Any
+ ) -> Generator[None, None, None]:
"""Context manager for tracing PPE events in the custom device tools.
Args:
@@ -328,13 +343,18 @@ class PyTorchRuntime(BaseRuntime):
"""
def __init__(
- self, device_spec: DeviceLike, options: Dict[str, Any],
+ self,
+ device_spec: DeviceLike,
+ options: Dict[str, Any],
) -> None:
super().__init__(device_spec, options)
self._grad_scaler = options.get("grad_scaler", None)
autocast_options = options.get("autocast", False)
if isinstance(autocast_options, bool):
- autocast_options = {"enabled": autocast_options, "device_type": "cuda"}
+ autocast_options = {
+ "enabled": autocast_options,
+ "device_type": "cuda",
+ }
self._autocast = _autocast._AutocastManager(
autocast_options, self._grad_scaler is not None
)
@@ -376,7 +396,7 @@ def train_validation_end(self, module: torch.nn.Module) -> None:
def train_pre_step(
self,
- trainer: 'Trainer',
+ trainer: "Trainer",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -385,7 +405,7 @@ def train_pre_step(
def train_post_step(
self,
- trainer: 'Trainer',
+ trainer: "Trainer",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -395,7 +415,7 @@ def train_post_step(
def eval_pre_step(
self,
- evaluator: 'Evaluator',
+ evaluator: "Evaluator",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -404,7 +424,7 @@ def eval_pre_step(
def eval_post_step(
self,
- evaluator: 'Evaluator',
+ evaluator: "Evaluator",
module: torch.nn.Module,
batch_idx: int,
batch: Any,
@@ -412,7 +432,11 @@ def eval_post_step(
) -> None:
pass
- def execute(self, code_block: CodeBlock, batch: Any,) -> Any:
+ def execute(
+ self,
+ code_block: CodeBlock,
+ batch: Any,
+ ) -> Any:
# Run forward, backward and optimize steps depending on codeblock opts
if self._grad_scaler is None:
@@ -439,10 +463,7 @@ def _scale(x: torch.Tensor) -> torch.Tensor:
isinstance(v, torch.Tensor)
and v.grad_fn is not None
and v.numel() == 1
- and (
- v.dtype.is_floating_point
- or v.dtype.is_complex
- )
+ and (v.dtype.is_floating_point or v.dtype.is_complex)
):
_scale(v).backward() # type: ignore[no-untyped-call]
else:
@@ -483,7 +504,9 @@ def map(
@classmethod
@contextlib.contextmanager
- def trace(cls, event_name: Optional[str], arg: Any) -> Generator[None, None, None]:
+ def trace(
+ cls, event_name: Optional[str], arg: Any
+ ) -> Generator[None, None, None]:
"""Context manager for tracing PPE events in the custom device tools.
Args:
@@ -515,7 +538,9 @@ def _getstate_without_runtime(self): # type: ignore
# remove runtime class and getstate
def _remove_runtime_class(state): # type: ignore
- state = {k: v for k, v in state.items() if k != _RUNTIME_TAG_NAME}
+ state = {
+ k: v for k, v in state.items() if k != _RUNTIME_TAG_NAME
+ }
for k, v in state.items():
if isinstance(v, dict):
state[k] = _remove_runtime_class(v) # type: ignore
@@ -528,6 +553,7 @@ def _remove_runtime_class(state): # type: ignore
return state
return _remove_runtime_class(state) # type: ignore
+
return _getstate_without_runtime
getstate = None
@@ -537,7 +563,7 @@ def _remove_runtime_class(state): # type: ignore
setattr( # NOQA
module,
"__getstate__",
- types.MethodType(mk_getstate(getstate), module) # type: ignore
+ types.MethodType(mk_getstate(getstate), module), # type: ignore
)
diff --git a/pytorch_pfn_extras/runtime/_to.py b/pytorch_pfn_extras/runtime/_to.py
index 4db4b0556..32770d9b0 100644
--- a/pytorch_pfn_extras/runtime/_to.py
+++ b/pytorch_pfn_extras/runtime/_to.py
@@ -1,21 +1,19 @@
from typing import Any, Dict, Optional, Type, TypeVar
-import torch
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras.runtime._runtime import DeviceLike, BaseRuntime
-
+import torch
+from pytorch_pfn_extras.runtime._runtime import BaseRuntime, DeviceLike
-ModuleOrTensor = TypeVar('ModuleOrTensor', torch.nn.Module, torch.Tensor)
+ModuleOrTensor = TypeVar("ModuleOrTensor", torch.nn.Module, torch.Tensor)
def to(
- module_or_tensor: ModuleOrTensor,
- device: DeviceLike,
- *,
- options: Optional[Dict[str, Any]] = None,
- runtime_class: Optional[Type[BaseRuntime]] = None,
- config: Optional[Dict[str, Any]] = None,
+ module_or_tensor: ModuleOrTensor,
+ device: DeviceLike,
+ *,
+ options: Optional[Dict[str, Any]] = None,
+ runtime_class: Optional[Type[BaseRuntime]] = None,
+ config: Optional[Dict[str, Any]] = None,
) -> ModuleOrTensor:
"""A function to transfer the given object to the given device.
@@ -52,7 +50,7 @@ def to(
if config is not None:
options = config
elif config is not None:
- raise ValueError('options and config cannot be specified together')
+ raise ValueError("options and config cannot be specified together")
if runtime_class is None:
registry = ppe.runtime.runtime_registry
@@ -67,4 +65,4 @@ def to(
elif isinstance(obj, torch.Tensor):
return runtime.move_tensor(obj)
else:
- raise ValueError('Unsupported type for module_or_tensor')
+ raise ValueError("Unsupported type for module_or_tensor")
diff --git a/pytorch_pfn_extras/torchscript.py b/pytorch_pfn_extras/torchscript.py
index f452a71fe..57788379c 100644
--- a/pytorch_pfn_extras/torchscript.py
+++ b/pytorch_pfn_extras/torchscript.py
@@ -1,18 +1,23 @@
from typing import Any, Callable, List, Tuple
+
import torch
# Run jit pass with post lint
-def run_jit_pass(p: Callable, g: torch._C.Graph, *args: Any, **kwargs: Any) -> None:
+def run_jit_pass(
+ p: Callable, g: torch._C.Graph, *args: Any, **kwargs: Any
+) -> None:
p(g, *args, **kwargs)
torch._C._jit_pass_lint(g)
-def find_inplace(g: torch._C.Graph) -> Tuple[torch._C.Graph, List[torch._C.Node]]:
+def find_inplace(
+ g: torch._C.Graph,
+) -> Tuple[torch._C.Graph, List[torch._C.Node]]:
g = g.copy()
run_jit_pass(torch._C._jit_pass_inline, g)
nodes = []
for n in g.nodes():
- if n.kind().endswith('_'):
+ if n.kind().endswith("_"):
nodes.append(n)
return g, nodes
diff --git a/pytorch_pfn_extras/training/__init__.py b/pytorch_pfn_extras/training/__init__.py
index ffe617918..8fae0df54 100644
--- a/pytorch_pfn_extras/training/__init__.py
+++ b/pytorch_pfn_extras/training/__init__.py
@@ -1,14 +1,16 @@
-from pytorch_pfn_extras.training.extension import Extension # NOQA
-from pytorch_pfn_extras.training.extension import ExtensionEntry # NOQA
-from pytorch_pfn_extras.training.extension import make_extension # NOQA
+from pytorch_pfn_extras.training import extensions # NOQA
+from pytorch_pfn_extras.training._evaluator import DistributedEvaluator # NOQA
+from pytorch_pfn_extras.training._evaluator import Evaluator # NOQA
+from pytorch_pfn_extras.training._manager_protocol import ( # NOQA
+ ExtensionsManagerProtocol,
+)
+from pytorch_pfn_extras.training._trainer import Trainer # NOQA
from pytorch_pfn_extras.training.extension import PRIORITY_EDITOR # NOQA
from pytorch_pfn_extras.training.extension import PRIORITY_READER # NOQA
from pytorch_pfn_extras.training.extension import PRIORITY_WRITER # NOQA
-from pytorch_pfn_extras.training import extensions # NOQA
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol # NOQA
+from pytorch_pfn_extras.training.extension import Extension # NOQA
+from pytorch_pfn_extras.training.extension import ExtensionEntry # NOQA
+from pytorch_pfn_extras.training.extension import make_extension # NOQA
from pytorch_pfn_extras.training.manager import ExtensionsManager # NOQA
from pytorch_pfn_extras.training.manager import IgniteExtensionsManager # NOQA
from pytorch_pfn_extras.training.metrics import AccuracyMetric # NOQA
-from pytorch_pfn_extras.training._trainer import Trainer # NOQA
-from pytorch_pfn_extras.training._evaluator import Evaluator # NOQA
-from pytorch_pfn_extras.training._evaluator import DistributedEvaluator # NOQA
diff --git a/pytorch_pfn_extras/training/_evaluator.py b/pytorch_pfn_extras/training/_evaluator.py
index d48214dbc..f6bcefb79 100644
--- a/pytorch_pfn_extras/training/_evaluator.py
+++ b/pytorch_pfn_extras/training/_evaluator.py
@@ -1,22 +1,27 @@
import contextlib
import queue
from typing import (
- Any, Callable, Generator, Iterable, Mapping, Optional, Sequence,
- Union, TYPE_CHECKING,
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generator,
+ Iterable,
+ Mapping,
+ Optional,
+ Sequence,
+ Union,
)
import torch
import torch.distributed
-
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training.extensions import evaluator
-
from pytorch_pfn_extras.training.metrics import Batch as DictBatch
if TYPE_CHECKING:
from pytorch_pfn_extras.handler import BaseHandler
- from pytorch_pfn_extras.training.metrics import MetricType
from pytorch_pfn_extras.reporting import Observation
+ from pytorch_pfn_extras.training.metrics import MetricType
@contextlib.contextmanager
@@ -27,9 +32,9 @@ def _nullcontext() -> Generator[None, None, None]:
@contextlib.contextmanager
def _progress_bar(
- name: str,
- required: bool,
- size: int,
+ name: str,
+ required: bool,
+ size: int,
) -> Generator[Callable[[int], None], None, None]:
if required:
progress = evaluator.IterationStatus(size)
@@ -38,6 +43,7 @@ def _progress_bar(
def update(i: int) -> None:
progress.current_position = i
pbar.update()
+
yield update
pbar.close()
@@ -47,21 +53,22 @@ def update(i: int) -> None:
class Evaluator:
def __init__(
- self,
- handler: 'BaseHandler',
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- *,
- progress_bar: bool = False,
- metrics: Optional[Sequence['MetricType']] = None,
- profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
+ self,
+ handler: "BaseHandler",
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ *,
+ progress_bar: bool = False,
+ metrics: Optional[Sequence["MetricType"]] = None,
+ profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
):
super().__init__()
if not isinstance(models, dict):
if not isinstance(models, torch.nn.Module):
raise ValueError(
- 'model must be an instance of dict or toch.nn.Module')
- self.models = {'main': models}
+ "model must be an instance of dict or toch.nn.Module"
+ )
+ self.models = {"main": models}
else:
self.models = models
@@ -72,8 +79,7 @@ def __init__(
self._profile = profile
for name, model in self.models.items():
self._reporter.add_observer(name, model)
- self._reporter.add_observers(
- name, model.named_modules())
+ self._reporter.add_observers(name, model.named_modules())
def _process_metrics(self, ins: DictBatch, outs: DictBatch) -> DictBatch:
for metric in self._metrics:
@@ -81,15 +87,16 @@ def _process_metrics(self, ins: DictBatch, outs: DictBatch) -> DictBatch:
return outs
def _complete_step(
- self, idx: int, outs: DictBatch, *, is_deferred: bool = False
+ self, idx: int, outs: DictBatch, *, is_deferred: bool = False
) -> None:
c_idx = self._idxs.get()
# Asure that iterations complete in order
if c_idx != idx:
raise RuntimeError(
- 'Completed a not expected iteration. '
- '{} was expected but completion of {} happened'.format(
- c_idx, idx)
+ "Completed a not expected iteration. "
+ "{} was expected but completion of {} happened".format(
+ c_idx, idx
+ )
)
x = self._inputs.get()
observed = self._observed.get()
@@ -106,10 +113,7 @@ def _gather_summaries(self) -> None:
pass
def run(
- self,
- loader: Iterable[Any],
- *,
- eval_len: Optional[int] = None
+ self, loader: Iterable[Any], *, eval_len: Optional[int] = None
) -> None:
"""Executes the evaluation loop.
@@ -120,9 +124,9 @@ def run(
The number of iterations per one evaluation epoch.
"""
# Note: setup_manager is done by the Trainer.
- self._idxs: 'queue.Queue[int]' = queue.Queue()
- self._inputs: 'queue.Queue[DictBatch]' = queue.Queue()
- self._observed: 'queue.Queue[Observation]' = queue.Queue()
+ self._idxs: "queue.Queue[int]" = queue.Queue()
+ self._inputs: "queue.Queue[DictBatch]" = queue.Queue()
+ self._observed: "queue.Queue[Observation]" = queue.Queue()
if eval_len is None:
eval_len = len(loader) # type: ignore[arg-type]
@@ -131,7 +135,7 @@ def run(
self._summary = reporting.DictSummary()
observation: Observation = {}
self.handler.eval_loop_begin(self)
- self._pbar = _progress_bar('validation', self._progress_bar, eval_len)
+ self._pbar = _progress_bar("validation", self._progress_bar, eval_len)
self._update = self._pbar.__enter__()
loader_iter = iter(loader)
with self._profile or _nullcontext() as prof:
@@ -146,7 +150,8 @@ def run(
self._observed.put(observation)
with self._reporter.scope(observation):
self.handler.eval_step(
- self, idx, x, self._complete_step)
+ self, idx, x, self._complete_step
+ )
# Some of the DataLoaders might need an explicit break
# since they could start cycling on their data
if (idx + 1) == eval_len:
@@ -165,14 +170,16 @@ def run(
class DistributedEvaluator(Evaluator):
def __init__(
- self,
- handler: 'BaseHandler',
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- *,
- progress_bar: bool = False,
- metrics: Optional[Sequence['MetricType']] = None,
+ self,
+ handler: "BaseHandler",
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ *,
+ progress_bar: bool = False,
+ metrics: Optional[Sequence["MetricType"]] = None,
):
- super().__init__(handler, models, progress_bar=progress_bar, metrics=metrics)
+ super().__init__(
+ handler, models, progress_bar=progress_bar, metrics=metrics
+ )
if not torch.distributed.is_initialized(): # type: ignore[no-untyped-call]
raise RuntimeError("PyTorch distributed module is not initialized.")
diff --git a/pytorch_pfn_extras/training/_manager_protocol.py b/pytorch_pfn_extras/training/_manager_protocol.py
index c0a9ad401..624acff1d 100644
--- a/pytorch_pfn_extras/training/_manager_protocol.py
+++ b/pytorch_pfn_extras/training/_manager_protocol.py
@@ -1,17 +1,15 @@
-from typing import Mapping, Optional, TYPE_CHECKING
-from typing_extensions import Protocol
+from typing import TYPE_CHECKING, Mapping, Optional
import torch
+from typing_extensions import Protocol
if TYPE_CHECKING:
+ from pytorch_pfn_extras import reporting, writing
from pytorch_pfn_extras.training import trigger as trigger_module
from pytorch_pfn_extras.training.extension import Extension
- from pytorch_pfn_extras import writing
- from pytorch_pfn_extras import reporting
class ExtensionsManagerProtocol(Protocol):
-
@property
def iteration(self) -> int:
...
@@ -53,7 +51,7 @@ def stop_trigger(self) -> bool:
...
@property
- def _stop_trigger(self) -> 'trigger_module.Trigger':
+ def _stop_trigger(self) -> "trigger_module.Trigger":
...
@property
@@ -61,16 +59,16 @@ def out(self) -> str:
...
@property
- def writer(self) -> Optional['writing.Writer']:
+ def writer(self) -> Optional["writing.Writer"]:
...
@property
- def reporter(self) -> 'reporting.Reporter':
+ def reporter(self) -> "reporting.Reporter":
...
- def get_extension(self, name: str) -> 'Extension':
+ def get_extension(self, name: str) -> "Extension":
...
@property
- def observation(self) -> 'reporting.Observation':
+ def observation(self) -> "reporting.Observation":
...
diff --git a/pytorch_pfn_extras/training/_trainer.py b/pytorch_pfn_extras/training/_trainer.py
index 6e6d274b2..fe1c46d1c 100644
--- a/pytorch_pfn_extras/training/_trainer.py
+++ b/pytorch_pfn_extras/training/_trainer.py
@@ -4,24 +4,33 @@
import time
import warnings
from typing import (
- Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union, TYPE_CHECKING
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
)
+import pytorch_pfn_extras.reporting as reporting
import torch
-
from pytorch_pfn_extras import training
+from pytorch_pfn_extras.profiler import record
from pytorch_pfn_extras.training import extension as extension
from pytorch_pfn_extras.training import trigger as trigger_module
-import pytorch_pfn_extras.reporting as reporting
-from pytorch_pfn_extras.profiler import record
-
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
from pytorch_pfn_extras.training.trigger import Trigger, TriggerLike
if TYPE_CHECKING:
from pytorch_pfn_extras import handler as handler_module
- from pytorch_pfn_extras.training._evaluator import Evaluator
from pytorch_pfn_extras.profiler._time_summary import _ReportNotification
+ from pytorch_pfn_extras.training._evaluator import Evaluator
@contextlib.contextmanager
@@ -32,67 +41,89 @@ def _nullcontext() -> Generator[None, None, None]:
class Trainer:
def __init__(
- self,
- handler: 'handler_module.BaseHandler',
- *,
- evaluator: Optional[Union[
- 'Evaluator', Tuple['Evaluator', TriggerLike],
- Mapping[str, Union['Evaluator', Tuple['Evaluator', TriggerLike]]]]],
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
- **kwargs: Any,
+ self,
+ handler: "handler_module.BaseHandler",
+ *,
+ evaluator: Optional[
+ Union[
+ "Evaluator",
+ Tuple["Evaluator", TriggerLike],
+ Mapping[
+ str, Union["Evaluator", Tuple["Evaluator", TriggerLike]]
+ ],
+ ]
+ ],
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ profile: Optional[torch.profiler.profile] = None, # type: ignore[name-defined]
+ **kwargs: Any,
):
self.handler = handler
- self._manager: Optional['training.ExtensionsManager'] = None
+ self._manager: Optional["training.ExtensionsManager"] = None
# The followings are used when setting up a manager instance
if not isinstance(models, dict):
if not isinstance(models, torch.nn.Module):
raise ValueError(
- 'model must be an instance of dict or toch.nn.Module')
- self._models = {'main': models}
+ "model must be an instance of dict or toch.nn.Module"
+ )
+ self._models = {"main": models}
else:
self._models = models
self._kwargs = kwargs
self._profile = profile
- self._enable_profile = kwargs.get('enable_profile', profile is not None)
+ self._enable_profile = kwargs.get("enable_profile", profile is not None)
self._extensions: List[ # list of (args, kwargs)
- Tuple[Tuple[
- Union['extension.ExtensionLike', extension.ExtensionEntry],
- Optional[str], 'TriggerLike', Optional[int]
- ], Dict[str, Any]]] = []
+ Tuple[
+ Tuple[
+ Union["extension.ExtensionLike", extension.ExtensionEntry],
+ Optional[str],
+ "TriggerLike",
+ Optional[int],
+ ],
+ Dict[str, Any],
+ ]
+ ] = []
self._manager_state: Optional[Dict[str, Any]] = None
- self._evaluators: Dict[str, Tuple['Evaluator', TriggerLike]] = {}
+ self._evaluators: Dict[str, Tuple["Evaluator", TriggerLike]] = {}
if evaluator is None:
evaluator = {}
elif not isinstance(evaluator, collections.abc.Mapping):
evaluator = {"Evaluator": evaluator}
if isinstance(evaluator, collections.abc.Mapping):
for n, e in evaluator.items():
- self._evaluators[n] = e if isinstance(e, tuple) else (e, (1, 'epoch'))
+ self._evaluators[n] = (
+ e if isinstance(e, tuple) else (e, (1, "epoch"))
+ )
self.val_loader = None
def extend(
- self,
- extension: Union['extension.ExtensionLike', extension.ExtensionEntry],
- name: Optional[str] = None,
- trigger: 'TriggerLike' = None,
- priority: Optional[int] = None,
- *,
- call_before_training: bool = False,
- **kwargs: Any,
+ self,
+ extension: Union["extension.ExtensionLike", extension.ExtensionEntry],
+ name: Optional[str] = None,
+ trigger: "TriggerLike" = None,
+ priority: Optional[int] = None,
+ *,
+ call_before_training: bool = False,
+ **kwargs: Any,
) -> None:
if self._manager is not None:
- raise RuntimeError('cannot extend after starting the engine')
+ raise RuntimeError("cannot extend after starting the engine")
self._extensions.append(
- ((extension, name, trigger, priority),
- dict(call_before_training=call_before_training, **kwargs)))
+ (
+ (extension, name, trigger, priority),
+ dict(call_before_training=call_before_training, **kwargs),
+ )
+ )
- def _setup_manager(self, iters_per_epoch: int) -> 'training.ExtensionsManager':
+ def _setup_manager(
+ self, iters_per_epoch: int
+ ) -> "training.ExtensionsManager":
from pytorch_pfn_extras.training import ExtensionsManager
+
self._manager = ExtensionsManager(
- self._models, iters_per_epoch=iters_per_epoch, **self._kwargs)
+ self._models, iters_per_epoch=iters_per_epoch, **self._kwargs
+ )
for ex_args, ex_kwargs in self._extensions:
self._manager.extend(*ex_args, **ex_kwargs)
if self._manager_state is not None:
@@ -100,9 +131,9 @@ def _setup_manager(self, iters_per_epoch: int) -> 'training.ExtensionsManager':
return self._manager
@property
- def manager(self) -> 'training.ExtensionsManager':
+ def manager(self) -> "training.ExtensionsManager":
if self._manager is None:
- raise RuntimeError('the engine is not started yet')
+ raise RuntimeError("the engine is not started yet")
return self._manager
@property
@@ -148,34 +179,37 @@ def stop_trigger(self, trigger: Trigger) -> None:
self._stop_trigger = trigger
@property
- def evaluator(self) -> Optional['Evaluator']:
+ def evaluator(self) -> Optional["Evaluator"]:
if len(self._evaluators) == 0:
return None
if len(self._evaluators) == 1:
return next(iter(self._evaluators.values()))[0]
- raise ValueError('multiple evaluators are registered.')
+ raise ValueError("multiple evaluators are registered.")
def get_optimizer(self, name: str) -> torch.optim.Optimizer:
return self.manager.optimizers[name]
- def set_optimizer(self, name: str, optimizer: torch.optim.Optimizer) -> None:
+ def set_optimizer(
+ self, name: str, optimizer: torch.optim.Optimizer
+ ) -> None:
self.manager.optimizers[name] = optimizer # type: ignore[index]
def is_epoch_last_iter(self, idx: int) -> bool:
return (idx + 1) == (self.manager._iters_per_epoch)
def _complete_step(
- self,
- idx: int,
- outs: Any,
+ self,
+ idx: int,
+ outs: Any,
) -> None:
c_idx = self._idxs.get()
# Asure that iterations complete in order
if c_idx != idx:
raise RuntimeError(
- 'Completed a not expected iteration. '
- '{} was expected but completion of {} happened'.format(
- c_idx, idx)
+ "Completed a not expected iteration. "
+ "{} was expected but completion of {} happened".format(
+ c_idx, idx
+ )
)
x = self._inputs.get()
begin = self._times.get()
@@ -187,12 +221,14 @@ def _complete_step(
self.handler.train_post_step(self, idx, x, outs)
reporting.report({"elapsed_time": time.time() - begin})
- def run(self,
- train_loader: Iterable[Any],
- val_loader: Optional[Iterable[Any]] = None,
- *,
- train_len: Optional[int] = None,
- eval_len: Optional[int] = None) -> None:
+ def run(
+ self,
+ train_loader: Iterable[Any],
+ val_loader: Optional[Iterable[Any]] = None,
+ *,
+ train_len: Optional[int] = None,
+ eval_len: Optional[int] = None,
+ ) -> None:
"""Executes the training loop.
Args:
@@ -222,11 +258,11 @@ def run(self,
class _EvaluatorExt:
def __init__(
- self,
- trainer: 'Trainer',
- evaluator: 'Evaluator',
- val_loader: Optional[Iterable[Any]],
- eval_len: Optional[int],
+ self,
+ trainer: "Trainer",
+ evaluator: "Evaluator",
+ val_loader: Optional[Iterable[Any]],
+ eval_len: Optional[int],
) -> None:
self.needs_model_state = True
self._trainer = trainer
@@ -238,7 +274,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
evaluator = self._evaluator
if self._val_loader is None:
raise ValueError('"val_loader" is not given.')
- evaluator.handler.train_validation_begin(self._trainer, evaluator)
+ evaluator.handler.train_validation_begin(
+ self._trainer, evaluator
+ )
evaluator.run(self._val_loader, eval_len=self._eval_len)
evaluator.handler.train_validation_end(self._trainer, evaluator)
@@ -257,11 +295,12 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if len(self._evaluators) == 0:
if val_loader is not None:
warnings.warn(
- '`val_loader` is given whereas the evaluator is missing.',
- UserWarning)
+ "`val_loader` is given whereas the evaluator is missing.",
+ UserWarning,
+ )
else:
if val_loader is None:
- raise ValueError('`val_loader` is required')
+ raise ValueError("`val_loader` is required")
for _, (evaluator, _) in self._evaluators.items():
evaluator.handler.eval_setup(evaluator, val_loader)
@@ -271,27 +310,30 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
# When iterations are completed in the callback
# This is needed to avoid being constantly passing parameters
- self._idxs: 'queue.Queue[int]' = queue.Queue()
- self._inputs: 'queue.Queue[Any]' = queue.Queue()
- self._times: 'queue.Queue[float]' = queue.Queue()
- self._observed: 'queue.Queue[reporting.Observation]' = queue.Queue()
+ self._idxs: "queue.Queue[int]" = queue.Queue()
+ self._inputs: "queue.Queue[Any]" = queue.Queue()
+ self._times: "queue.Queue[float]" = queue.Queue()
+ self._observed: "queue.Queue[reporting.Observation]" = (
+ queue.Queue()
+ )
# Iterator must be created after `train_epoch_begin` as it may be
# using a DistributedSampler.
loader_iter = iter(train_loader)
- self._profile_records: 'queue.Queue[List[_ReportNotification]]' \
- = queue.Queue()
+ self._profile_records: "queue.Queue[List[_ReportNotification]]" = (
+ queue.Queue()
+ )
for idx in range(train_len):
with record(
"pytorch_pfn_extras.training.Trainer:iteration",
use_cuda=torch.cuda.is_available(),
enable=self._enable_profile,
- device=device
+ device=device,
) as ntf0:
try:
with record(
"pytorch_pfn_extras.training.Trainer:get_data",
enable=self._enable_profile,
- device=device
+ device=device,
):
x = next(loader_iter)
except StopIteration:
@@ -299,7 +341,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
with record(
"pytorch_pfn_extras.training.Trainer:get_data",
enable=self._enable_profile,
- device=device
+ device=device,
):
x = next(loader_iter)
begin = time.time()
@@ -311,19 +353,24 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
"pytorch_pfn_extras.training.Trainer:run_iteration",
use_cuda=torch.cuda.is_available(),
enable=self._enable_profile,
- device=device
- ) as ntf1, \
- self.manager.run_iteration():
+ device=device,
+ ) as ntf1, self.manager.run_iteration():
self._observed.put(self.manager.observation)
with record(
"pytorch_pfn_extras.training.Trainer:train_step",
use_cuda=torch.cuda.is_available(),
enable=self._enable_profile,
- device=device
+ device=device,
) as ntf2:
- self._profile_records.put([ntf0, ntf1, ntf2])
+ self._profile_records.put(
+ [ntf0, ntf1, ntf2]
+ )
self.handler.train_step(
- self, idx, x, complete_fn=self._complete_step)
+ self,
+ idx,
+ x,
+ complete_fn=self._complete_step,
+ )
# Check if the callback was called
except Exception:
# The manager has errored and called the extensions
@@ -340,7 +387,10 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
# And will keep yielding results even if the epoch
# is completed. We forcefully exit at the end of
# every epoch
- if self.is_epoch_last_iter(idx) or self.manager.stop_trigger:
+ if (
+ self.is_epoch_last_iter(idx)
+ or self.manager.stop_trigger
+ ):
break
# In handlers that support a completely Async model train_epoch_end
# Will take care of completing pending work
diff --git a/pytorch_pfn_extras/training/_transform_model.py b/pytorch_pfn_extras/training/_transform_model.py
index 5fa416e4d..9a0dd4ff1 100644
--- a/pytorch_pfn_extras/training/_transform_model.py
+++ b/pytorch_pfn_extras/training/_transform_model.py
@@ -1,12 +1,10 @@
import typing
import torch
-from torch.nn.parallel import DistributedDataParallel
-
from pytorch_pfn_extras.nn.parallel import (
DistributedDataParallel as PpeDistributedDataParallel,
)
-
+from torch.nn.parallel import DistributedDataParallel
_TransformModel = typing.Callable[[str, torch.nn.Module], torch.nn.Module]
diff --git a/pytorch_pfn_extras/training/_trigger_util.py b/pytorch_pfn_extras/training/_trigger_util.py
index c35fb6d20..a2e957b64 100644
--- a/pytorch_pfn_extras/training/_trigger_util.py
+++ b/pytorch_pfn_extras/training/_trigger_util.py
@@ -1,10 +1,13 @@
-from typing import Any, Callable, Dict, Union, Optional, Tuple
+from typing import Any, Callable, Dict, Optional, Tuple, Union
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
class Trigger:
"""Base class for triggers."""
+
def may_fire(self, iteration: int, epoch_len: int) -> bool:
"""Flags if the trigger may fire at the current iteration
@@ -24,7 +27,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
class _CallableTrigger(Trigger):
- def __init__(self, func: Callable[[ExtensionsManagerProtocol], bool]) -> None:
+ def __init__(
+ self, func: Callable[[ExtensionsManagerProtocol], bool]
+ ) -> None:
self.func = func
def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
diff --git a/pytorch_pfn_extras/training/_util.py b/pytorch_pfn_extras/training/_util.py
index 462dae9cd..ce3759c03 100644
--- a/pytorch_pfn_extras/training/_util.py
+++ b/pytorch_pfn_extras/training/_util.py
@@ -12,7 +12,7 @@ def _get_ignite_version(version: str) -> List[int]:
# major and minor ids can be only integers.
# Some examples of versions are:
# 0.1.0, 0.1.1, 0.3.0.dev20191007, 0.3.0.
- version_regexp = r'^[0-9]+\.[0-9]+\.[0-9]+(\.[0-9a-zA-Z]+)?$'
+ version_regexp = r"^[0-9]+\.[0-9]+\.[0-9]+(\.[0-9a-zA-Z]+)?$"
if re.search(version_regexp, version):
- return [int(x) for x in version.split('.')[:2]]
- raise ValueError('Invalid version format')
+ return [int(x) for x in version.split(".")[:2]]
+ raise ValueError("Invalid version format")
diff --git a/pytorch_pfn_extras/training/extension.py b/pytorch_pfn_extras/training/extension.py
index 05353d9a8..caa79f62e 100644
--- a/pytorch_pfn_extras/training/extension.py
+++ b/pytorch_pfn_extras/training/extension.py
@@ -1,11 +1,14 @@
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
import types
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from pytorch_pfn_extras.training import _trigger_util
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
if TYPE_CHECKING:
from pytorch_pfn_extras.training._trigger_util import TriggerLike
+
ExtensionLike = Callable[[ExtensionsManagerProtocol], Any]
@@ -46,7 +49,8 @@ class Extension:
:meth:`pytorch_pfn_extras.ExtensionsManager.extend` for details.
"""
- trigger: 'TriggerLike' = (1, 'iteration')
+
+ trigger: "TriggerLike" = (1, "iteration")
priority: int = PRIORITY_READER
name: Optional[str] = None
needs_model_state = False
@@ -77,15 +81,18 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> Any:
"""
raise NotImplementedError(
- 'Extension implementation must override __call__.')
+ "Extension implementation must override __call__."
+ )
def __getattr__(self, name: str) -> Any:
- if name == 'invoke_before_training':
+ if name == "invoke_before_training":
raise AttributeError(
- 'invoke_before_training has been removed since Chainer '
- 'v2.0.0. Use Extension.initialize instead.')
- raise AttributeError('{} object has no attribute {}'.format(
- type(self).__name__, name))
+ "invoke_before_training has been removed since Chainer "
+ "v2.0.0. Use Extension.initialize instead."
+ )
+ raise AttributeError(
+ "{} object has no attribute {}".format(type(self).__name__, name)
+ )
def finalize(self, manager: ExtensionsManagerProtocol) -> None:
"""Finalizes the extension.
@@ -113,10 +120,10 @@ def initialize(self, manager: ExtensionsManagerProtocol) -> None:
pass
def on_error(
- self,
- manager: ExtensionsManagerProtocol,
- exc: Exception,
- tb: types.TracebackType,
+ self,
+ manager: ExtensionsManagerProtocol,
+ exc: Exception,
+ tb: types.TracebackType,
) -> None:
"""Handles the error raised during training before finalization.
@@ -148,47 +155,47 @@ def load_state_dict(self, to_load: Dict[str, Any]) -> None:
class _WrappedExtension(Extension):
-
- def __init__(self, ext: 'ExtensionLike') -> None:
+ def __init__(self, ext: "ExtensionLike") -> None:
self._ext = ext
- self.trigger = getattr(self._ext, 'trigger', Extension.trigger)
- self.priority = getattr(self._ext, 'priority', Extension.priority)
+ self.trigger = getattr(self._ext, "trigger", Extension.trigger)
+ self.priority = getattr(self._ext, "priority", Extension.priority)
super().__init__()
@property
def default_name(self) -> str:
- return getattr(self._ext, 'default_name', None) or super().default_name
+ return getattr(self._ext, "default_name", None) or super().default_name
def __call__(self, manager: ExtensionsManagerProtocol) -> None:
self._ext(manager)
def finalize(self, manager: ExtensionsManagerProtocol) -> None:
- getattr(self._ext, 'finalize', super().finalize)(manager)
+ getattr(self._ext, "finalize", super().finalize)(manager)
def initialize(self, manager: ExtensionsManagerProtocol) -> None:
- getattr(self._ext, 'initialize', super().initialize)(manager)
+ getattr(self._ext, "initialize", super().initialize)(manager)
def on_error(
- self,
- manager: ExtensionsManagerProtocol,
- exc: Exception,
- tb: types.TracebackType,
+ self,
+ manager: ExtensionsManagerProtocol,
+ exc: Exception,
+ tb: types.TracebackType,
) -> None:
- getattr(self._ext, 'on_error', super().on_error)(manager, exc, tb)
+ getattr(self._ext, "on_error", super().on_error)(manager, exc, tb)
_OnErrorType = Callable[
- [ExtensionsManagerProtocol, Exception, types.TracebackType], None]
+ [ExtensionsManagerProtocol, Exception, types.TracebackType], None
+]
def make_extension(
- trigger: 'TriggerLike' = Extension.trigger,
- default_name: Optional[str] = None,
- priority: int = Extension.priority,
- finalizer: 'ExtensionLike' = lambda manager: None,
- initializer: 'ExtensionLike' = lambda manager: None,
- on_error: _OnErrorType = lambda manager, exc, tb: None,
-) -> Callable[['ExtensionLike'], 'ExtensionLike']:
+ trigger: "TriggerLike" = Extension.trigger,
+ default_name: Optional[str] = None,
+ priority: int = Extension.priority,
+ finalizer: "ExtensionLike" = lambda manager: None,
+ initializer: "ExtensionLike" = lambda manager: None,
+ on_error: _OnErrorType = lambda manager, exc, tb: None,
+) -> Callable[["ExtensionLike"], "ExtensionLike"]:
"""Decorator to make given function into an extension.
This decorator just adds some attributes to a given function. The value of
@@ -210,7 +217,8 @@ def make_extension(
called after an error is raised during the training loop.
"""
- def decorator(ext: 'ExtensionLike') -> 'ExtensionLike':
+
+ def decorator(ext: "ExtensionLike") -> "ExtensionLike":
ext.trigger = trigger # type: ignore
ext.default_name = default_name or ext.__name__ # type: ignore
ext.priority = priority # type: ignore
@@ -222,7 +230,7 @@ def decorator(ext: 'ExtensionLike') -> 'ExtensionLike':
return decorator
-def _as_extension(ext: 'ExtensionLike') -> Extension:
+def _as_extension(ext: "ExtensionLike") -> Extension:
return ext if isinstance(ext, Extension) else _WrappedExtension(ext)
@@ -243,39 +251,42 @@ class ExtensionEntry:
"""
def __init__(
- self,
- extension: 'ExtensionLike',
- *,
- name: Optional[str] = None,
- priority: Optional[int] = None,
- trigger: Optional['TriggerLike'] = None,
- call_before_training: bool = False,
+ self,
+ extension: "ExtensionLike",
+ *,
+ name: Optional[str] = None,
+ priority: Optional[int] = None,
+ trigger: Optional["TriggerLike"] = None,
+ call_before_training: bool = False,
) -> None:
self.extension = _as_extension(extension)
self.priority = priority or self.extension.priority
self.call_before_training = call_before_training
self._update_trigger(trigger or self.extension.trigger)
- self._update_name(name or self.extension.name or self.extension.default_name)
+ self._update_name(
+ name or self.extension.name or self.extension.default_name
+ )
- def _update_trigger(self, trigger: 'TriggerLike') -> None:
+ def _update_trigger(self, trigger: "TriggerLike") -> None:
self.trigger = _trigger_util.get_trigger(trigger)
def _update_name(self, name: str) -> None:
- if name == 'training':
+ if name == "training":
raise ValueError(
- 'the name "training" is prohibited as an extension name')
+ 'the name "training" is prohibited as an extension name'
+ )
self.name = name
self.extension.name = name
def state_dict(self) -> Dict[str, Any]:
state = {}
- state['extension'] = self.extension.state_dict()
- state['trigger'] = self.trigger.state_dict()
+ state["extension"] = self.extension.state_dict()
+ state["trigger"] = self.trigger.state_dict()
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- if 'extension' in to_load:
- self.extension.load_state_dict(to_load['extension'])
- if 'trigger' in to_load:
- self.trigger.load_state_dict(to_load['trigger'])
+ if "extension" in to_load:
+ self.extension.load_state_dict(to_load["extension"])
+ if "trigger" in to_load:
+ self.trigger.load_state_dict(to_load["trigger"])
diff --git a/pytorch_pfn_extras/training/extensions/__init__.py b/pytorch_pfn_extras/training/extensions/__init__.py
index 26c1074f3..5a3285a26 100644
--- a/pytorch_pfn_extras/training/extensions/__init__.py
+++ b/pytorch_pfn_extras/training/extensions/__init__.py
@@ -1,30 +1,62 @@
from pytorch_pfn_extras.training.extensions import snapshot_writers # NOQA
+from pytorch_pfn_extras.training.extensions import util as _util
from pytorch_pfn_extras.training.extensions._snapshot import snapshot # NOQA
-from pytorch_pfn_extras.training.extensions._snapshot import snapshot_object # NOQA
+from pytorch_pfn_extras.training.extensions._snapshot import ( # NOQA
+ snapshot_object,
+)
from pytorch_pfn_extras.training.extensions.best_value import BestValue # NOQA
from pytorch_pfn_extras.training.extensions.best_value import MaxValue # NOQA
from pytorch_pfn_extras.training.extensions.best_value import MinValue # NOQA
-from pytorch_pfn_extras.training.extensions.evaluator import Evaluator, DistributedEvaluator, IgniteEvaluator # NOQA
-from pytorch_pfn_extras.training.extensions.fail_on_non_number import FailOnNonNumber # NOQA
+from pytorch_pfn_extras.training.extensions.evaluator import ( # NOQA
+ DistributedEvaluator,
+ Evaluator,
+ IgniteEvaluator,
+)
+from pytorch_pfn_extras.training.extensions.fail_on_non_number import ( # NOQA
+ FailOnNonNumber,
+)
from pytorch_pfn_extras.training.extensions.log_report import LogReport # NOQA
-from pytorch_pfn_extras.training.extensions.lr_scheduler import LRScheduler # NOQA
-from pytorch_pfn_extras.training.extensions.micro_average import MicroAverage # NOQA
-from pytorch_pfn_extras.training.extensions.profile_report import ProfileReport # NOQA
-from pytorch_pfn_extras.training.extensions.parameter_statistics import ParameterStatistics # NOQA
-from pytorch_pfn_extras.training.extensions.plot_report import PlotReport # NOQA
-from pytorch_pfn_extras.training.extensions.profile_report import ProfileReport # NOQA
-from pytorch_pfn_extras.training.extensions.slack import Slack, SlackWebhook # NOQA
-from pytorch_pfn_extras.training.extensions.value_observation import observe_lr # NOQA
-from pytorch_pfn_extras.training.extensions.value_observation import observe_value # NOQA
-from pytorch_pfn_extras.training.extensions.variable_statistics_plot import VariableStatisticsPlot # NOQA
-from pytorch_pfn_extras.training.extensions import util as _util
-
-from pytorch_pfn_extras.training.extensions.print_report import PrintReport as PrintReportCLI # NOQA
-from pytorch_pfn_extras.training.extensions.progress_bar import ProgressBar as ProgressBarCLI # NOQA
+from pytorch_pfn_extras.training.extensions.lr_scheduler import ( # NOQA
+ LRScheduler,
+)
+from pytorch_pfn_extras.training.extensions.micro_average import ( # NOQA
+ MicroAverage,
+)
+from pytorch_pfn_extras.training.extensions.parameter_statistics import ( # NOQA
+ ParameterStatistics,
+)
+from pytorch_pfn_extras.training.extensions.plot_report import ( # NOQA
+ PlotReport,
+)
+from pytorch_pfn_extras.training.extensions.print_report import (
+ PrintReport as PrintReportCLI, # NOQA
+)
+from pytorch_pfn_extras.training.extensions.profile_report import ( # NOQA
+ ProfileReport,
+)
+from pytorch_pfn_extras.training.extensions.progress_bar import (
+ ProgressBar as ProgressBarCLI, # NOQA
+)
+from pytorch_pfn_extras.training.extensions.slack import ( # NOQA
+ Slack,
+ SlackWebhook,
+)
+from pytorch_pfn_extras.training.extensions.value_observation import ( # NOQA
+ observe_lr,
+ observe_value,
+)
+from pytorch_pfn_extras.training.extensions.variable_statistics_plot import ( # NOQA
+ VariableStatisticsPlot,
+)
try:
- from pytorch_pfn_extras.training.extensions.print_report_notebook import PrintReportNotebook # NOQA
- from pytorch_pfn_extras.training.extensions.progress_bar_notebook import ProgressBarNotebook # NOQA
+ from pytorch_pfn_extras.training.extensions.print_report_notebook import ( # NOQA
+ PrintReportNotebook,
+ )
+ from pytorch_pfn_extras.training.extensions.progress_bar_notebook import ( # NOQA
+ ProgressBarNotebook,
+ )
+
_ipython_module_available = True
except ImportError:
_ipython_module_available = False
diff --git a/pytorch_pfn_extras/training/extensions/_snapshot.py b/pytorch_pfn_extras/training/extensions/_snapshot.py
index 1fcb275c3..ce6dadcc1 100644
--- a/pytorch_pfn_extras/training/extensions/_snapshot.py
+++ b/pytorch_pfn_extras/training/extensions/_snapshot.py
@@ -4,18 +4,19 @@
import torch
import torch.distributed
-
-from pytorch_pfn_extras import logging
+from pytorch_pfn_extras import logging, writing
from pytorch_pfn_extras.training import extension
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-from pytorch_pfn_extras import writing
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
logger = logging._get_root_logger()
-def _find_snapshot_files(fmt: str, path: str, fs: Any) -> List[Tuple[float, str]]:
- '''Only prefix and suffix match
+def _find_snapshot_files(
+ fmt: str, path: str, fs: Any
+) -> List[Tuple[float, str]]:
+ """Only prefix and suffix match
TODO(kuenishi): currently clean format string such as
"snapshot{.iteration}.npz" can only be parsed, but tricky (or
@@ -33,16 +34,20 @@ def _find_snapshot_files(fmt: str, path: str, fs: Any) -> List[Tuple[float, str]
A sorted list of pair of ``mtime, filename``, whose file
name that matched the format ``fmt`` directly under ``path``.
- '''
- prefix = fmt.split('{')[0]
- suffix = fmt.split('}')[-1]
+ """
+ prefix = fmt.split("{")[0]
+ suffix = fmt.split("}")[-1]
- matched_files = (file for file in fs.list(path)
- if file.startswith(prefix) and file.endswith(suffix))
+ matched_files = (
+ file
+ for file in fs.list(path)
+ if file.startswith(prefix) and file.endswith(suffix)
+ )
def _prepend_mtime(f: str) -> Any:
t = fs.stat(os.path.join(path, f)).last_modified
return (t, f)
+
return sorted(_prepend_mtime(file) for file in matched_files)
@@ -63,7 +68,7 @@ def _find_latest_snapshot(fmt: str, path: str, fs: Any) -> Optional[str]:
"""
snapshot_files = _find_snapshot_files(fmt, path, fs)
- logger.debug('found snapshot files {}'.format(snapshot_files))
+ logger.debug("found snapshot files {}".format(snapshot_files))
if len(snapshot_files) > 0:
_, filename = snapshot_files[-1]
return filename
@@ -71,7 +76,7 @@ def _find_latest_snapshot(fmt: str, path: str, fs: Any) -> Optional[str]:
def _find_stale_snapshots(
- fmt: str, path: str, n_retains: int, fs: Any
+ fmt: str, path: str, n_retains: int, fs: Any
) -> Generator[str, None, None]:
"""Finds stale snapshots in a directory, retaining several files
@@ -100,7 +105,8 @@ def _find_stale_snapshots(
def snapshot_object(
- target: Any, filename: str, savefun: Any = None, **kwargs: Any) -> '_Snapshot':
+ target: Any, filename: str, savefun: Any = None, **kwargs: Any
+) -> "_Snapshot":
"""snapshot_object(target, filename, savefun=None, \
*, condition=None, writer=None, snapshot_on_error=False, \
n_retains=-1, autoload=False)
@@ -160,22 +166,21 @@ def snapshot_object(
- :meth:`pytorch_pfn_extras.training.extensions.snapshot`
"""
- return snapshot(target=target, filename=filename, savefun=savefun,
- **kwargs)
+ return snapshot(target=target, filename=filename, savefun=savefun, **kwargs)
def snapshot(
- savefun: Any = None,
- filename: str = 'snapshot_iter_{.iteration}',
- *,
- target: Any = None,
- condition: Any = None,
- writer: Optional[writing.Writer] = None,
- snapshot_on_error: bool = False,
- n_retains: int = -1,
- autoload: bool = False,
- saver_rank: Optional[int] = None,
-) -> '_Snapshot':
+ savefun: Any = None,
+ filename: str = "snapshot_iter_{.iteration}",
+ *,
+ target: Any = None,
+ condition: Any = None,
+ writer: Optional[writing.Writer] = None,
+ snapshot_on_error: bool = False,
+ n_retains: int = -1,
+ autoload: bool = False,
+ saver_rank: Optional[int] = None,
+) -> "_Snapshot":
"""
Returns a trainer extension to take snapshots of the trainer.
@@ -277,17 +282,31 @@ def __call__(self, x):
"""
if savefun is not None and writer is not None:
raise TypeError(
- 'savefun and writer arguments cannot be specified together.')
+ "savefun and writer arguments cannot be specified together."
+ )
if saver_rank is None:
return _Snapshot(
- target=target, condition=condition, writer=writer,
- filename=filename, snapshot_on_error=snapshot_on_error,
- n_retains=n_retains, autoload=autoload, savefun=savefun)
+ target=target,
+ condition=condition,
+ writer=writer,
+ filename=filename,
+ snapshot_on_error=snapshot_on_error,
+ n_retains=n_retains,
+ autoload=autoload,
+ savefun=savefun,
+ )
return _DistributedSnapshot(
- target=target, condition=condition, writer=writer, filename=filename,
- snapshot_on_error=snapshot_on_error, n_retains=n_retains,
- autoload=autoload, saver_rank=saver_rank, savefun=savefun)
+ target=target,
+ condition=condition,
+ writer=writer,
+ filename=filename,
+ snapshot_on_error=snapshot_on_error,
+ n_retains=n_retains,
+ autoload=autoload,
+ saver_rank=saver_rank,
+ savefun=savefun,
+ )
def _always_true() -> bool:
@@ -308,20 +327,21 @@ class _Snapshot(extension.Extension):
The default priority is -100, which is lower than that of most
built-in extensions.
"""
- trigger = 1, 'epoch'
+
+ trigger = 1, "epoch"
priority = extension.PRIORITY_SNAPSHOT
needs_model_state = True
def __init__(
- self,
- target: Any = None,
- condition: Any = None,
- writer: Optional[writing.Writer] = None,
- filename: str = 'snapshot_iter_{.iteration}',
- snapshot_on_error: bool = False,
- n_retains: int = -1,
- autoload: bool = False,
- savefun: Any = None,
+ self,
+ target: Any = None,
+ condition: Any = None,
+ writer: Optional[writing.Writer] = None,
+ filename: str = "snapshot_iter_{.iteration}",
+ snapshot_on_error: bool = False,
+ n_retains: int = -1,
+ autoload: bool = False,
+ savefun: Any = None,
) -> None:
if condition is None:
condition = _always_true
@@ -335,7 +355,8 @@ def __init__(
self._savefun = savefun
def initialize( # type: ignore[override]
- self, manager: ExtensionsManagerProtocol) -> Optional[str]:
+ self, manager: ExtensionsManagerProtocol
+ ) -> Optional[str]:
target = manager if self._target is None else self._target
writer = manager.writer if self.writer is None else self.writer
self.writer = writer
@@ -347,16 +368,22 @@ def initialize( # type: ignore[override]
# terms of mtime, and tries to load it it the target or
# manager.
assert writer is not None
- loaded_fn = _find_latest_snapshot(self.filename, writer.out_dir, writer.fs)
+ loaded_fn = _find_latest_snapshot(
+ self.filename, writer.out_dir, writer.fs
+ )
if loaded_fn:
- snapshot_file = writer.fs.open(os.path.join(writer.out_dir, loaded_fn), 'rb')
+ snapshot_file = writer.fs.open(
+ os.path.join(writer.out_dir, loaded_fn), "rb"
+ )
# As described above (at ``autoload`` option),
# snapshot files to be autoloaded must be saved by
# ``save_npz`` . In order to support general format,
# we nned to first reconstruct the design of savefun
# and loadfun.
- state = torch.load(snapshot_file, # type: ignore[no-untyped-call]
- map_location=torch.device("cpu"))
+ state = torch.load(
+ snapshot_file, # type: ignore[no-untyped-call]
+ map_location=torch.device("cpu"),
+ )
if type(target) is dict:
for k in target:
target[k].load_state_dict(state[k])
@@ -364,9 +391,11 @@ def initialize( # type: ignore[override]
target.load_state_dict(state)
snapshot_file.close()
- if (hasattr(writer, '_add_cleanup_hook')
- and self.n_retains > 0
- and isinstance(self.filename, str)):
+ if (
+ hasattr(writer, "_add_cleanup_hook")
+ and self.n_retains > 0
+ and isinstance(self.filename, str)
+ ):
# This block sets a method to automatic cleanup of stale
# snapshots, when ``n_retains`` argument is positive
# number. When the given snapshot writer is Chainer's
@@ -375,8 +404,9 @@ def initialize( # type: ignore[override]
# injected here.
def _cleanup() -> None:
assert writer is not None
- files = _find_stale_snapshots(self.filename, writer.out_dir,
- self.n_retains, writer.fs)
+ files = _find_stale_snapshots(
+ self.filename, writer.out_dir, self.n_retains, writer.fs
+ )
for file in files:
writer.fs.remove(os.path.join(writer.out_dir, file))
@@ -386,10 +416,10 @@ def _cleanup() -> None:
return loaded_fn
def on_error(
- self,
- manager: ExtensionsManagerProtocol,
- exc: Exception,
- tb: types.TracebackType,
+ self,
+ manager: ExtensionsManagerProtocol,
+ exc: Exception,
+ tb: types.TracebackType,
) -> None:
super().on_error(manager, exc, tb)
if self._snapshot_on_error:
@@ -406,8 +436,7 @@ def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None:
# We need to get a dictionary with the state here
if type(target) is dict:
- serialized_target = {
- k: v.state_dict() for k, v in target.items()}
+ serialized_target = {k: v.state_dict() for k, v in target.items()}
else:
serialized_target = target.state_dict()
filename = self.filename
@@ -417,7 +446,8 @@ def _make_snapshot(self, manager: ExtensionsManagerProtocol) -> None:
filename = filename.format(manager)
outdir = manager.out
writer( # type: ignore
- filename, outdir, serialized_target, savefun=self._savefun)
+ filename, outdir, serialized_target, savefun=self._savefun
+ )
def finalize(self, manager: ExtensionsManagerProtocol) -> None:
self.writer.finalize() # type: ignore
@@ -437,34 +467,46 @@ class _DistributedSnapshot(_Snapshot):
The default priority is lower than that of most
built-in extensions.
"""
- trigger = 1, 'epoch'
+
+ trigger = 1, "epoch"
priority = extension.PRIORITY_SNAPSHOT
def __init__(
- self,
- target: Any = None,
- condition: Any = None,
- writer: Optional[writing.Writer] = None,
- filename: str = 'snapshot_iter_{.iteration}',
- snapshot_on_error: bool = False,
- n_retains: int = -1,
- autoload: bool = False,
- saver_rank: int = 0,
- savefun: Any = None,
+ self,
+ target: Any = None,
+ condition: Any = None,
+ writer: Optional[writing.Writer] = None,
+ filename: str = "snapshot_iter_{.iteration}",
+ snapshot_on_error: bool = False,
+ n_retains: int = -1,
+ autoload: bool = False,
+ saver_rank: int = 0,
+ savefun: Any = None,
):
- super().__init__(target, condition, writer, filename,
- snapshot_on_error, n_retains,
- autoload, savefun)
+ super().__init__(
+ target,
+ condition,
+ writer,
+ filename,
+ snapshot_on_error,
+ n_retains,
+ autoload,
+ savefun,
+ )
# To support distributed snapshots
if not torch.distributed.is_initialized(): # type: ignore[no-untyped-call]
- raise RuntimeError('The Distributed Snapshot extension',
- ' requires torch.distributed to be initialized')
+ raise RuntimeError(
+ "The Distributed Snapshot extension",
+ " requires torch.distributed to be initialized",
+ )
self._saver_rank = saver_rank
self._size = torch.distributed.get_world_size() # type: ignore[no-untyped-call]
self._rank = torch.distributed.get_rank() # type: ignore[no-untyped-call]
if not (0 <= saver_rank < self._size):
- raise ValueError('Distributed snapshot requires a saver rank'
- ' in the range [0-{})'.format(self._size))
+ raise ValueError(
+ "Distributed snapshot requires a saver rank"
+ " in the range [0-{})".format(self._size)
+ )
def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if self.condition():
diff --git a/pytorch_pfn_extras/training/extensions/best_value.py b/pytorch_pfn_extras/training/extensions/best_value.py
index 0c0f0b4b9..7bec96c42 100644
--- a/pytorch_pfn_extras/training/extensions/best_value.py
+++ b/pytorch_pfn_extras/training/extensions/best_value.py
@@ -1,9 +1,9 @@
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
-
-from pytorch_pfn_extras.training import extension
-from pytorch_pfn_extras.training import triggers
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
+from pytorch_pfn_extras.training import extension, triggers
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
if TYPE_CHECKING:
from pytorch_pfn_extras.training._trigger_util import TriggerLike
@@ -24,13 +24,13 @@ class BestValue(extension.Extension):
:class:`~pytorch_pfn_extras.triggers.BestValueTrigger`.
"""
- default_name = 'best_value'
+ default_name = "best_value"
def __init__(
- self,
- key: str,
- compare: Callable[[float, float], bool],
- trigger: 'TriggerLike' = (1, 'epoch'),
+ self,
+ key: str,
+ compare: Callable[[float, float], bool],
+ trigger: "TriggerLike" = (1, "epoch"),
) -> None:
self._best_epoch: Optional[int] = None
self._best_it: Optional[int] = None
@@ -42,8 +42,10 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
def _check_best_value_exists(self) -> None:
if self._best_trigger._best_value is None:
- raise RuntimeError("Best observation hasn't been obtained. "
- "Run the BestValue extension at least once")
+ raise RuntimeError(
+ "Best observation hasn't been obtained. "
+ "Run the BestValue extension at least once"
+ )
@property
def best_value(self) -> float:
@@ -74,15 +76,15 @@ def best_epoch(self) -> int:
def state_dict(self) -> Dict[str, Any]:
return {
- '_best_trigger': self._best_trigger.state_dict(),
- '_best_it': self._best_it,
- '_best_epoch': self._best_epoch
+ "_best_trigger": self._best_trigger.state_dict(),
+ "_best_it": self._best_it,
+ "_best_epoch": self._best_epoch,
}
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- self._best_trigger.load_state_dict(to_load['_best_trigger'])
- self._best_it = to_load['_best_it']
- self._best_epoch = to_load['_best_epoch']
+ self._best_trigger.load_state_dict(to_load["_best_trigger"])
+ self._best_it = to_load["_best_it"]
+ self._best_epoch = to_load["_best_epoch"]
class MaxValue(BestValue):
@@ -97,11 +99,12 @@ class MaxValue(BestValue):
:class:`~pytorch_pfn_extras.triggers.BestValueTrigger`.
"""
- default_name = 'max_value'
+ default_name = "max_value"
- def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')):
+ def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")):
super().__init__(
- key, lambda max_value, new_value: new_value > max_value, trigger)
+ key, lambda max_value, new_value: new_value > max_value, trigger
+ )
class MinValue(BestValue):
@@ -116,8 +119,9 @@ class MinValue(BestValue):
:class:`~pytorch_pfn_extras.triggers.BestValueTrigger`.
"""
- default_name = 'min_value'
+ default_name = "min_value"
- def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')):
+ def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")):
super().__init__(
- key, lambda min_value, new_value: new_value < min_value, trigger)
+ key, lambda min_value, new_value: new_value < min_value, trigger
+ )
diff --git a/pytorch_pfn_extras/training/extensions/evaluator.py b/pytorch_pfn_extras/training/extensions/evaluator.py
index ce98498d5..8fd57f650 100644
--- a/pytorch_pfn_extras/training/extensions/evaluator.py
+++ b/pytorch_pfn_extras/training/extensions/evaluator.py
@@ -1,19 +1,27 @@
import contextlib
import datetime
from typing import (
- Any, Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union,
TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ TextIO,
+ Union,
)
import numpy
import torch
import torch.distributed
-
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import extension
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
from pytorch_pfn_extras.training.extensions import util
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
_MetricType = Callable[[Any, Any, Any], None]
_Scalar = Union[torch.Tensor, numpy.ndarray, numpy.floating, float]
@@ -83,29 +91,32 @@ class Evaluator(extension.Extension):
eval_func: Evaluation function called at each iteration.
"""
- trigger = 1, 'epoch'
- default_name = 'validation'
+
+ trigger = 1, "epoch"
+ default_name = "validation"
priority = extension.PRIORITY_WRITER
def __init__(
- self,
- iterator: Union[torch.utils.data.DataLoader[Any],
- Dict[str, torch.utils.data.DataLoader[Any]]],
- target: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
- eval_hook: Optional[Callable[['Evaluator'], None]] = None,
- eval_func: Optional[Callable[..., Any]] = None,
- **kwargs: Any,
+ self,
+ iterator: Union[
+ torch.utils.data.DataLoader[Any],
+ Dict[str, torch.utils.data.DataLoader[Any]],
+ ],
+ target: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
+ eval_hook: Optional[Callable[["Evaluator"], None]] = None,
+ eval_func: Optional[Callable[..., Any]] = None,
+ **kwargs: Any,
) -> None:
- progress_bar = kwargs.get('progress_bar', False)
- metrics = kwargs.get('metrics', [])
+ progress_bar = kwargs.get("progress_bar", False)
+ metrics = kwargs.get("metrics", [])
if isinstance(iterator, torch.utils.data.DataLoader):
- self._iterators = {'main': iterator}
+ self._iterators = {"main": iterator}
else:
self._iterators = iterator
if isinstance(target, torch.nn.Module):
- target = {'main': target}
+ target = {"main": target}
self._targets = target
self.name = None
@@ -118,7 +129,7 @@ def eval_func(self, *args: Any, **kwargs: Any) -> Any:
if self._eval_func:
func = self._eval_func
else:
- func = self._targets['main']
+ func = self._targets["main"]
return func(*args, **kwargs)
def get_iterator(self, name: str) -> torch.utils.data.DataLoader[Any]:
@@ -152,8 +163,8 @@ def add_metric(self, metric_fn: _MetricType) -> None:
self._metrics.append(metric_fn)
def __call__(
- self,
- manager: Optional[ExtensionsManagerProtocol] = None,
+ self,
+ manager: Optional[ExtensionsManagerProtocol] = None,
) -> Optional[Dict[str, _Scalar]]:
"""Executes the evaluator extension.
@@ -175,13 +186,12 @@ def __call__(
# set up a reporter
reporter = reporting.Reporter()
if self.name is not None:
- prefix = self.name + '/'
+ prefix = self.name + "/"
else:
- prefix = ''
+ prefix = ""
for name, target in self._targets.items():
reporter.add_observer(prefix + name, target)
- reporter.add_observers(prefix + name,
- target.named_modules())
+ reporter.add_observers(prefix + name, target.named_modules())
with reporter:
with torch.no_grad(): # type: ignore[no-untyped-call]
@@ -190,7 +200,9 @@ def __call__(
reporting.report(result)
return result
- def _gather_summaries(self, summary: reporting.DictSummary) -> reporting.DictSummary:
+ def _gather_summaries(
+ self, summary: reporting.DictSummary
+ ) -> reporting.DictSummary:
return summary
def evaluate(self) -> Dict[str, _Scalar]:
@@ -207,7 +219,7 @@ def evaluate(self) -> Dict[str, _Scalar]:
:func:`~pytorch_pfn_extras.report` without specifying any observer.
"""
- iterator = self._iterators['main']
+ iterator = self._iterators["main"]
if self.eval_hook:
self.eval_hook(self)
@@ -226,7 +238,7 @@ def evaluate(self) -> Dict[str, _Scalar]:
progress.current_position = idx
observation: Dict[str, Any] = {}
with reporting.report_scope(observation):
- if isinstance(batch, tuple) and hasattr(batch, '_fields'):
+ if isinstance(batch, tuple) and hasattr(batch, "_fields"):
outs = self.eval_func(batch)
elif isinstance(batch, (tuple, list)):
outs = self.eval_func(*batch)
@@ -287,31 +299,39 @@ class DistributedEvaluator(Evaluator):
"""
def __init__(
- self,
- iterator: Union[torch.utils.data.DataLoader[Any],
- Dict[str, torch.utils.data.DataLoader[Any]]],
- target: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
- eval_hook: Optional[Callable[['Evaluator'], None]] = None,
- eval_func: Optional[Callable[..., Any]] = None,
- **kwargs: Any,
+ self,
+ iterator: Union[
+ torch.utils.data.DataLoader[Any],
+ Dict[str, torch.utils.data.DataLoader[Any]],
+ ],
+ target: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
+ eval_hook: Optional[Callable[["Evaluator"], None]] = None,
+ eval_func: Optional[Callable[..., Any]] = None,
+ **kwargs: Any,
) -> None:
if not torch.distributed.is_initialized(): # type: ignore[no-untyped-call]
- msg = "PyTorch distributed module is not initialized. " \
- "Initialize process group or use non-distributed Evaluator."
+ msg = (
+ "PyTorch distributed module is not initialized. "
+ "Initialize process group or use non-distributed Evaluator."
+ )
raise RuntimeError(msg)
- if 'progress_bar' in kwargs:
+ if "progress_bar" in kwargs:
rank = torch.distributed.get_rank() # type: ignore[no-untyped-call]
- kwargs['progress_bar'] &= (rank == 0)
+ kwargs["progress_bar"] &= rank == 0
super().__init__(iterator, target, eval_hook, eval_func, **kwargs)
- def _gather_summaries(self, summary: reporting.DictSummary) -> reporting.DictSummary:
+ def _gather_summaries(
+ self, summary: reporting.DictSummary
+ ) -> reporting.DictSummary:
return sum(_dist_gather(summary), reporting.DictSummary())
@contextlib.contextmanager
-def _in_eval_mode(targets: Iterable[torch.nn.Module]) -> Generator[None, None, None]:
+def _in_eval_mode(
+ targets: Iterable[torch.nn.Module],
+) -> Generator[None, None, None]:
targets = list(targets)
was_train = [t.training for t in targets]
try:
@@ -335,19 +355,22 @@ def epoch_detail(self) -> float:
class _IteratorProgressBar(util.ProgressBar):
-
def __init__(
- self,
- name: str,
- iterator: IterationStatus,
- bar_length: int = 50,
- out: Optional[TextIO] = None,
+ self,
+ name: str,
+ iterator: IterationStatus,
+ bar_length: int = 50,
+ out: Optional[TextIO] = None,
):
- if not (hasattr(iterator, 'current_position')
- and hasattr(iterator, 'epoch_detail')):
- raise TypeError('Iterator must have the following attributes '
- 'to enable a progress bar: '
- 'current_position, epoch_detail')
+ if not (
+ hasattr(iterator, "current_position")
+ and hasattr(iterator, "epoch_detail")
+ ):
+ raise TypeError(
+ "Iterator must have the following attributes "
+ "to enable a progress bar: "
+ "current_position, epoch_detail"
+ )
self._name = name
self._iterator = iterator
self._bar_length = bar_length
@@ -357,28 +380,34 @@ def __init__(
def get_lines(self) -> List[str]:
iteration = self._iterator.current_position
epoch_detail = self._iterator.epoch_detail
- epoch_size = getattr(self._iterator, '_epoch_size', None)
+ epoch_size = getattr(self._iterator, "_epoch_size", None)
lines = []
rate = epoch_detail
- marks = '#' * int(rate * self._bar_length)
- rest_marks = '.' * (self._bar_length - len(marks))
- lines.append('{} [{}{}] {:6.2%}\n'.format(
- self._name, marks, rest_marks, rate))
+ marks = "#" * int(rate * self._bar_length)
+ rest_marks = "." * (self._bar_length - len(marks))
+ lines.append(
+ "{} [{}{}] {:6.2%}\n".format(self._name, marks, rest_marks, rate)
+ )
if epoch_size:
- lines.append(f'{{:{len(self._name)}}} / {{}} iterations\n'
- .format(iteration, epoch_size))
+ lines.append(
+ f"{{:{len(self._name)}}} / {{}} iterations\n".format(
+ iteration, epoch_size
+ )
+ )
else:
- lines.append(f'{{:{len(self._name)}}} iterations\n'
- .format(iteration))
+ lines.append(
+ f"{{:{len(self._name)}}} iterations\n".format(iteration)
+ )
speed_t, speed_e = self.update_speed(iteration, epoch_detail)
estimated_time = (1.0 - epoch_detail) / speed_e
- itps = f'{{:{len(self._name)}.5g}} iters/sec.'.format(speed_t)
- eta = 'Estimated time to finish: {}.\n' \
- .format(datetime.timedelta(seconds=estimated_time))
+ itps = f"{{:{len(self._name)}.5g}} iters/sec.".format(speed_t)
+ eta = "Estimated time to finish: {}.\n".format(
+ datetime.timedelta(seconds=estimated_time)
+ )
lines.append("{} {}".format(itps, eta))
return lines
@@ -389,20 +418,21 @@ def get_lines(self) -> List[str]:
class IgniteEvaluator(Evaluator):
def __init__(
- self,
- evaluator: 'Engine',
- iterator: Union[torch.utils.data.DataLoader[Any],
- Dict[str, torch.utils.data.DataLoader[Any]]],
- target: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
- **kwargs: Any,
+ self,
+ evaluator: "Engine",
+ iterator: Union[
+ torch.utils.data.DataLoader[Any],
+ Dict[str, torch.utils.data.DataLoader[Any]],
+ ],
+ target: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
+ **kwargs: Any,
):
super().__init__(iterator, target, None, **kwargs)
self.evaluator = evaluator
self.set_evaluator_handlers()
def set_evaluator_handlers(self) -> None:
- from ignite.engine import Engine
- from ignite.engine import Events
+ from ignite.engine import Engine, Events
# Register handlers to retrieve the Average metrics and report them
@self.evaluator.on(Events.ITERATION_STARTED)
@@ -412,6 +442,7 @@ def set_evaluation_started(engine: Engine) -> None:
self.cm.__enter__()
if self._progress_bar:
+
@self.evaluator.on(Events.ITERATION_STARTED)
def update_progress_bar(engine: Engine) -> None:
self.progress.current_position = engine.state.iteration
@@ -428,12 +459,11 @@ def set_evaluation_completed(engine: Engine) -> None:
with reporting.report_scope(ignite_metrics):
metrics = self.evaluator.state.metrics
for metric in metrics:
- reporting.report(
- {'val/{}'.format(metric): metrics[metric]})
+ reporting.report({"val/{}".format(metric): metrics[metric]})
self.summary.add(ignite_metrics)
def evaluate(self) -> Dict[str, _Scalar]:
- iterator = self._iterators['main']
+ iterator = self._iterators["main"]
self.summary = reporting.DictSummary()
self.progress = IterationStatus(len(iterator))
if self._progress_bar:
diff --git a/pytorch_pfn_extras/training/extensions/fail_on_non_number.py b/pytorch_pfn_extras/training/extensions/fail_on_non_number.py
index c61b428bd..9e47662eb 100644
--- a/pytorch_pfn_extras/training/extensions/fail_on_non_number.py
+++ b/pytorch_pfn_extras/training/extensions/fail_on_non_number.py
@@ -1,7 +1,8 @@
import torch
-
from pytorch_pfn_extras.training import extension
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
class FailOnNonNumber(extension.Extension):
@@ -26,10 +27,12 @@ def __init__(self, *, check_grad: bool = True):
def __call__(self, manager: ExtensionsManagerProtocol) -> None:
for name, model in manager.models.items():
for param in model.parameters():
- if (not torch.isfinite(param).all()
- or (self._check_grad
- and param.grad is not None
- and not torch.isfinite(param.grad).all())):
+ if not torch.isfinite(param).all() or (
+ self._check_grad
+ and param.grad is not None
+ and not torch.isfinite(param.grad).all()
+ ):
raise RuntimeError(
- 'Kill the process since parameters in optimizer'
- ' \'{}\' diverge. R.I.P.'.format(name))
+ "Kill the process since parameters in optimizer"
+ " '{}' diverge. R.I.P.".format(name)
+ )
diff --git a/pytorch_pfn_extras/training/extensions/log_report.py b/pytorch_pfn_extras/training/extensions/log_report.py
index 0f07032d3..c78448d0e 100644
--- a/pytorch_pfn_extras/training/extensions/log_report.py
+++ b/pytorch_pfn_extras/training/extensions/log_report.py
@@ -5,7 +5,9 @@
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
Observation = Mapping[str, reporting.Scalar]
@@ -18,37 +20,38 @@
class LogWriterSaveFunc:
-
def __init__(self, format: str, append: bool) -> None:
self._format = format
self._append = append
def __call__(self, target: Dict[str, Any], file_o: Any) -> None:
- if self._format == 'json':
+ if self._format == "json":
if self._append:
raise ValueError(
- 'LogReport does not support json format with append mode.')
+ "LogReport does not support json format with append mode."
+ )
log = json.dumps(target, indent=4)
- elif self._format == 'json-lines':
+ elif self._format == "json-lines":
# Add a new line at the end for subsequent appends
- log = '\n'.join([json.dumps(x) for x in target]) + '\n'
- elif self._format == 'yaml':
+ log = "\n".join([json.dumps(x) for x in target]) + "\n"
+ elif self._format == "yaml":
import yaml
# This is to dump ordered dicts as regular dicts
def dict_representer(dumper: Any, data: Any) -> Any:
return dumper.represent_dict(data.items())
+
yaml.add_representer( # type: ignore[no-untyped-call]
- collections.OrderedDict, dict_representer)
+ collections.OrderedDict, dict_representer
+ )
# yaml.add_constructor(_mapping_tag, dict_constructor)
log = yaml.dump(target)
else:
- raise ValueError('Unknown format: {}'.format(self._format))
- file_o.write(bytes(log.encode('ascii')))
+ raise ValueError("Unknown format: {}".format(self._format))
+ file_o.write(bytes(log.encode("ascii")))
class _LogBuffer:
-
def __init__(self) -> None:
self.lookers: Dict[int, int] = {}
self._log: List[Observation] = []
@@ -57,22 +60,22 @@ def __init__(self) -> None:
def _trim(self) -> None:
min_looker_index = min(self.lookers.values())
if min_looker_index > self._offset:
- self._log = self._log[min_looker_index - self._offset:]
+ self._log = self._log[min_looker_index - self._offset :]
self._offset = min_looker_index
def append(self, observation: Observation) -> None:
self._log.append(observation)
def _get(self, looker_id: int) -> List[Observation]:
- return self._log[self.lookers[looker_id] - self._offset:]
+ return self._log[self.lookers[looker_id] - self._offset :]
def _clear(self, looker_id: int) -> None:
if looker_id not in self.lookers:
- raise ValueError(f'looker {looker_id} is not registered')
+ raise ValueError(f"looker {looker_id} is not registered")
self.lookers[looker_id] = len(self._log) + self._offset
self._trim()
- def emit_new_looker(self) -> '_LogLooker':
+ def emit_new_looker(self) -> "_LogLooker":
looker_id = len(self.lookers)
assert looker_id not in self.lookers
self.lookers[looker_id] = len(self._log) + self._offset
@@ -83,7 +86,6 @@ def size(self) -> int:
class _LogLooker:
-
def __init__(self, log_buffer: _LogBuffer, looker_id: int) -> None:
self._log_buffer = log_buffer
self._looker_id = looker_id
@@ -154,14 +156,14 @@ class LogReport(extension.Extension):
"""
def __init__(
- self,
- keys: Optional[Iterable[str]] = None,
- trigger: trigger_module.TriggerLike = (1, 'epoch'),
- postprocess: Optional[Callable[[Mapping[str, Any]], None]] = None,
- filename: Optional[str] = None,
- append: bool = False,
- format: Optional[str] = None,
- **kwargs: Any,
+ self,
+ keys: Optional[Iterable[str]] = None,
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
+ postprocess: Optional[Callable[[Mapping[str, Any]], None]] = None,
+ filename: Optional[str] = None,
+ append: bool = False,
+ format: Optional[str] = None,
+ **kwargs: Any,
):
self._keys = keys
self._trigger = trigger_module.get_trigger(trigger)
@@ -170,20 +172,20 @@ def __init__(
self._log_looker = self._log_buffer.emit_new_looker()
# When using a writer, it needs to have a savefun defined
# to deal with a string.
- self._writer = kwargs.get('writer', None)
+ self._writer = kwargs.get("writer", None)
if filename is None:
- filename = 'log'
+ filename = "log"
if format is None:
- if filename.endswith('.jsonl'):
- format = 'json-lines'
- elif filename.endswith('.yaml'):
- format = 'yaml'
+ if filename.endswith(".jsonl"):
+ format = "json-lines"
+ elif filename.endswith(".yaml"):
+ format = "yaml"
else:
- format = 'json'
- elif format not in ('json', 'json-lines', 'yaml'):
- raise ValueError(f'unsupported log format: {format}')
+ format = "json"
+ elif format not in ("json", "json-lines", "yaml"):
+ raise ValueError(f"unsupported log format: {format}")
self._filename = filename
self._append = append
@@ -210,9 +212,9 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
for name, value in stats.items():
stats_cpu[name] = float(value) # copy to CPU
- stats_cpu['epoch'] = manager.epoch
- stats_cpu['iteration'] = manager.iteration
- stats_cpu['elapsed_time'] = manager.elapsed_time
+ stats_cpu["epoch"] = manager.epoch
+ stats_cpu["iteration"] = manager.iteration
+ stats_cpu["elapsed_time"] = manager.elapsed_time
if self._postprocess is not None:
self._postprocess(stats_cpu)
@@ -223,8 +225,13 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
log_name = self._filename.format(**stats_cpu)
out = manager.out
savefun = LogWriterSaveFunc(self._format, self._append)
- writer(log_name, out, self._log_looker.get(),
- savefun=savefun, append=self._append)
+ writer(
+ log_name,
+ out,
+ self._log_looker.get(),
+ savefun=savefun,
+ append=self._append,
+ )
if self._append:
self._log_looker.clear()
@@ -238,26 +245,26 @@ def log(self) -> List[Mapping[str, Any]]:
def state_dict(self) -> Dict[str, Any]:
state: Dict[str, Any] = {}
- if hasattr(self._trigger, 'state_dict'):
- state['_trigger'] = self._trigger.state_dict()
+ if hasattr(self._trigger, "state_dict"):
+ state["_trigger"] = self._trigger.state_dict()
try:
- state['_summary'] = self._summary.state_dict()
+ state["_summary"] = self._summary.state_dict()
except KeyError:
pass
- state['_log'] = json.dumps(self._log_buffer._log)
+ state["_log"] = json.dumps(self._log_buffer._log)
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- if hasattr(self._trigger, 'load_state_dict'):
- self._trigger.load_state_dict(to_load['_trigger'])
- self._summary.load_state_dict(to_load['_summary'])
- self._log_buffer._log = json.loads(to_load['_log'])
+ if hasattr(self._trigger, "load_state_dict"):
+ self._trigger.load_state_dict(to_load["_trigger"])
+ self._summary.load_state_dict(to_load["_summary"])
+ self._log_buffer._log = json.loads(to_load["_log"])
def _init_summary(self) -> None:
self._summary = reporting.DictSummary()
- def to_dataframe(self) -> 'pandas.DataFrame':
+ def to_dataframe(self) -> "pandas.DataFrame":
if not _pandas_available:
raise ImportError(
"Need to install pandas to use `to_dataframe` method."
diff --git a/pytorch_pfn_extras/training/extensions/lr_scheduler.py b/pytorch_pfn_extras/training/extensions/lr_scheduler.py
index 28be7f758..fc794554a 100644
--- a/pytorch_pfn_extras/training/extensions/lr_scheduler.py
+++ b/pytorch_pfn_extras/training/extensions/lr_scheduler.py
@@ -2,25 +2,33 @@
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
from torch.optim.lr_scheduler import ReduceLROnPlateau
-def _get_value_from_log_report(manager: ExtensionsManagerProtocol, key: Any) -> Any:
+def _get_value_from_log_report(
+ manager: ExtensionsManagerProtocol, key: Any
+) -> Any:
# Find and return the latest reported "key" from LogReport
if key is None:
return None
if key not in manager.observation:
raise ValueError(
- '{} is not found in the reported values {}'.format(
- key, manager.observation))
+ "{} is not found in the reported values {}".format(
+ key, manager.observation
+ )
+ )
return manager.observation[key]
-def _default_stepper(manager: ExtensionsManagerProtocol, scheduler: Any) -> None:
+def _default_stepper(
+ manager: ExtensionsManagerProtocol, scheduler: Any
+) -> None:
if isinstance(scheduler, ReduceLROnPlateau):
- LRScheduler.step_by_value('val/loss')(manager, scheduler)
+ LRScheduler.step_by_value("val/loss")(manager, scheduler)
else:
scheduler.step()
@@ -43,11 +51,12 @@ class LRScheduler(extension.Extension):
"""
def __init__(
- self,
- scheduler: Any, *,
- stepper: Any = _default_stepper,
- trigger: trigger_module.TriggerLike = (1, 'epoch'),
- is_async: bool = True,
+ self,
+ scheduler: Any,
+ *,
+ stepper: Any = _default_stepper,
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
+ is_async: bool = True,
) -> None:
self.scheduler = scheduler
self.trigger = trigger_module.get_trigger(trigger)
@@ -59,12 +68,15 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
@staticmethod
def step_by_value(key: Optional[str]) -> Any:
- def _stepper(manager: ExtensionsManagerProtocol, scheduler: Any) -> None:
+ def _stepper(
+ manager: ExtensionsManagerProtocol, scheduler: Any
+ ) -> None:
scheduler.step(_get_value_from_log_report(manager, key))
+
return _stepper
def state_dict(self) -> Dict[str, Any]:
- return {'scheduler': self.scheduler.state_dict()}
+ return {"scheduler": self.scheduler.state_dict()}
def load_state_dict(self, state: Dict[str, Any]) -> None:
- self.scheduler.load_state_dict(state['scheduler'])
+ self.scheduler.load_state_dict(state["scheduler"])
diff --git a/pytorch_pfn_extras/training/extensions/micro_average.py b/pytorch_pfn_extras/training/extensions/micro_average.py
index 10f2df441..984fc00f8 100644
--- a/pytorch_pfn_extras/training/extensions/micro_average.py
+++ b/pytorch_pfn_extras/training/extensions/micro_average.py
@@ -3,7 +3,9 @@
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
class MicroAverage(extension.Extension):
@@ -64,24 +66,26 @@ class MicroAverage(extension.Extension):
priority = extension.PRIORITY_EDITOR
def __init__(
- self,
- numerator_key: str,
- denominator_key: str,
- result_key: str,
- trigger: trigger_module.TriggerLike = (1, 'epoch'),
+ self,
+ numerator_key: str,
+ denominator_key: str,
+ result_key: str,
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
) -> None:
self._trigger = trigger_module.get_trigger(trigger)
self._numerator_key = numerator_key
self._denominator_key = denominator_key
self._result_key = result_key
- self._numerator = 0.
- self._denominator = 0.
+ self._numerator = 0.0
+ self._denominator = 0.0
def __call__(self, manager: ExtensionsManagerProtocol) -> None:
observation: Any = manager.observation
- if not (self._numerator_key in observation
- and self._denominator_key in observation):
+ if not (
+ self._numerator_key in observation
+ and self._denominator_key in observation
+ ):
return
self._numerator += observation[self._numerator_key]
@@ -94,10 +98,12 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
reporting.report({self._result_key: result})
def state_dict(self) -> Dict[str, Any]:
- state = {'_numerator': self._numerator,
- '_denominator': self._denominator}
+ state = {
+ "_numerator": self._numerator,
+ "_denominator": self._denominator,
+ }
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- self._numerator = to_load['_numerator']
- self._denominator = to_load['_denominator']
+ self._numerator = to_load["_numerator"]
+ self._denominator = to_load["_denominator"]
diff --git a/pytorch_pfn_extras/training/extensions/parameter_statistics.py b/pytorch_pfn_extras/training/extensions/parameter_statistics.py
index 66975cc91..24e515d01 100644
--- a/pytorch_pfn_extras/training/extensions/parameter_statistics.py
+++ b/pytorch_pfn_extras/training/extensions/parameter_statistics.py
@@ -1,19 +1,19 @@
from typing import Any, Optional
import torch
-
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
_default_statistics = {
- 'mean': lambda x: torch.mean(x),
- 'std': lambda x: torch.std(x),
- 'min': lambda x: torch.min(x),
- 'max': lambda x: torch.max(x),
- 'zeros': lambda x: (x == 0).sum(),
+ "mean": lambda x: torch.mean(x),
+ "std": lambda x: torch.std(x),
+ "min": lambda x: torch.min(x),
+ "max": lambda x: torch.max(x),
+ "zeros": lambda x: (x == 0).sum(),
# 'percentile': lambda x: backend.get_array_module(x).percentile(
# x, (0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87))
}
@@ -72,41 +72,40 @@ class ParameterStatistics(extension.Extension):
(0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87))``)
"""
- default_name = 'parameter_statistics'
+
+ default_name = "parameter_statistics"
priority = extension.PRIORITY_WRITER
# prefix ends with a '/' and param_name is preceded by a '/'
- report_key_template = ('{prefix}{param_name}/{attr_name}/'
- '{function_name}')
+ report_key_template = "{prefix}{param_name}/{attr_name}/" "{function_name}"
default_statistics = _default_statistics
def __init__(
- self,
- links: Any,
- statistics: Any = 'default',
- report_params: bool = True,
- report_grads: bool = True,
- prefix: Optional[str] = None,
- trigger: trigger_module.TriggerLike = (1, 'epoch'),
- skip_nan_params: bool = False,
+ self,
+ links: Any,
+ statistics: Any = "default",
+ report_params: bool = True,
+ report_grads: bool = True,
+ prefix: Optional[str] = None,
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
+ skip_nan_params: bool = False,
):
-
if not isinstance(links, (list, tuple)):
- links = links,
+ links = (links,)
self._links = links
if statistics is None:
statistics = {}
- elif statistics == 'default':
+ elif statistics == "default":
statistics = self.default_statistics
self._statistics = dict(statistics)
attrs = []
if report_params:
- attrs.append('data')
+ attrs.append("data")
if report_grads:
- attrs.append('grad')
+ attrs.append("grad")
self._attrs = attrs
self._prefix = prefix
@@ -137,24 +136,30 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
# since the statistics function should make no
# assumption about the axes
params = getattr(param, attr_name).flatten()
- if (self._skip_nan_params
- and (
- torch.isnan(params).any())):
- value: Any = float('nan')
+ if self._skip_nan_params and (
+ torch.isnan(params).any()
+ ):
+ value: Any = float("nan")
else:
value = function(params)
key = self.report_key_template.format(
- prefix=self._prefix + '/' if self._prefix else '',
+ prefix=self._prefix + "/" if self._prefix else "",
param_name=param_name,
attr_name=attr_name,
- function_name=function_name
+ function_name=function_name,
)
- if (isinstance(value, torch.Tensor)
- and value.numel() > 1):
+ if (
+ isinstance(value, torch.Tensor)
+ and value.numel() > 1
+ ):
# Append integer indices to the keys if the
# statistic function return multiple values
- statistics.update({'{}/{}'.format(key, i): v for
- i, v in enumerate(value)})
+ statistics.update(
+ {
+ "{}/{}".format(key, i): v
+ for i, v in enumerate(value)
+ }
+ )
else:
statistics[key] = value
diff --git a/pytorch_pfn_extras/training/extensions/plot_report.py b/pytorch_pfn_extras/training/extensions/plot_report.py
index a906f6b7d..fb01ec305 100644
--- a/pytorch_pfn_extras/training/extensions/plot_report.py
+++ b/pytorch_pfn_extras/training/extensions/plot_report.py
@@ -1,20 +1,21 @@
import json
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import warnings
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy
-
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
_available = None
def matplotlib_savefun(target: Tuple[Any, Any, Any], file_o: Any) -> None:
fig, leg, plt = target
- fig.savefig(file_o, bbox_extra_artists=(leg,), bbox_inches='tight')
+ fig.savefig(file_o, bbox_extra_artists=(leg,), bbox_inches="tight")
fig.clf()
plt.close(fig)
@@ -23,6 +24,7 @@ def _try_import_matplotlib() -> None:
global matplotlib, _available
try:
import matplotlib # NOQA
+
_available = True
except (ImportError, TypeError):
_available = False
@@ -33,10 +35,12 @@ def _check_available() -> None:
_try_import_matplotlib()
if not _available:
- warnings.warn('matplotlib is not installed on your environment, '
- 'so nothing will be plotted at this time. '
- 'Please install matplotlib to plot figures.\n\n'
- ' $ pip install matplotlib\n')
+ warnings.warn(
+ "matplotlib is not installed on your environment, "
+ "so nothing will be plotted at this time. "
+ "Please install matplotlib to plot figures.\n\n"
+ " $ pip install matplotlib\n"
+ )
class PlotReport(extension.Extension):
@@ -114,18 +118,17 @@ class PlotReport(extension.Extension):
"""
def __init__(
- self,
- y_keys: Union[Iterable[str], str],
- x_key: str = 'iteration',
- trigger: trigger_module.TriggerLike = (1, 'epoch'),
- postprocess: Any = None,
- filename: Optional[str] = None,
- marker: str = 'x',
- grid: bool = True,
- **kwargs: Any,
+ self,
+ y_keys: Union[Iterable[str], str],
+ x_key: str = "iteration",
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
+ postprocess: Any = None,
+ filename: Optional[str] = None,
+ marker: str = "x",
+ grid: bool = True,
+ **kwargs: Any,
):
-
- file_name = kwargs.get('file_name', 'plot.png')
+ file_name = kwargs.get("file_name", "plot.png")
if filename is None:
filename = file_name
del file_name # avoid accidental use
@@ -144,7 +147,7 @@ def __init__(
self._postprocess = postprocess
self._init_summary()
self._data: Dict[str, List[Tuple[Any, Any]]] = {k: [] for k in y_keys}
- self._writer = kwargs.get('writer', None)
+ self._writer = kwargs.get("writer", None)
@staticmethod
def available() -> bool:
@@ -173,8 +176,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
for name, value in stats.items():
stats_cpu[name] = float(value) # copy to CPU
- stats_cpu['epoch'] = manager.epoch
- stats_cpu['iteration'] = manager.iteration
+ stats_cpu["epoch"] = manager.epoch
+ stats_cpu["iteration"] = manager.iteration
x = stats_cpu[self._x_key]
data = self._data
@@ -200,9 +203,14 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if self._postprocess is not None:
self._postprocess(f, a, summary)
leg = a.legend(
- bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
- writer(self._file_name, manager.out, (f, leg, plt), # type: ignore
- savefun=matplotlib_savefun)
+ bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0
+ )
+ writer(
+ self._file_name,
+ manager.out,
+ (f, leg, plt), # type: ignore
+ savefun=matplotlib_savefun,
+ )
else:
print(
f"[WARNING] No data found for key {self._y_keys}, "
@@ -215,11 +223,11 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
self._init_summary()
def state_dict(self) -> Dict[str, Any]:
- state = {'_plot_{}'.format(self._file_name): json.dumps(self._data)}
+ state = {"_plot_{}".format(self._file_name): json.dumps(self._data)}
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- key = '_plot_{}'.format(self._file_name)
+ key = "_plot_{}".format(self._file_name)
self._data = json.loads(to_load[key])
def _init_summary(self) -> None:
diff --git a/pytorch_pfn_extras/training/extensions/print_report.py b/pytorch_pfn_extras/training/extensions/print_report.py
index c2f33bad8..86f5fa11c 100644
--- a/pytorch_pfn_extras/training/extensions/print_report.py
+++ b/pytorch_pfn_extras/training/extensions/print_report.py
@@ -1,17 +1,20 @@
-from copy import deepcopy
import os
import sys
-from typing import Any, Dict, IO, List, Optional, Sequence, Tuple, Union
+from copy import deepcopy
+from typing import IO, Any, Dict, List, Optional, Sequence, Tuple, Union
from pytorch_pfn_extras.training import extension
-from pytorch_pfn_extras.training.extensions import log_report \
- as log_report_module
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
+from pytorch_pfn_extras.training.extensions import (
+ log_report as log_report_module,
+)
from pytorch_pfn_extras.training.extensions import util
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
def create_header_and_templates(
- entries: Sequence[str],
+ entries: Sequence[str],
) -> Tuple[str, List[Tuple[str, str, str]]]:
"""Construct header and templates from `entries`
@@ -25,35 +28,36 @@ def create_header_and_templates(
# format information
entry_widths = [max(10, len(s)) for s in entries]
- header = ' '.join(('{:%d}' % w for w in entry_widths)).format(
- *entries) + '\n'
+ header = (
+ " ".join(("{:%d}" % w for w in entry_widths)).format(*entries) + "\n"
+ )
templates = []
for entry, w in zip(entries, entry_widths):
- templates.append((entry, '{:<%dg} ' % w, ' ' * (w + 2)))
+ templates.append((entry, "{:<%dg} " % w, " " * (w + 2)))
return header, templates
def filter_and_sort_entries(
- all_entries: List[str],
- unit: str = 'epoch',
+ all_entries: List[str],
+ unit: str = "epoch",
) -> List[str]:
entries = deepcopy(all_entries)
# TODO(nakago): sort other entries if necessary
- if 'iteration' in entries:
+ if "iteration" in entries:
# move iteration to head
- entries.pop(entries.index('iteration'))
- if unit == 'iteration':
- entries = ['iteration'] + entries
- if 'epoch' in entries:
+ entries.pop(entries.index("iteration"))
+ if unit == "iteration":
+ entries = ["iteration"] + entries
+ if "epoch" in entries:
# move epoch to head
- entries.pop(entries.index('epoch'))
- if unit == 'epoch':
- entries = ['epoch'] + entries
- if 'elapsed_time' in entries:
+ entries.pop(entries.index("epoch"))
+ if unit == "epoch":
+ entries = ["epoch"] + entries
+ if "elapsed_time" in entries:
# move elapsed_time to tail
- entries.pop(entries.index('elapsed_time'))
- entries.append('elapsed_time')
+ entries.pop(entries.index("elapsed_time"))
+ entries.append("elapsed_time")
return entries
@@ -76,10 +80,10 @@ class PrintReport(extension.Extension):
"""
def __init__(
- self,
- entries: Optional[Sequence[str]] = None,
- log_report: Union[str, log_report_module.LogReport] = 'LogReport',
- out: IO[Any] = sys.stdout,
+ self,
+ entries: Optional[Sequence[str]] = None,
+ log_report: Union[str, log_report_module.LogReport] = "LogReport",
+ out: IO[Any] = sys.stdout,
) -> None:
if entries is None:
self._infer_entries = True
@@ -98,21 +102,20 @@ def __init__(
self._all_entries: List[str] = []
def get_log_report(
- self,
- manager: ExtensionsManagerProtocol,
+ self,
+ manager: ExtensionsManagerProtocol,
) -> log_report_module.LogReport:
log_report = self._log_report
if isinstance(log_report, str):
ext = manager.get_extension(log_report)
if not isinstance(ext, log_report_module.LogReport):
- raise TypeError('`log_report` must be LogReport object')
+ raise TypeError("`log_report` must be LogReport object")
return ext
elif isinstance(log_report, log_report_module.LogReport):
log_report(manager) # update the log report
return log_report
else:
- raise TypeError('log report has a wrong type %s' %
- type(log_report))
+ raise TypeError("log report has a wrong type %s" % type(log_report))
@property
def _log_looker(self) -> log_report_module._LogLooker:
@@ -133,12 +136,13 @@ def _update_entries(self, log_report: log_report_module.LogReport) -> None:
updated_flag = True
if updated_flag:
- if hasattr(log_report, '_trigger') and hasattr(log_report._trigger,
- 'unit'):
+ if hasattr(log_report, "_trigger") and hasattr(
+ log_report._trigger, "unit"
+ ):
unit = log_report._trigger.unit # type: ignore[attr-defined]
else:
# Failed to infer `unit`, use epoch as default
- unit = 'epoch'
+ unit = "epoch"
entries = filter_and_sort_entries(self._all_entries, unit=unit)
self._entries = entries
header, templates = create_header_and_templates(entries)
@@ -160,23 +164,23 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
for line in self._log_looker.get():
# delete the printed contents from the current cursor
- if os.name == 'nt':
+ if os.name == "nt":
util.erase_console(0, 0)
else:
- out.write('\033[J')
+ out.write("\033[J")
self._print(line)
self._log_looker.clear()
def state_dict(self) -> Dict[str, Any]:
log_report = self._log_report
if isinstance(log_report, log_report_module.LogReport):
- return {'_log_report': log_report.state_dict()}
+ return {"_log_report": log_report.state_dict()}
return {}
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
log_report = self._log_report
if isinstance(log_report, log_report_module.LogReport):
- log_report.load_state_dict(to_load['_log_report'])
+ log_report.load_state_dict(to_load["_log_report"])
def _print(self, observation: log_report_module.Observation) -> None:
out = self._out
@@ -185,6 +189,6 @@ def _print(self, observation: log_report_module.Observation) -> None:
out.write(template.format(observation[entry]))
else:
out.write(empty)
- out.write('\n')
- if hasattr(out, 'flush'):
+ out.write("\n")
+ if hasattr(out, "flush"):
out.flush()
diff --git a/pytorch_pfn_extras/training/extensions/print_report_notebook.py b/pytorch_pfn_extras/training/extensions/print_report_notebook.py
index 2d0c6d928..f09c47b1f 100644
--- a/pytorch_pfn_extras/training/extensions/print_report_notebook.py
+++ b/pytorch_pfn_extras/training/extensions/print_report_notebook.py
@@ -1,15 +1,16 @@
import sys
-from typing import Any, IO, List, Optional, Union
+from typing import IO, Any, List, Optional, Union
from IPython.display import display
from ipywidgets import HTML
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
+from pytorch_pfn_extras.training.extensions import (
+ log_report as log_report_module,
+)
from pytorch_pfn_extras.training.extensions.print_report import PrintReport
-from pytorch_pfn_extras.training.extensions import log_report \
- as log_report_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
class PrintReportNotebook(PrintReport):
@@ -32,10 +33,10 @@ class PrintReportNotebook(PrintReport):
"""
def __init__(
- self,
- entries: Optional[List[str]] = None,
- log_report: Union[str, log_report_module.LogReport] = 'LogReport',
- out: IO[Any] = sys.stdout,
+ self,
+ entries: Optional[List[str]] = None,
+ log_report: Union[str, log_report_module.LogReport] = "LogReport",
+ out: IO[Any] = sys.stdout,
) -> None:
super(PrintReportNotebook, self).__init__(
entries=entries, log_report=log_report, out=out
@@ -56,4 +57,4 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if self._infer_entries:
# --- update entries ---
self._update_entries(log_report)
- self._widget.value = df[self._entries].to_html(index=False, na_rep='')
+ self._widget.value = df[self._entries].to_html(index=False, na_rep="")
diff --git a/pytorch_pfn_extras/training/extensions/profile_report.py b/pytorch_pfn_extras/training/extensions/profile_report.py
index fdc97c055..95514d275 100644
--- a/pytorch_pfn_extras/training/extensions/profile_report.py
+++ b/pytorch_pfn_extras/training/extensions/profile_report.py
@@ -1,13 +1,15 @@
-from collections import OrderedDict
import json
+from collections import OrderedDict
from typing import Any, Dict, Iterable, List, Optional
from pytorch_pfn_extras import reporting
+from pytorch_pfn_extras.profiler._time_summary import get_time_summary
from pytorch_pfn_extras.training import extension
-from pytorch_pfn_extras.training.extensions import log_report
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-from pytorch_pfn_extras.profiler._time_summary import get_time_summary
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
+from pytorch_pfn_extras.training.extensions import log_report
class ProfileReport(extension.Extension):
@@ -44,15 +46,16 @@ class ProfileReport(extension.Extension):
header (str): header string
templates (str): template string for print values.
"""
+
def __init__(
- self,
- store_keys: Optional[Iterable[str]] = None,
- report_keys: Optional[Iterable[str]] = None,
- trigger: trigger_module.TriggerLike = (1, "epoch"),
- filename: Optional[str] = None,
- append: bool = False,
- format: Optional[str] = None,
- **kwargs: Any,
+ self,
+ store_keys: Optional[Iterable[str]] = None,
+ report_keys: Optional[Iterable[str]] = None,
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
+ filename: Optional[str] = None,
+ append: bool = False,
+ format: Optional[str] = None,
+ **kwargs: Any,
):
self.time_summary = get_time_summary()
# Initializes global TimeSummary.
@@ -73,15 +76,15 @@ def __init__(
filename = log_name
del log_name # avoid accidental use
self._log_name = filename
- self._writer = kwargs.get('writer', None)
+ self._writer = kwargs.get("writer", None)
if format is None and filename is not None:
- if filename.endswith('.jsonl'):
- format = 'json-lines'
- elif filename.endswith('.yaml'):
- format = 'yaml'
+ if filename.endswith(".jsonl"):
+ format = "json-lines"
+ elif filename.endswith(".yaml"):
+ format = "yaml"
else:
- format = 'json'
+ format = "json"
self._append = append
self._format = format
@@ -114,7 +117,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
stats_cpu["elapsed_time"] = manager.elapsed_time
# Recreate dict to fix order of logs
out = OrderedDict(
- [(k, stats_cpu[k]) for k in sorted(stats_cpu.keys())])
+ [(k, stats_cpu[k]) for k in sorted(stats_cpu.keys())]
+ )
self._log.append(out)
@@ -123,9 +127,15 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
log_name = self._log_name.format(**out)
assert self._format is not None
savefun = log_report.LogWriterSaveFunc(
- self._format, self._append)
- writer(log_name, out, self._log, # type: ignore
- savefun=savefun, append=self._append)
+ self._format, self._append
+ )
+ writer(
+ log_name,
+ out,
+ self._log, # type: ignore
+ savefun=savefun,
+ append=self._append,
+ )
if self._append:
self._log = []
diff --git a/pytorch_pfn_extras/training/extensions/progress_bar.py b/pytorch_pfn_extras/training/extensions/progress_bar.py
index 5c32d0944..4859e6be7 100644
--- a/pytorch_pfn_extras/training/extensions/progress_bar.py
+++ b/pytorch_pfn_extras/training/extensions/progress_bar.py
@@ -3,8 +3,10 @@
from typing import Any, List, Optional
from pytorch_pfn_extras.training import extension
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
from pytorch_pfn_extras.training.extensions import util
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
class ProgressBar(extension.Extension):
@@ -28,18 +30,19 @@ class ProgressBar(extension.Extension):
"""
def __init__(
- self,
- training_length: Any = None,
- update_interval: int = 100,
- bar_length: int = 50,
- out: Any = sys.stdout,
+ self,
+ training_length: Any = None,
+ update_interval: int = 100,
+ bar_length: int = 50,
+ out: Any = sys.stdout,
):
self._training_length = training_length
self._update_interval = update_interval
self._bar_length = bar_length
self._out = out
self._pbar = _ManagerProgressBar(
- self._training_length, self._bar_length, self._out)
+ self._training_length, self._bar_length, self._out
+ )
def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if self._pbar.manager is None:
@@ -55,7 +58,6 @@ def finalize(self, manager: ExtensionsManagerProtocol) -> None:
class _ManagerProgressBar(util.ProgressBar):
-
def __init__(self, training_length: Any, bar_length: int, out: Any) -> None:
super().__init__(out)
self.training_length = training_length
@@ -75,37 +77,46 @@ def get_lines(self) -> List[str]:
self.training_length = t.get_training_length() # type: ignore[attr-defined]
length, unit = self.training_length
- if unit == 'iteration':
+ if unit == "iteration":
rate = iteration / length
else:
rate = epoch / length
rate = min(rate, 1.0)
bar_length = self.bar_length
- marks = '#' * int(rate * bar_length)
- lines.append(' total [{}{}] {:6.2%}\n'.format(
- marks, '.' * (bar_length - len(marks)), rate))
+ marks = "#" * int(rate * bar_length)
+ lines.append(
+ " total [{}{}] {:6.2%}\n".format(
+ marks, "." * (bar_length - len(marks)), rate
+ )
+ )
epoch_rate = epoch - int(epoch)
- marks = '#' * int(epoch_rate * bar_length)
- lines.append('this epoch [{}{}] {:6.2%}\n'.format(
- marks, '.' * (bar_length - len(marks)), epoch_rate))
+ marks = "#" * int(epoch_rate * bar_length)
+ lines.append(
+ "this epoch [{}{}] {:6.2%}\n".format(
+ marks, "." * (bar_length - len(marks)), epoch_rate
+ )
+ )
if self.progress_template is None:
self.progress_template = (
- '{0.iteration:10} iter, {0.epoch} epoch / %s %ss\n' %
- self.training_length)
+ "{0.iteration:10} iter, {0.epoch} epoch / %s %ss\n"
+ % self.training_length
+ )
progress = self.progress_template.format(self.manager)
lines.append(progress)
speed_t, speed_e = self.update_speed(iteration, epoch)
- if unit == 'iteration':
+ if unit == "iteration":
estimated_time = (length - iteration) / speed_t
else:
estimated_time = (length - epoch) / speed_e
estimated_time = max(estimated_time, 0.0)
- lines.append('{:10.5g} iters/sec. Estimated time to finish: {}.\n'
- .format(speed_t,
- datetime.timedelta(seconds=estimated_time)))
+ lines.append(
+ "{:10.5g} iters/sec. Estimated time to finish: {}.\n".format(
+ speed_t, datetime.timedelta(seconds=estimated_time)
+ )
+ )
return lines
diff --git a/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py b/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py
index 8207222a0..438c8f9b4 100644
--- a/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py
+++ b/pytorch_pfn_extras/training/extensions/progress_bar_notebook.py
@@ -5,9 +5,10 @@
from IPython.display import display
from ipywidgets import HTML, FloatProgress, HBox, VBox # NOQA
-
from pytorch_pfn_extras.training import extension, trigger # NOQA
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
class ProgressBarNotebook(extension.Extension):
@@ -34,11 +35,11 @@ class ProgressBarNotebook(extension.Extension):
"""
def __init__(
- self,
- training_length: Any = None,
- update_interval: int = 100,
- bar_length: int = 50,
- out: Any = sys.stdout,
+ self,
+ training_length: Any = None,
+ update_interval: int = 100,
+ bar_length: int = 50,
+ out: Any = sys.stdout,
):
self._training_length = training_length
if training_length is not None:
@@ -46,26 +47,31 @@ def __init__(
self._update_interval = update_interval
self._recent_timing: List[Tuple[float, float, float]] = []
- self._total_bar = FloatProgress(description='total',
- min=0, max=1, value=0,
- bar_style='info')
+ self._total_bar = FloatProgress(
+ description="total", min=0, max=1, value=0, bar_style="info"
+ )
self._total_html = HTML()
- self._epoch_bar = FloatProgress(description='this epoch',
- min=0, max=1, value=0,
- bar_style='info')
+ self._epoch_bar = FloatProgress(
+ description="this epoch", min=0, max=1, value=0, bar_style="info"
+ )
self._epoch_html = HTML()
self._status_html = HTML()
- self._widget = VBox([HBox([self._total_bar, self._total_html]),
- HBox([self._epoch_bar, self._epoch_html]),
- self._status_html])
+ self._widget = VBox(
+ [
+ HBox([self._total_bar, self._total_html]),
+ HBox([self._epoch_bar, self._epoch_html]),
+ self._status_html,
+ ]
+ )
def initialize(self, manager: ExtensionsManagerProtocol) -> None:
if self._training_length is None:
t = manager._stop_trigger
if not isinstance(t, trigger.IntervalTrigger):
raise TypeError(
- 'cannot retrieve the training length from %s' % type(t))
+ "cannot retrieve the training length from %s" % type(t)
+ )
self._training_length = t.period, t.unit
self._init_status_template()
@@ -77,7 +83,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
iteration, epoch_detail = manager.iteration, manager.epoch_detail
- if unit == 'iteration':
+ if unit == "iteration":
is_finished = iteration == length
else:
is_finished = epoch_detail == length
@@ -87,8 +93,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
def finalize(self, manager: ExtensionsManagerProtocol) -> None:
if self._total_bar.value != 1:
- self._total_bar.bar_style = 'warning'
- self._epoch_bar.bar_style = 'warning'
+ self._total_bar.bar_style = "warning"
+ self._epoch_bar.bar_style = "warning"
@property
def widget(self) -> VBox:
@@ -102,7 +108,7 @@ def update(self, iteration: int, epoch_detail: float) -> None:
recent_timing.append((iteration, epoch_detail, now))
- if unit == 'iteration':
+ if unit == "iteration":
rate = iteration / length
else:
rate = epoch_detail / length
@@ -113,12 +119,13 @@ def update(self, iteration: int, epoch_detail: float) -> None:
self._epoch_bar.value = epoch_rate
self._epoch_html.value = "{:6.2%}".format(epoch_rate)
- status = self._status_template.format(iteration=iteration,
- epoch=int(epoch_detail))
+ status = self._status_template.format(
+ iteration=iteration, epoch=int(epoch_detail)
+ )
if rate == 1:
- self._total_bar.bar_style = 'success'
- self._epoch_bar.bar_style = 'success'
+ self._total_bar.bar_style = "success"
+ self._epoch_bar.bar_style = "success"
old_t, old_e, old_sec = recent_timing[0]
span = now - old_sec
@@ -126,16 +133,16 @@ def update(self, iteration: int, epoch_detail: float) -> None:
speed_t = (iteration - old_t) / span
speed_e = (epoch_detail - old_e) / span
else:
- speed_t = float('inf')
- speed_e = float('inf')
+ speed_t = float("inf")
+ speed_e = float("inf")
- if unit == 'iteration':
+ if unit == "iteration":
estimated_time = (length - iteration) / speed_t
else:
estimated_time = (length - epoch_detail) / speed_e
- estimate = ('{:10.5g} iters/sec. Estimated time to finish: {}.'
- .format(speed_t,
- datetime.timedelta(seconds=estimated_time)))
+ estimate = "{:10.5g} iters/sec. Estimated time to finish: {}.".format(
+ speed_t, datetime.timedelta(seconds=estimated_time)
+ )
self._status_html.value = status + estimate
@@ -144,5 +151,6 @@ def update(self, iteration: int, epoch_detail: float) -> None:
def _init_status_template(self) -> None:
self._status_template = (
- '{iteration:10} iter, {epoch} epoch / %s %ss
' %
- self._training_length)
+ "{iteration:10} iter, {epoch} epoch / %s %ss
"
+ % self._training_length
+ )
diff --git a/pytorch_pfn_extras/training/extensions/slack.py b/pytorch_pfn_extras/training/extensions/slack.py
index 2319492a9..0c01e0c09 100644
--- a/pytorch_pfn_extras/training/extensions/slack.py
+++ b/pytorch_pfn_extras/training/extensions/slack.py
@@ -1,24 +1,25 @@
import getpass
-import os
import json
-import urllib.request
+import os
import shlex
-import sys
import socket
+import sys
import traceback
import types
-from typing import Any, Callable, Optional, Sequence, Union
+import urllib.request
import warnings
-
+from typing import Any, Callable, Optional, Sequence, Union
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
from pytorch_pfn_extras.training._trigger_util import TriggerLike
-
try:
import slack_sdk
+
_slack_sdk_available = True
except ImportError:
_slack_sdk_available = False
@@ -29,51 +30,49 @@ def _failsafe(func: Callable[[], Any]) -> str:
try:
return str(func())
except Exception:
- return 'UNKNOWN'
+ return "UNKNOWN"
_identity = (
- f'{_failsafe(getpass.getuser)}@{_failsafe(socket.gethostname)} '
- f'[PID {_failsafe(os.getpid)}]')
+ f"{_failsafe(getpass.getuser)}@{_failsafe(socket.gethostname)} "
+ f"[PID {_failsafe(os.getpid)}]"
+)
def _default_msg(
- manager: ExtensionsManagerProtocol,
- context: Any,
+ manager: ExtensionsManagerProtocol,
+ context: Any,
) -> str:
- return f'Epoch #{manager.epoch}'
+ return f"Epoch #{manager.epoch}"
def _default_start_msg(
- manager: ExtensionsManagerProtocol,
- context: Any,
+ manager: ExtensionsManagerProtocol,
+ context: Any,
) -> str:
- cmdline = ' '.join([shlex.quote(x) for x in sys.argv])
- return (
- f'🏃 *Training started! {_identity}*\n'
- f'Command: `{cmdline}`'
- )
+ cmdline = " ".join([shlex.quote(x) for x in sys.argv])
+ return f"🏃 *Training started! {_identity}*\n" f"Command: `{cmdline}`"
def _default_end_msg(
- manager: ExtensionsManagerProtocol,
- context: Any,
+ manager: ExtensionsManagerProtocol,
+ context: Any,
) -> str:
- return f'✅ *Training finished! {_identity}*'
+ return f"✅ *Training finished! {_identity}*"
def _default_error_msg(
- manager: ExtensionsManagerProtocol,
- exc: Exception,
- context: Any,
+ manager: ExtensionsManagerProtocol,
+ exc: Exception,
+ context: Any,
) -> str:
return (
- f'❌ *Error during training. {_identity}*\n'
- f'{type(exc).__name__}: {exc}\n'
- 'Traceback:\n'
- '```\n'
- ''.join(traceback.format_tb(exc.__traceback__)).strip() + '\n'
- '```'
+ f"❌ *Error during training. {_identity}*\n"
+ f"{type(exc).__name__}: {exc}\n"
+ "Traceback:\n"
+ "```\n"
+ "".join(traceback.format_tb(exc.__traceback__)).strip() + "\n"
+ "```"
)
@@ -107,8 +106,7 @@ def _default_error_msg(
class _SlackBase(extension.Extension):
-
- trigger: TriggerLike = (1, 'epoch')
+ trigger: TriggerLike = (1, "epoch")
default_msg = _default_msg
default_start_msg = _default_start_msg
@@ -132,25 +130,25 @@ def _upload_files(self, filenames: Sequence[str]) -> Sequence[str]:
raise NotImplementedError
def _format(
- self,
- msg: Union[_MessageFunc, str],
- default: Optional[_MessageFunc],
- manager: ExtensionsManagerProtocol,
+ self,
+ msg: Union[_MessageFunc, str],
+ default: Optional[_MessageFunc],
+ manager: ExtensionsManagerProtocol,
) -> str:
- default_str = '' if default is None else default(manager, self._context)
+ default_str = "" if default is None else default(manager, self._context)
if isinstance(msg, str):
return msg.format(
manager=manager,
context=self._context,
default=default_str,
- **manager.observation
+ **manager.observation,
)
return msg(manager, self._context)
def _format_error(
- self,
- manager: ExtensionsManagerProtocol,
- error: Exception,
+ self,
+ manager: ExtensionsManagerProtocol,
+ error: Exception,
) -> str:
msg = self._error_msg
assert msg is not None
@@ -159,7 +157,7 @@ def _format_error(
manager=manager,
context=self._context,
default=_default_error_msg(manager, error, self._context),
- error=error
+ error=error,
)
return msg(manager, error, self._context)
@@ -172,30 +170,30 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
pass
elif isinstance(self._filenames, Sequence):
filenames = [
- self._format(f, None, manager) for f in self._filenames]
+ self._format(f, None, manager) for f in self._filenames
+ ]
else: # callable
filenames = self._filenames(manager, self._context)
- needs_upload = (
- len(filenames) != 0
- and (self._upload_trigger is None
- or self._upload_trigger(manager)))
+ needs_upload = len(filenames) != 0 and (
+ self._upload_trigger is None or self._upload_trigger(manager)
+ )
if self._msg is None and not needs_upload:
# The message is not set and no files to upload.
return
- text = ''
+ text = ""
if self._msg is not None:
text = self._format(self._msg, _default_msg, manager)
# TODO(kmaehashi): keep track of already uploaded files and warn
# TODO(kmaehashi): warn too many or too large files
- attachments = ''
+ attachments = ""
if needs_upload:
permalinks = self._upload_files(filenames)
- attachments = ''.join([f'<{link}| >' for link in permalinks])
+ attachments = "".join([f"<{link}| >" for link in permalinks])
self._post_message(text + attachments)
@@ -203,19 +201,21 @@ def initialize(self, manager: ExtensionsManagerProtocol) -> None:
if not self._available or self._start_msg is None:
return
self._post_message(
- self._format(self._start_msg, _default_start_msg, manager))
+ self._format(self._start_msg, _default_start_msg, manager)
+ )
def finalize(self, manager: ExtensionsManagerProtocol) -> None:
if not self._available or self._end_msg is None:
return
self._post_message(
- self._format(self._end_msg, _default_end_msg, manager))
+ self._format(self._end_msg, _default_end_msg, manager)
+ )
def on_error(
- self,
- manager: ExtensionsManagerProtocol,
- exc: Exception,
- tb: types.TracebackType
+ self,
+ manager: ExtensionsManagerProtocol,
+ exc: Exception,
+ tb: types.TracebackType,
) -> None:
if not self._available or self._error_msg is None:
return
@@ -223,7 +223,8 @@ def on_error(
class Slack(_SlackBase):
- __doc__ = """An extension to communicate with Slack.
+ __doc__ = (
+ """An extension to communicate with Slack.
.. admonition:: Example
@@ -236,7 +237,9 @@ class Slack(_SlackBase):
... filenames=["result/statistics.png"],
... upload_trigger=(max_epoch, 'epoch'),
... )
- """ + _message_spec_doc + """
+ """
+ + _message_spec_doc
+ + """
This extension can upload files along with the message when triggered.
``filenames`` can be a list of filenames (the same formatting rule as
``msg`` apply), or a callable taking (ExtensionsManager, context) and
@@ -273,17 +276,18 @@ class Slack(_SlackBase):
variable ``SLACK_BOT_TOKEN`` will be used.
Optional, default is ``None``.
"""
+ )
- trigger: TriggerLike = (1, 'epoch')
+ trigger: TriggerLike = (1, "epoch")
def __init__(
self,
channel: str,
msg: Optional[Union[str, _MessageFunc]] = None,
*,
- start_msg: Optional[Union[str, _MessageFunc]] = '{default}',
- end_msg: Optional[Union[str, _MessageFunc]] = '{default}',
- error_msg: Optional[Union[str, _ErrorMessageFunc]] = '{default}',
+ start_msg: Optional[Union[str, _MessageFunc]] = "{default}",
+ end_msg: Optional[Union[str, _MessageFunc]] = "{default}",
+ error_msg: Optional[Union[str, _ErrorMessageFunc]] = "{default}",
thread: bool = True,
filenames: Optional[Union[Sequence[str], _FilenamesFunc]] = None,
upload_trigger: Optional[TriggerLike] = None,
@@ -294,8 +298,9 @@ def __init__(
if not _slack_sdk_available:
self._available = False
warnings.warn(
- '`slack_sdk` package is unavailable. '
- 'The Slack extension will do nothing.')
+ "`slack_sdk` package is unavailable. "
+ "The Slack extension will do nothing."
+ )
return
self._channel = channel
@@ -311,10 +316,11 @@ def __init__(
self._upload_trigger = trigger_module.get_trigger(upload_trigger)
if token is None:
- token = os.environ.get('SLACK_BOT_TOKEN', None)
+ token = os.environ.get("SLACK_BOT_TOKEN", None)
if token is None:
raise RuntimeError(
- 'A bot `token` is needed for communicating with Slack')
+ "A bot `token` is needed for communicating with Slack"
+ )
self._client = slack_sdk.WebClient(token=token)
self._thread_ts: Optional[str] = None
@@ -324,11 +330,12 @@ def _upload_files(self, filenames: Sequence[str]) -> Sequence[str]:
for filename in filenames:
response = self._client.files_upload(file=filename)
assert response.get("ok") # type: ignore[no-untyped-call]
- permalinks.append(response['file']['permalink'])
+ permalinks.append(response["file"]["permalink"])
except Exception as e:
warnings.warn(
- f'Slack upload failed: {type(e).__name__}: {e} '
- f'[{filenames}]')
+ f"Slack upload failed: {type(e).__name__}: {e} "
+ f"[{filenames}]"
+ )
return permalinks
def _post_message(self, text: str) -> None:
@@ -344,12 +351,13 @@ def _post_message(self, text: str) -> None:
self._thread_ts = ts
except Exception as e:
warnings.warn(
- f'Slack post failed: {type(e).__name__}: {e} '
- f'[{text}]')
+ f"Slack post failed: {type(e).__name__}: {e} " f"[{text}]"
+ )
class SlackWebhook(_SlackBase):
- __doc__ = """An extension to communicate with Slack using Incoming Webhook.
+ __doc__ = (
+ """An extension to communicate with Slack using Incoming Webhook.
.. admonition:: Example
@@ -358,7 +366,9 @@ class SlackWebhook(_SlackBase):
... msg="Epoch #{manager.epoch}: loss = {val/loss}",
... end_msg="{default} \\n <@username> Check out the result!",
... )
- """ + _message_spec_doc + """
+ """
+ + _message_spec_doc
+ + """
Args:
url (str): Incoming webhook URL to send messages.
msg (str, callable, or None): A message to be sent when triggered.
@@ -373,15 +383,16 @@ class SlackWebhook(_SlackBase):
context (object): Any arbitrary user object you will need when
generating a message.
"""
+ )
def __init__(
self,
url: str,
msg: Optional[Union[str, _MessageFunc]] = None,
*,
- start_msg: Optional[Union[str, _MessageFunc]] = '{default}',
- end_msg: Optional[Union[str, _MessageFunc]] = '{default}',
- error_msg: Optional[Union[str, _ErrorMessageFunc]] = '{default}',
+ start_msg: Optional[Union[str, _MessageFunc]] = "{default}",
+ end_msg: Optional[Union[str, _MessageFunc]] = "{default}",
+ error_msg: Optional[Union[str, _ErrorMessageFunc]] = "{default}",
context: Any = None,
) -> None:
super().__init__()
@@ -393,12 +404,12 @@ def __init__(
self._context = context
def _post_message(self, text: str) -> None:
- payload = json.dumps({'text': text}).encode('utf-8')
- request_headers = {'Content-Type': 'application/json; charset=utf-8'}
+ payload = json.dumps({"text": text}).encode("utf-8")
+ request_headers = {"Content-Type": "application/json; charset=utf-8"}
request = urllib.request.Request(
url=self._url,
data=payload,
- method='POST',
+ method="POST",
headers=request_headers,
)
try:
@@ -406,5 +417,6 @@ def _post_message(self, text: str) -> None:
assert 200 <= response.status < 300, response
except Exception as e:
warnings.warn(
- f'Slack WebHook request failed: {type(e).__name__}: {e} '
- f'[{text}]')
+ f"Slack WebHook request failed: {type(e).__name__}: {e} "
+ f"[{text}]"
+ )
diff --git a/pytorch_pfn_extras/training/extensions/snapshot_writers.py b/pytorch_pfn_extras/training/extensions/snapshot_writers.py
index 988c86df5..e9ce7f7fb 100644
--- a/pytorch_pfn_extras/training/extensions/snapshot_writers.py
+++ b/pytorch_pfn_extras/training/extensions/snapshot_writers.py
@@ -1,2 +1,3 @@
from pytorch_pfn_extras.writing import * # NOQA
+
# TODO(ecastill) deprecate this
diff --git a/pytorch_pfn_extras/training/extensions/util.py b/pytorch_pfn_extras/training/extensions/util.py
index eebf8fab3..37b42c134 100644
--- a/pytorch_pfn_extras/training/extensions/util.py
+++ b/pytorch_pfn_extras/training/extensions/util.py
@@ -4,11 +4,13 @@
import time
from typing import Deque, Optional, Sequence, TextIO, Tuple
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
try:
from IPython import get_ipython
+
_ipython_available = True
except ImportError:
_ipython_available = False
@@ -16,11 +18,11 @@
def _is_notebook() -> bool:
if _ipython_available and get_ipython() is not None:
- return 'IPKernelApp' in get_ipython().config
+ return "IPKernelApp" in get_ipython().config
return False
-if os.name == 'nt':
+if os.name == "nt":
import ctypes
from ctypes import windll # type: ignore [attr-defined]
@@ -29,10 +31,13 @@ def _is_notebook() -> bool:
_COORD = ctypes.wintypes._COORD
class _CONSOLE_SCREEN_BUFFER_INFO(ctypes.Structure):
- _fields_ = [('dwSize', _COORD), ('dwCursorPosition', _COORD),
- ('wAttributes', ctypes.c_ushort),
- ('srWindow', ctypes.wintypes.SMALL_RECT),
- ('dwMaximumWindowSize', _COORD)]
+ _fields_ = [
+ ("dwSize", _COORD),
+ ("dwCursorPosition", _COORD),
+ ("wAttributes", ctypes.c_ushort),
+ ("srWindow", ctypes.wintypes.SMALL_RECT),
+ ("dwMaximumWindowSize", _COORD),
+ ]
def set_console_cursor_position(x: int, y: int) -> None:
"""Set relative cursor position from current position to (x,y)"""
@@ -64,29 +69,31 @@ def erase_console(x: int, y: int, mode: int = 0) -> None:
cur_pos = csbi.dwCursorPosition
wr = ctypes.c_ulong()
if mode == 0:
- num = csbi.srWindow.Right * (
- csbi.srWindow.Bottom - cur_pos.Y) - cur_pos.X
+ num = (
+ csbi.srWindow.Right * (csbi.srWindow.Bottom - cur_pos.Y)
+ - cur_pos.X
+ )
windll.kernel32.FillConsoleOutputCharacterA(
- whnd, ord(' '), num, cur_pos, ctypes.byref(wr))
+ whnd, ord(" "), num, cur_pos, ctypes.byref(wr)
+ )
elif mode == 1:
num = cur_pos.X
windll.kernel32.FillConsoleOutputCharacterA(
- whnd, ord(' '), num, _COORD(0, cur_pos.Y), ctypes.byref(wr))
+ whnd, ord(" "), num, _COORD(0, cur_pos.Y), ctypes.byref(wr)
+ )
elif mode == 2:
- os.system('cls')
+ os.system("cls")
class ProgressBar:
-
def __init__(self, out: Optional[TextIO] = None) -> None:
self._out = sys.stdout if out is None else out
- self._recent_timing: Deque[Tuple[int, float, float]] = collections.deque(
- [], maxlen=100)
+ self._recent_timing: Deque[
+ Tuple[int, float, float]
+ ] = collections.deque([], maxlen=100)
def update_speed(
- self,
- iteration: int,
- epoch_detail: float
+ self, iteration: int, epoch_detail: float
) -> Tuple[float, float]:
now = time.time()
self._recent_timing.append((iteration, epoch_detail, now))
@@ -96,16 +103,15 @@ def update_speed(
speed_t = (iteration - old_t) / span
speed_e = (epoch_detail - old_e) / span
else:
- speed_t = float('inf')
- speed_e = float('inf')
+ speed_t = float("inf")
+ speed_e = float("inf")
return speed_t, speed_e
def get_lines(self) -> Sequence[str]:
raise NotImplementedError
def update(
- self,
- manager: Optional[ExtensionsManagerProtocol] = None
+ self, manager: Optional[ExtensionsManagerProtocol] = None
) -> None:
self.erase_console()
@@ -121,18 +127,18 @@ def close(self) -> None:
self.flush()
def erase_console(self) -> None:
- if os.name == 'nt':
+ if os.name == "nt":
erase_console(0, 0)
else:
- self._out.write('\033[J')
+ self._out.write("\033[J")
def move_cursor_up(self, n: int) -> None:
# move the cursor to the head of the progress bar
- if os.name == 'nt':
- set_console_cursor_position(0, - n)
+ if os.name == "nt":
+ set_console_cursor_position(0, -n)
else:
- self._out.write('\033[{:d}A'.format(n))
+ self._out.write("\033[{:d}A".format(n))
def flush(self) -> None:
- if hasattr(self._out, 'flush'):
+ if hasattr(self._out, "flush"):
self._out.flush()
diff --git a/pytorch_pfn_extras/training/extensions/value_observation.py b/pytorch_pfn_extras/training/extensions/value_observation.py
index 906650b70..0a615db9c 100644
--- a/pytorch_pfn_extras/training/extensions/value_observation.py
+++ b/pytorch_pfn_extras/training/extensions/value_observation.py
@@ -1,14 +1,15 @@
from typing import Any, Callable
import torch.optim
-
from pytorch_pfn_extras.training import extension
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
def observe_value(
- observation_key: str,
- target_func: Callable[[ExtensionsManagerProtocol], Any],
+ observation_key: str,
+ target_func: Callable[[ExtensionsManagerProtocol], Any],
) -> Callable[[ExtensionsManagerProtocol], None]:
"""Returns an extension to continuously record a value.
@@ -26,17 +27,20 @@ def observe_value(
.ExtensionsManager>` method.
"""
+
@extension.make_extension(
- trigger=(1, 'epoch'), priority=extension.PRIORITY_WRITER)
+ trigger=(1, "epoch"), priority=extension.PRIORITY_WRITER
+ )
def _observe_value(manager: ExtensionsManagerProtocol) -> None:
manager.observation[observation_key] = target_func(manager)
+
return _observe_value
def observe_lr(
- optimizer: torch.optim.Optimizer,
- param_group: int = 0,
- observation_key: str = 'lr',
+ optimizer: torch.optim.Optimizer,
+ param_group: int = 0,
+ observation_key: str = "lr",
) -> Any:
"""Returns an extension to record the learning rate.
@@ -57,4 +61,5 @@ def observe_lr(
"""
return observe_value(
observation_key,
- lambda manager: optimizer.param_groups[param_group]['lr'])
+ lambda manager: optimizer.param_groups[param_group]["lr"],
+ )
diff --git a/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py b/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py
index 7f9b9105e..a54d0446f 100644
--- a/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py
+++ b/pytorch_pfn_extras/training/extensions/variable_statistics_plot.py
@@ -1,13 +1,13 @@
-from typing import Any, Dict, Optional, Tuple, Union
import warnings
+from typing import Any, Dict, Optional, Tuple, Union
import numpy
import torch
-
from pytorch_pfn_extras.training import extension
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
matplotlib: Any = None
_available: Optional[bool] = None
@@ -16,10 +16,13 @@
_plot_common_kwargs: Any = None
-def percentile(a: torch.Tensor, q: Union[float, Tuple[float, ...]], axis: int) -> Any:
+def percentile(
+ a: torch.Tensor, q: Union[float, Tuple[float, ...]], axis: int
+) -> Any:
# fallback to numpy
return torch.Tensor(
- numpy.percentile(a.cpu().numpy(), q, axis)) # type: ignore[no-untyped-call]
+ numpy.percentile(a.cpu().numpy(), q, axis)
+ ) # type: ignore[no-untyped-call]
def matplotlib_savefun(target: Tuple[Any, Any], file_o: Any) -> None:
@@ -34,20 +37,24 @@ def _try_import_matplotlib() -> None:
global _plot_color, _plot_color_trans, _plot_common_kwargs
try:
import matplotlib
+
_available = True
except ImportError:
_available = False
if _available:
- if hasattr(matplotlib.colors, 'to_rgba'):
+ if hasattr(matplotlib.colors, "to_rgba"):
_to_rgba = matplotlib.colors.to_rgba
else:
# For matplotlib 1.x
_to_rgba = matplotlib.colors.ColorConverter().to_rgba
- _plot_color = _to_rgba('#1f77b4') # C0 color
+ _plot_color = _to_rgba("#1f77b4") # C0 color
_plot_color_trans = _plot_color[:3] + (0.2,) # apply alpha
_plot_common_kwargs = {
- 'alpha': 0.2, 'linewidth': 0, 'color': _plot_color_trans}
+ "alpha": 0.2,
+ "linewidth": 0,
+ "color": _plot_color_trans,
+ }
def _check_available() -> None:
@@ -55,10 +62,12 @@ def _check_available() -> None:
_try_import_matplotlib()
if not _available:
- warnings.warn('matplotlib is not installed on your environment, '
- 'so nothing will be plotted at this time. '
- 'Please install matplotlib to plot figures.\n\n'
- ' $ pip install matplotlib\n')
+ warnings.warn(
+ "matplotlib is not installed on your environment, "
+ "so nothing will be plotted at this time. "
+ "Please install matplotlib to plot figures.\n\n"
+ " $ pip install matplotlib\n"
+ )
def _unpack_variables(x: Any, memo: Any = None) -> Any:
@@ -79,10 +88,10 @@ class Reservoir:
"""Reservoir sample with a fixed sized buffer."""
def __init__(
- self,
- size: int,
- data_shape: Tuple[int, ...],
- dtype: Any = numpy.float32,
+ self,
+ size: int,
+ data_shape: Tuple[int, ...],
+ dtype: Any = numpy.float32,
) -> None:
self.size = size
self.data = numpy.zeros((size,) + data_shape, dtype=dtype)
@@ -93,15 +102,17 @@ def add(self, x: Any, idx: Any = None) -> None:
if self.counter < self.size:
self.data[self.counter] = x
self.idxs[self.counter] = idx or self.counter
- elif self.counter >= self.size and \
- numpy.random.random() < self.size / float(self.counter + 1):
+ elif (
+ self.counter >= self.size
+ and numpy.random.random() < self.size / float(self.counter + 1)
+ ):
i = numpy.random.randint(self.size)
self.data[i] = x
self.idxs[i] = idx or self.counter
self.counter += 1
def get_data(self) -> Tuple[Any, Any]:
- idxs = self.idxs[:min(self.counter, self.size)]
+ idxs = self.idxs[: min(self.counter, self.size)]
sorted_args = numpy.argsort(idxs)
return idxs[sorted_args], self.data[sorted_args]
@@ -111,20 +122,22 @@ class Statistician:
"""Helper to compute basic NumPy-like statistics."""
def __init__(
- self,
- collect_mean: bool,
- collect_std: bool,
- percentile_sigmas: Union[float, Tuple[float, ...]],
+ self,
+ collect_mean: bool,
+ collect_std: bool,
+ percentile_sigmas: Union[float, Tuple[float, ...]],
) -> None:
self.collect_mean = collect_mean
self.collect_std = collect_std
self.percentile_sigmas = percentile_sigmas
- def __call__(self, x: Any, axis: Any = 0, dtype: Any = None) -> Dict[str, Any]:
+ def __call__(
+ self, x: Any, axis: Any = 0, dtype: Any = None
+ ) -> Dict[str, Any]:
if axis is None:
axis = tuple(range(x.ndim))
elif not isinstance(axis, (tuple, list)):
- axis = axis,
+ axis = (axis,)
return self.collect(x, axis)
@@ -132,14 +145,14 @@ def collect(self, x: Any, axis: int) -> Dict[str, Any]:
out = dict()
if self.collect_mean:
- out['mean'] = x.mean(axis=axis)
+ out["mean"] = x.mean(axis=axis)
if self.collect_std:
- out['std'] = x.std(axis=axis)
+ out["std"] = x.std(axis=axis)
if self.percentile_sigmas:
p = percentile(x, self.percentile_sigmas, axis=axis)
- out['percentile'] = p
+ out["percentile"] = p
return out
@@ -215,26 +228,34 @@ class VariableStatisticsPlot(extension.Extension):
"""
def __init__(
- self,
- targets: Any,
- max_sample_size: int = 1000,
- report_data: bool = True,
- report_grad: bool = True,
- plot_mean: bool = True,
- plot_std: bool = True,
- percentile_sigmas: Union[float, Tuple[float, ...]] = (
- 0, 0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87, 100),
- trigger: trigger_module.TriggerLike = (1, 'epoch'),
- filename: Optional[str] = None,
- figsize: Optional[Tuple[int, ...]] = None,
- marker: Optional[str] = None,
- grid: bool = True,
- **kwargs: Any,
+ self,
+ targets: Any,
+ max_sample_size: int = 1000,
+ report_data: bool = True,
+ report_grad: bool = True,
+ plot_mean: bool = True,
+ plot_std: bool = True,
+ percentile_sigmas: Union[float, Tuple[float, ...]] = (
+ 0,
+ 0.13,
+ 2.28,
+ 15.87,
+ 50,
+ 84.13,
+ 97.72,
+ 99.87,
+ 100,
+ ),
+ trigger: trigger_module.TriggerLike = (1, "epoch"),
+ filename: Optional[str] = None,
+ figsize: Optional[Tuple[int, ...]] = None,
+ marker: Optional[str] = None,
+ grid: bool = True,
+ **kwargs: Any,
):
-
_check_available()
- file_name = kwargs.get('file_name', 'statistics.png')
+ file_name = kwargs.get("file_name", "statistics.png")
if filename is None:
filename = file_name
del file_name # avoid accidental use
@@ -242,24 +263,27 @@ def __init__(
self._vars = _unpack_variables(targets)
if not self._vars:
raise ValueError(
- 'Need at least one variables for which to collect statistics.'
- '\nActual: 0 <= 0')
+ "Need at least one variables for which to collect statistics."
+ "\nActual: 0 <= 0"
+ )
if not any((plot_mean, plot_std, bool(percentile_sigmas))):
- raise ValueError('Nothing to plot')
+ raise ValueError("Nothing to plot")
self._keys = []
if report_data:
- self._keys.append('data')
+ self._keys.append("data")
if report_grad:
- self._keys.append('grad')
+ self._keys.append("grad")
self._report_data = report_data
self._report_grad = report_grad
self._statistician = Statistician(
- collect_mean=plot_mean, collect_std=plot_std,
- percentile_sigmas=percentile_sigmas)
+ collect_mean=plot_mean,
+ collect_std=plot_std,
+ percentile_sigmas=percentile_sigmas,
+ )
self._plot_mean = plot_mean
self._plot_std = plot_std
@@ -270,7 +294,7 @@ def __init__(
self._figsize = figsize
self._marker = marker
self._grid = grid
- self._writer = kwargs.get('writer', None)
+ self._writer = kwargs.get("writer", None)
if not self._plot_percentile:
n_percentile = 0
@@ -280,7 +304,9 @@ def __init__(
else:
n_percentile = len(percentile_sigmas)
self._data_shape = (
- len(self._keys), int(plot_mean) + int(plot_std) + n_percentile)
+ len(self._keys),
+ int(plot_mean) + int(plot_std) + n_percentile,
+ )
self._samples = Reservoir(max_sample_size, data_shape=self._data_shape)
@staticmethod
@@ -309,15 +335,17 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
stat_list = []
if self._plot_mean:
stat_list.append(
- numpy.atleast_1d(stat_dict['mean'].cpu().numpy()))
+ numpy.atleast_1d(stat_dict["mean"].cpu().numpy())
+ )
if self._plot_std:
stat_list.append(
- numpy.atleast_1d(stat_dict['std'].cpu().numpy()))
+ numpy.atleast_1d(stat_dict["std"].cpu().numpy())
+ )
if self._plot_percentile:
- stat_list.append(
- numpy.atleast_1d(stat_dict['percentile']))
+ stat_list.append(numpy.atleast_1d(stat_dict["percentile"]))
stats[i] = numpy.concatenate( # type: ignore[no-untyped-call]
- stat_list, axis=0)
+ stat_list, axis=0
+ )
self._samples.add(stats, idx=manager.iteration)
@@ -325,16 +353,18 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
self.save_plot_using_module(plt, manager)
def save_plot_using_module(
- self,
- plt: Any,
- manager: ExtensionsManagerProtocol,
+ self,
+ plt: Any,
+ manager: ExtensionsManagerProtocol,
) -> None:
- nrows = int(self._plot_mean or self._plot_std) \
- + int(self._plot_percentile)
+ nrows = int(self._plot_mean or self._plot_std) + int(
+ self._plot_percentile
+ )
ncols = len(self._keys)
fig, axes = plt.subplots(
- nrows, ncols, figsize=self._figsize, sharex=True)
+ nrows, ncols, figsize=self._figsize, sharex=True
+ )
if not isinstance(axes, numpy.ndarray): # single subplot
axes = numpy.asarray([axes])
@@ -362,17 +392,26 @@ def save_plot_using_module(
if self._plot_mean or self._plot_std:
if self._plot_mean and self._plot_std:
ax.errorbar(
- idxs, data[:, col, 0], data[:, col, 1],
- color=_plot_color, ecolor=_plot_color_trans,
- label='mean, std', marker=self._marker)
+ idxs,
+ data[:, col, 0],
+ data[:, col, 1],
+ color=_plot_color,
+ ecolor=_plot_color_trans,
+ label="mean, std",
+ marker=self._marker,
+ )
else:
if self._plot_mean:
- label = 'mean'
+ label = "mean"
elif self._plot_std:
- label = 'std'
+ label = "std"
ax.plot(
- idxs, data[:, col, 0], color=_plot_color, label=label,
- marker=self._marker)
+ idxs,
+ data[:, col, 0],
+ color=_plot_color,
+ label=label,
+ marker=self._marker,
+ )
row += 1
if self._plot_percentile:
@@ -384,22 +423,27 @@ def save_plot_using_module(
# percentile is the mid percentile and the number of
# percentiles are odd
ax.plot(
- idxs, data[:, col, offset + i], color=_plot_color,
- label='percentile', marker=self._marker)
+ idxs,
+ data[:, col, offset + i],
+ color=_plot_color,
+ label="percentile",
+ marker=self._marker,
+ )
else:
if i == n_percentile_mid_floor:
# Last percentiles and the number of all
# percentiles are even
- label = 'percentile'
+ label = "percentile"
else:
- label = '_nolegend_'
+ label = "_nolegend_"
ax.fill_between(
idxs,
data[:, col, offset + i],
data[:, col, -i - 1],
label=label,
- **_plot_common_kwargs)
- ax.set_xlabel('iteration')
+ **_plot_common_kwargs,
+ )
+ ax.set_xlabel("iteration")
for ax in axes.ravel():
ax.legend()
@@ -407,8 +451,12 @@ def save_plot_using_module(
ax.grid()
ax.set_axisbelow(True)
- writer(self._filename, manager.out, (fig, plt), # type: ignore
- savefun=matplotlib_savefun)
+ writer(
+ self._filename,
+ manager.out,
+ (fig, plt), # type: ignore
+ savefun=matplotlib_savefun,
+ )
def finalize(self, manager: ExtensionsManagerProtocol) -> None:
if self._writer is not None:
diff --git a/pytorch_pfn_extras/training/manager.py b/pytorch_pfn_extras/training/manager.py
index 28dbc7f57..3ba206e86 100644
--- a/pytorch_pfn_extras/training/manager.py
+++ b/pytorch_pfn_extras/training/manager.py
@@ -1,30 +1,36 @@
import collections
import contextlib
import copy
-from pytorch_pfn_extras.profiler import record
import time
+import warnings
from typing import (
- Any, Dict, Generator, Mapping, Optional, Sequence, Union, TYPE_CHECKING
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Generator,
+ Mapping,
+ Optional,
+ Sequence,
+ Union,
)
-import warnings
-
-import torch
import pytorch_pfn_extras
-from pytorch_pfn_extras import writing
-from pytorch_pfn_extras import reporting
+import torch
+from pytorch_pfn_extras import reporting, writing
+from pytorch_pfn_extras.profiler import record
+from pytorch_pfn_extras.training import _util as util_module
from pytorch_pfn_extras.training import extension as extension_module
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training import _util as util_module
from pytorch_pfn_extras.training._transform_model import (
- default_transform_model, _TransformModel,
+ _TransformModel,
+ default_transform_model,
)
_get_time = time.perf_counter
class _ManagerProxy:
- def __init__(self, manager: '_BaseExtensionsManager') -> None:
+ def __init__(self, manager: "_BaseExtensionsManager") -> None:
self._manager = manager
@property
@@ -99,26 +105,27 @@ class _BaseExtensionsManager:
"""
def __init__(
- self,
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- optimizers: Union[torch.optim.Optimizer,
- Mapping[str, torch.optim.Optimizer]],
- max_epochs: int,
- extensions: Optional[Sequence['extension_module.ExtensionLike']],
- out_dir: str,
- writer: Optional[writing.Writer],
- stop_trigger: 'trigger_module.TriggerLike' = None,
- transform_model: _TransformModel = default_transform_model,
- enable_profile: bool = False,
+ self,
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ optimizers: Union[
+ torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]
+ ],
+ max_epochs: int,
+ extensions: Optional[Sequence["extension_module.ExtensionLike"]],
+ out_dir: str,
+ writer: Optional[writing.Writer],
+ stop_trigger: "trigger_module.TriggerLike" = None,
+ transform_model: _TransformModel = default_transform_model,
+ enable_profile: bool = False,
) -> None:
if extensions is None:
extensions = []
if stop_trigger is None:
self._stop_trigger = trigger_module.get_trigger(
- (max_epochs, 'epoch'))
+ (max_epochs, "epoch")
+ )
else:
- self._stop_trigger = trigger_module.get_trigger(
- stop_trigger)
+ self._stop_trigger = trigger_module.get_trigger(stop_trigger)
if writer is None:
writer = writing.SimpleWriter(out_dir=out_dir)
# triggers are stateful, so we need to make a copy for internal use
@@ -142,22 +149,22 @@ def __init__(
else:
if not isinstance(models, torch.nn.Module):
raise ValueError(
- 'model must be an instance of dict or toch.nn.Module')
- self._models = {'main': models}
+ "model must be an instance of dict or toch.nn.Module"
+ )
+ self._models = {"main": models}
if isinstance(optimizers, collections.abc.Mapping):
self._optimizers = optimizers
else:
# TODO(ecastill) Optimizer type is not checked because of tests
# using mocks and other classes
- self._optimizers = {'main': optimizers}
+ self._optimizers = {"main": optimizers}
for name, model in self._models.items():
# TODO we should not initialize extensions at this point
# so, we cannot use `self.models`
model = self._transform_model(name, model)
self.reporter.add_observer(name, model)
- self.reporter.add_observers(
- name, model.named_modules())
+ self.reporter.add_observers(name, model.named_modules())
self._finalized = False
self.max_epochs = max_epochs
self._start_iteration = 0
@@ -165,7 +172,8 @@ def __init__(
self._start_time: Optional[float] = None
self.__iters_per_epoch: Optional[int] = None
self._extensions: Dict[
- str, extension_module.ExtensionEntry] = collections.OrderedDict()
+ str, extension_module.ExtensionEntry
+ ] = collections.OrderedDict()
for ext in extensions:
self.extend(ext)
@@ -189,16 +197,18 @@ def _check_model_available(self) -> None:
if self._model_available:
return
raise RuntimeError(
- 'Models cannot be accessed from extensions in this iteration. '
- 'Extensions accessing models must declare '
- '`needs_model_state = True` attribute.')
+ "Models cannot be accessed from extensions in this iteration. "
+ "Extensions accessing models must declare "
+ "`needs_model_state = True` attribute."
+ )
@property
def models(self) -> Mapping[str, torch.nn.Module]:
self.start_extensions()
self._check_model_available()
- models = {k: self._transform_model(k, v)
- for k, v in self._models.items()}
+ models = {
+ k: self._transform_model(k, v) for k, v in self._models.items()
+ }
return models
@property
@@ -216,7 +226,8 @@ def optimizers(self) -> Mapping[str, torch.optim.Optimizer]:
def elapsed_time(self) -> float:
if self._start_time is None:
raise RuntimeError(
- 'Unavailable until the initial run_iteration call.')
+ "Unavailable until the initial run_iteration call."
+ )
return _get_time() - self._start_time
@property
@@ -251,22 +262,21 @@ def out(self) -> str:
return self.writer.out_dir
@property
- def updater(self) -> '_BaseExtensionsManager':
+ def updater(self) -> "_BaseExtensionsManager":
warnings.warn(
- 'The `updater` attribute has been deprecated in v0.3.0.'
- ' Use `iteration`, `epoch`, and `epoch_detail` attributes in'
- ' `ExtensionsManager` instead of attributes under `updater`.'
- ' You may also need to update the filename template specified to'
- ' snapshot extensions (e.g., from '
- '`snapshot_iter_{.updater.iteration}` to'
- ' `snapshot_iter_{.iteration}`).', DeprecationWarning)
+ "The `updater` attribute has been deprecated in v0.3.0."
+ " Use `iteration`, `epoch`, and `epoch_detail` attributes in"
+ " `ExtensionsManager` instead of attributes under `updater`."
+ " You may also need to update the filename template specified to"
+ " snapshot extensions (e.g., from "
+ "`snapshot_iter_{.updater.iteration}` to"
+ " `snapshot_iter_{.iteration}`).",
+ DeprecationWarning,
+ )
return self
def _prepare_for_training(
- self,
- start_iteration: int,
- start_execution: int,
- iters_per_epoch: int
+ self, start_iteration: int, start_execution: int, iters_per_epoch: int
) -> None:
self.iteration = start_iteration
self.execution = start_execution
@@ -281,15 +291,14 @@ def start_extensions(self) -> None:
exts = self._extensions
extension_order = sorted(
- exts.keys(),
- key=lambda name: exts[name].priority, reverse=True)
- self.extensions = [(name, exts[name])
- for name in extension_order]
+ exts.keys(), key=lambda name: exts[name].priority, reverse=True
+ )
+ self.extensions = [(name, exts[name]) for name in extension_order]
# invoke initializer of each extension
for _, entry in self.extensions:
initializer = entry.extension.initialize
- finished = getattr(entry.trigger, 'finished', False)
+ finished = getattr(entry.trigger, "finished", False)
if not finished:
initializer(self)
@@ -301,17 +310,17 @@ def start_extensions(self) -> None:
entry.extension(self)
def extend(
- self,
- extension: Union[
- 'extension_module.ExtensionLike',
- 'extension_module.ExtensionEntry',
- ],
- name: Optional[str] = None,
- trigger: 'trigger_module.TriggerLike' = None,
- priority: Optional[int] = None,
- *,
- call_before_training: Optional[bool] = None,
- **kwargs: Dict[str, Any],
+ self,
+ extension: Union[
+ "extension_module.ExtensionLike",
+ "extension_module.ExtensionEntry",
+ ],
+ name: Optional[str] = None,
+ trigger: "trigger_module.TriggerLike" = None,
+ priority: Optional[int] = None,
+ *,
+ call_before_training: Optional[bool] = None,
+ **kwargs: Dict[str, Any],
) -> None:
"""Registers an extension to the manager.
@@ -353,7 +362,8 @@ def extend(
"""
if self._start_extensions_called:
raise RuntimeError(
- 'extend called after the extensions were initialized')
+ "extend called after the extensions were initialized"
+ )
if isinstance(extension, extension_module.ExtensionEntry):
entry = extension
@@ -373,7 +383,7 @@ def extend(
ordinal = 0
while modified_name in self._extensions:
ordinal += 1
- modified_name = '%s_%d' % (name, ordinal)
+ modified_name = "%s_%d" % (name, ordinal)
entry._update_name(modified_name)
self._extensions[modified_name] = entry
@@ -392,7 +402,7 @@ def get_extension(self, name: str) -> extension_module.Extension:
if name in extensions:
return extensions[name].extension
else:
- raise ValueError('extension %s not found' % name)
+ raise ValueError("extension %s not found" % name)
def _run_on_error(self, exc: Exception) -> None:
if not self._run_on_error_called:
@@ -428,15 +438,15 @@ def run_extensions(self) -> None:
to_run.append((name, entry.extension))
else:
with record(
- f'pytorch_pfn_extras.training.ExtensionsManager'
- f'.run_extensions:{name}',
+ f"pytorch_pfn_extras.training.ExtensionsManager"
+ f".run_extensions:{name}",
enable=self._enable_profile,
):
entry.extension(self)
for name, extension in to_run:
with record(
- f'pytorch_pfn_extras.training.ExtensionsManager'
- f'.run_extensions:{name}',
+ f"pytorch_pfn_extras.training.ExtensionsManager"
+ f".run_extensions:{name}",
enable=self._enable_profile,
):
extension(self)
@@ -452,9 +462,10 @@ def needs_model_state(self, iteration: Optional[int] = None) -> bool:
# is increased just right before calling extensions
iteration = self.iteration + 1
for _, entry in self._extensions.items():
- needs_state = getattr(entry.extension, 'needs_model_state', False)
- if (needs_state and entry.trigger.may_fire(
- iteration, self._iters_per_epoch)):
+ needs_state = getattr(entry.extension, "needs_model_state", False)
+ if needs_state and entry.trigger.may_fire(
+ iteration, self._iters_per_epoch
+ ):
return True
return False
@@ -469,57 +480,66 @@ def _finalize_extensions(self) -> None:
pass
def state_dict(
- self,
+ self,
) -> Dict[str, Any]:
to_save: Dict[str, Any] = {}
- to_save['_start_iteration'] = self.iteration
- to_save['_start_execution'] = self.execution
+ to_save["_start_iteration"] = self.iteration
+ to_save["_start_execution"] = self.execution
# Use self.models to apply transform_model
- to_save['models'] = {
- name: self.models[name].state_dict()
- for name in self.models}
- to_save['optimizers'] = {name: self._optimizers[name].state_dict()
- for name in self._optimizers}
- to_save['extensions'] = {name: self._extensions[name].state_dict()
- for name in self._extensions}
- to_save['ppe_version'] = pytorch_pfn_extras.__version__
+ to_save["models"] = {
+ name: self.models[name].state_dict() for name in self.models
+ }
+ to_save["optimizers"] = {
+ name: self._optimizers[name].state_dict()
+ for name in self._optimizers
+ }
+ to_save["extensions"] = {
+ name: self._extensions[name].state_dict()
+ for name in self._extensions
+ }
+ to_save["ppe_version"] = pytorch_pfn_extras.__version__
return to_save
def _check_snapshot_version(self, ppe_version: Optional[str]) -> None:
must_warn = ppe_version is None or (
- ppe_version != pytorch_pfn_extras.__version__)
+ ppe_version != pytorch_pfn_extras.__version__
+ )
if not must_warn:
return
- msg = ('You are trying to load a snapshot file taken using a different '
- 'PPE version.\n')
+ msg = (
+ "You are trying to load a snapshot file taken using a different "
+ "PPE version.\n"
+ )
if ppe_version is not None:
- msg += (f'Snapshot taken with PPE {ppe_version} but '
- f'currently using PPE {pytorch_pfn_extras.__version__}')
+ msg += (
+ f"Snapshot taken with PPE {ppe_version} but "
+ f"currently using PPE {pytorch_pfn_extras.__version__}"
+ )
warnings.warn(msg)
def load_state_dict(
- self,
- to_load: Dict[str, Any],
+ self,
+ to_load: Dict[str, Any],
) -> None:
- self._check_snapshot_version(to_load.get('ppe_version', None))
- self._start_iteration = to_load['_start_iteration']
+ self._check_snapshot_version(to_load.get("ppe_version", None))
+ self._start_iteration = to_load["_start_iteration"]
self.iteration = self._start_iteration
- self._start_execution = to_load.get('_start_execution', self.iteration)
+ self._start_execution = to_load.get("_start_execution", self.iteration)
self.execution = self._start_execution
for name in self.models:
# TODO(ecastill) map_loc when loading the model and DDP check
# Use self.models to apply transform_model
- self.models[name].load_state_dict(to_load['models'][name])
+ self.models[name].load_state_dict(to_load["models"][name])
for name in self._optimizers:
- self._optimizers[name].load_state_dict(to_load['optimizers'][name])
+ self._optimizers[name].load_state_dict(to_load["optimizers"][name])
for name in self._extensions:
- self._extensions[name].load_state_dict(to_load['extensions'][name])
+ self._extensions[name].load_state_dict(to_load["extensions"][name])
class ExtensionsManager(_BaseExtensionsManager):
@@ -545,33 +565,43 @@ class ExtensionsManager(_BaseExtensionsManager):
"""
def __init__(
- self,
- models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
- optimizers: Union[torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]],
- max_epochs: int,
- *,
- iters_per_epoch: int,
- extensions: Optional[Sequence['extension_module.ExtensionLike']] = None,
- out_dir: str = 'result',
- stop_trigger: 'trigger_module.TriggerLike' = None,
- writer: Optional[writing.Writer] = None,
- transform_model: _TransformModel = lambda n, x: x,
- enable_profile: bool = False,
+ self,
+ models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
+ optimizers: Union[
+ torch.optim.Optimizer, Dict[str, torch.optim.Optimizer]
+ ],
+ max_epochs: int,
+ *,
+ iters_per_epoch: int,
+ extensions: Optional[Sequence["extension_module.ExtensionLike"]] = None,
+ out_dir: str = "result",
+ stop_trigger: "trigger_module.TriggerLike" = None,
+ writer: Optional[writing.Writer] = None,
+ transform_model: _TransformModel = lambda n, x: x,
+ enable_profile: bool = False,
) -> None:
super().__init__(
- models, optimizers, max_epochs, extensions,
- out_dir, writer, stop_trigger, transform_model, enable_profile)
+ models,
+ optimizers,
+ max_epochs,
+ extensions,
+ out_dir,
+ writer,
+ stop_trigger,
+ transform_model,
+ enable_profile,
+ )
if iters_per_epoch < 1:
raise ValueError(
- 'iters_per_epoch must be an integer >= 1 ({} given)'.format(
- iters_per_epoch))
+ "iters_per_epoch must be an integer >= 1 ({} given)".format(
+ iters_per_epoch
+ )
+ )
self._prepare_for_training(0, 0, iters_per_epoch)
@contextlib.contextmanager
def run_iteration(
- self,
- *,
- step_optimizers: Optional[Sequence[str]] = None
+ self, *, step_optimizers: Optional[Sequence[str]] = None
) -> Generator[None, None, None]:
"""Context manager to run an iteration.
@@ -583,7 +613,7 @@ def run_iteration(
to call `zero_grad` and `step`
"""
if self._finalized:
- raise RuntimeError('Attempted to run a finalized manager')
+ raise RuntimeError("Attempted to run a finalized manager")
if self._start_time is None:
self._start_time = _get_time()
self.start_extensions()
@@ -639,36 +669,47 @@ class IgniteExtensionsManager(_BaseExtensionsManager):
enable_profile (bool): Flag to enable/disable profiling of iterations.
Default is `False`.
"""
+
def __init__(
- self,
- engine: 'ignite.engine.Engine',
- models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
- optimizers: Union[torch.optim.Optimizer,
- Mapping[str, torch.optim.Optimizer]],
- max_epochs: int,
- *,
- extensions: Optional[Sequence['extension_module.ExtensionLike']] = None,
- out_dir: str = 'result',
- writer: Optional[writing.Writer] = None,
- enable_profile: bool = False,
+ self,
+ engine: "ignite.engine.Engine",
+ models: Union[torch.nn.Module, Mapping[str, torch.nn.Module]],
+ optimizers: Union[
+ torch.optim.Optimizer, Mapping[str, torch.optim.Optimizer]
+ ],
+ max_epochs: int,
+ *,
+ extensions: Optional[Sequence["extension_module.ExtensionLike"]] = None,
+ out_dir: str = "result",
+ writer: Optional[writing.Writer] = None,
+ enable_profile: bool = False,
) -> None:
import ignite
+
if not isinstance(engine, ignite.engine.Engine):
raise TypeError("Argument 'engine' must be of ignite.Engine type.")
- if (util_module._get_ignite_version(ignite.__version__)
- < util_module._get_ignite_version('0.3.0')):
- raise ImportError('Ignite version found {}. '
- 'Required is >=0.3.0'.format(ignite.__version__))
+ if util_module._get_ignite_version(
+ ignite.__version__
+ ) < util_module._get_ignite_version("0.3.0"):
+ raise ImportError(
+ "Ignite version found {}. "
+ "Required is >=0.3.0".format(ignite.__version__)
+ )
super().__init__(
- models, optimizers, max_epochs, extensions, out_dir, writer,
- enable_profile=enable_profile)
+ models,
+ optimizers,
+ max_epochs,
+ extensions,
+ out_dir,
+ writer,
+ enable_profile=enable_profile,
+ )
self.engine = engine
self._start_epoch = 0 # Used to correctly restore snapshots
self.set_ignite_handlers()
def set_ignite_handlers(self) -> None:
- from ignite.engine import Engine
- from ignite.engine import Events
+ from ignite.engine import Engine, Events
# Set a handler that sets the reporter scope on every iteration
@self.engine.on(Events.ITERATION_STARTED)
@@ -689,7 +730,8 @@ def set_training_started(engine: Engine) -> None:
self._start_time = _get_time()
# Initialize manager again after all state is restored
self._prepare_for_training(
- start_iteration, start_iteration, iters_per_epoch)
+ start_iteration, start_iteration, iters_per_epoch
+ )
# Make all the next
# handlers to be executed after user defined ones
@@ -710,13 +752,13 @@ def set_extensions_cleanup(engine: Engine) -> None:
def state_dict(self) -> Dict[str, Any]:
to_save = super().state_dict()
- to_save['_epoch_length'] = self.engine.state.epoch_length
- to_save['_start_iteration'] = self.engine.state.iteration
+ to_save["_epoch_length"] = self.engine.state.epoch_length
+ to_save["_start_iteration"] = self.engine.state.iteration
return to_save
def load_state_dict(
- self,
- to_load: Dict[str, Any],
+ self,
+ to_load: Dict[str, Any],
) -> None:
super().load_state_dict(to_load)
- self._start_epoch = self._start_iteration // to_load['_epoch_length']
+ self._start_epoch = self._start_iteration // to_load["_epoch_length"]
diff --git a/pytorch_pfn_extras/training/metrics.py b/pytorch_pfn_extras/training/metrics.py
index 7e7321663..4139e0ca3 100644
--- a/pytorch_pfn_extras/training/metrics.py
+++ b/pytorch_pfn_extras/training/metrics.py
@@ -2,7 +2,6 @@
import torch
-
Batch = Dict[str, torch.Tensor]
MetricType = Callable[[Batch, Batch], Batch]
@@ -17,12 +16,13 @@ class AccuracyMetric:
.. seealso:
:func:`pytorch_pfn_extras.engine.create_evaluator`
"""
+
def __init__(self, label_key: str, output_key: str) -> None:
self.label_key = label_key
self.output_key = output_key
def _preprocess_input(
- self, batch: Batch, out: Batch
+ self, batch: Batch, out: Batch
) -> Tuple[torch.Tensor, int, torch.Tensor]:
labels = batch[self.label_key].cpu()
n_output = labels.shape[0]
diff --git a/pytorch_pfn_extras/training/trigger.py b/pytorch_pfn_extras/training/trigger.py
index cbfb573e9..5a7bd42de 100644
--- a/pytorch_pfn_extras/training/trigger.py
+++ b/pytorch_pfn_extras/training/trigger.py
@@ -1,8 +1,12 @@
from pytorch_pfn_extras.training._trigger_util import Trigger # NOQA
-from pytorch_pfn_extras.training._trigger_util import get_trigger # NOQA
from pytorch_pfn_extras.training._trigger_util import TriggerFunc # NOQA
from pytorch_pfn_extras.training._trigger_util import TriggerLike # NOQA
+from pytorch_pfn_extras.training._trigger_util import get_trigger # NOQA
+from pytorch_pfn_extras.training._trigger_util import ( # NOQA
+ _never_fire_trigger,
+)
# For backward compatibility
-from pytorch_pfn_extras.training.triggers.interval_trigger import IntervalTrigger # NOQA
-from pytorch_pfn_extras.training._trigger_util import _never_fire_trigger # NOQA
+from pytorch_pfn_extras.training.triggers.interval_trigger import ( # NOQA
+ IntervalTrigger,
+)
diff --git a/pytorch_pfn_extras/training/triggers/__init__.py b/pytorch_pfn_extras/training/triggers/__init__.py
index 074978b92..99602a71c 100644
--- a/pytorch_pfn_extras/training/triggers/__init__.py
+++ b/pytorch_pfn_extras/training/triggers/__init__.py
@@ -1,9 +1,21 @@
# import classes and functions
-from pytorch_pfn_extras.training.triggers.early_stopping_trigger import EarlyStoppingTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.interval_trigger import IntervalTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.manual_schedule_trigger import ManualScheduleTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.minmax_value_trigger import BestValueTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.minmax_value_trigger import MaxValueTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.minmax_value_trigger import MinValueTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.once_trigger import OnceTrigger # NOQA
-from pytorch_pfn_extras.training.triggers.time_trigger import TimeTrigger # NOQA
+from pytorch_pfn_extras.training.triggers.early_stopping_trigger import ( # NOQA
+ EarlyStoppingTrigger,
+)
+from pytorch_pfn_extras.training.triggers.interval_trigger import ( # NOQA
+ IntervalTrigger,
+)
+from pytorch_pfn_extras.training.triggers.manual_schedule_trigger import ( # NOQA
+ ManualScheduleTrigger,
+)
+from pytorch_pfn_extras.training.triggers.minmax_value_trigger import ( # NOQA
+ BestValueTrigger,
+ MaxValueTrigger,
+ MinValueTrigger,
+)
+from pytorch_pfn_extras.training.triggers.once_trigger import ( # NOQA
+ OnceTrigger,
+)
+from pytorch_pfn_extras.training.triggers.time_trigger import ( # NOQA
+ TimeTrigger,
+)
diff --git a/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py b/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py
index 855a4ebec..7262e2e43 100644
--- a/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py
+++ b/pytorch_pfn_extras/training/triggers/early_stopping_trigger.py
@@ -1,15 +1,18 @@
import operator
-from typing import Tuple, TYPE_CHECKING
import warnings
+from typing import TYPE_CHECKING, Tuple
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import trigger
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
if TYPE_CHECKING:
- from pytorch_pfn_extras.training._trigger_util import TriggerLike
- from pytorch_pfn_extras.training._trigger_util import UnitLiteral
+ from pytorch_pfn_extras.training._trigger_util import (
+ TriggerLike,
+ UnitLiteral,
+ )
class EarlyStoppingTrigger(trigger.Trigger):
@@ -46,13 +49,13 @@ class EarlyStoppingTrigger(trigger.Trigger):
"""
def __init__(
- self,
- check_trigger: 'TriggerLike' = (1, 'epoch'),
- monitor: str = 'main/loss',
- patience: int = 3,
- mode: str = 'auto',
- verbose: bool = False,
- max_trigger: Tuple[int, 'UnitLiteral'] = (100, 'epoch'),
+ self,
+ check_trigger: "TriggerLike" = (1, "epoch"),
+ monitor: str = "main/loss",
+ patience: int = 3,
+ mode: str = "auto",
+ verbose: bool = False,
+ max_trigger: Tuple[int, "UnitLiteral"] = (100, "epoch"),
) -> None:
self.count = 0
self.patience = patience
@@ -64,14 +67,14 @@ def __init__(
self._init_summary()
- if mode == 'max':
+ if mode == "max":
self._compare = operator.gt
- elif mode == 'min':
+ elif mode == "min":
self._compare = operator.lt
else:
- if 'accuracy' in monitor:
+ if "accuracy" in monitor:
self._compare = operator.gt
else:
@@ -79,13 +82,13 @@ def __init__(
if self._compare == operator.gt:
if verbose:
- print('early stopping: operator is greater')
- self.best = float('-inf')
+ print("early stopping: operator is greater")
+ self.best = float("-inf")
else:
if verbose:
- print('early stopping: operator is less')
- self.best = float('inf')
+ print("early stopping: operator is less")
+ self.best = float("inf")
def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
"""Decides whether the training loop should be stopped.
@@ -114,7 +117,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
return False
if self.monitor not in observation.keys():
- warnings.warn('{} is not in observation'.format(self.monitor))
+ warnings.warn("{} is not in observation".format(self.monitor))
return False
stat = self._summary.compute_mean()
@@ -130,7 +133,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
if self._stop_condition():
if self.verbose:
- print('Epoch {}: early stopping'.format(manager.epoch))
+ print("Epoch {}: early stopping".format(manager.epoch))
return True
return False
diff --git a/pytorch_pfn_extras/training/triggers/interval_trigger.py b/pytorch_pfn_extras/training/triggers/interval_trigger.py
index 1a1847a81..58dcad773 100644
--- a/pytorch_pfn_extras/training/triggers/interval_trigger.py
+++ b/pytorch_pfn_extras/training/triggers/interval_trigger.py
@@ -1,8 +1,9 @@
-from typing import Tuple, TYPE_CHECKING
+from typing import TYPE_CHECKING, Tuple
from pytorch_pfn_extras.training import trigger
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
if TYPE_CHECKING:
from pytorch_pfn_extras.training._trigger_util import UnitLiteral
@@ -30,10 +31,11 @@ class IntervalTrigger(trigger.Trigger):
"""
- def __init__(self, period: float, unit: 'UnitLiteral'):
- if unit not in ('epoch', 'iteration'):
+ def __init__(self, period: float, unit: "UnitLiteral"):
+ if unit not in ("epoch", "iteration"):
raise ValueError(
- 'Trigger unit must be either \'epoch\' or \'iteration\'.')
+ "Trigger unit must be either 'epoch' or 'iteration'."
+ )
self.period = period
self.unit = unit
@@ -64,17 +66,17 @@ def __str__(self) -> str:
Returns:
str: IntervalTrigger(, '')
"""
- return '{}({}, \'{}\')'.format(
+ return "{}({}, '{}')".format(
self.__class__.__name__, self.period, self.unit
)
def may_fire(self, iteration: int, epoch_length: int) -> bool:
if iteration == 0:
- if self.unit == 'epoch':
+ if self.unit == "epoch":
return epoch_length == 0
else:
return self.period == 0
- if self.unit == 'epoch':
+ if self.unit == "epoch":
fire = (iteration % (epoch_length * self.period)) == 0
else:
fire = (iteration % self.period) == 0
diff --git a/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py b/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py
index c9662d1f0..4914755ee 100644
--- a/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py
+++ b/pytorch_pfn_extras/training/triggers/manual_schedule_trigger.py
@@ -1,8 +1,9 @@
-from typing import Sequence, Union, TYPE_CHECKING
+from typing import TYPE_CHECKING, Sequence, Union
from pytorch_pfn_extras.training import trigger
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
if TYPE_CHECKING:
from pytorch_pfn_extras.training._trigger_util import UnitLiteral
@@ -27,12 +28,15 @@ class ManualScheduleTrigger(trigger.Trigger):
"""
- def __init__(self, points: Union[float, Sequence[float]], unit: 'UnitLiteral'):
- if unit not in ('epoch', 'iteration'):
+ def __init__(
+ self, points: Union[float, Sequence[float]], unit: "UnitLiteral"
+ ):
+ if unit not in ("epoch", "iteration"):
raise ValueError(
- 'Trigger unit must be either \'epoch\' or \'iteration\'.')
+ "Trigger unit must be either 'epoch' or 'iteration'."
+ )
- self.points = (points if isinstance(points, list) else [points])
+ self.points = points if isinstance(points, list) else [points]
self.unit = unit
def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
@@ -53,9 +57,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
return fire
def may_fire(self, iteration: int, epoch_length: int) -> bool:
- if self.unit == 'epoch':
- fire = any(
- int(p * epoch_length) == iteration for p in self.points)
+ if self.unit == "epoch":
+ fire = any(int(p * epoch_length) == iteration for p in self.points)
else:
fire = any(p == iteration for p in self.points)
return fire
diff --git a/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py b/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py
index d0e74902c..dcf286f5f 100644
--- a/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py
+++ b/pytorch_pfn_extras/training/triggers/minmax_value_trigger.py
@@ -1,9 +1,10 @@
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from pytorch_pfn_extras import reporting
from pytorch_pfn_extras.training import trigger as trigger_module
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
-
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
if TYPE_CHECKING:
from pytorch_pfn_extras.training._trigger_util import TriggerLike
@@ -26,10 +27,10 @@ class BestValueTrigger(trigger_module.Trigger):
"""
def __init__(
- self,
- key: str,
- compare: Callable[[float, float], bool],
- trigger: 'TriggerLike' = (1, 'epoch'),
+ self,
+ key: str,
+ compare: Callable[[float, float], bool],
+ trigger: "TriggerLike" = (1, "epoch"),
) -> None:
self._key = key
self._best_value: Optional[float] = None
@@ -63,9 +64,12 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
stats = summary.compute_mean()
if key not in stats:
- raise KeyError('Key "{}" not found in the observation '
- '(current available keys are {})'
- .format(key, list(stats.keys())))
+ raise KeyError(
+ 'Key "{}" not found in the observation '
+ "(current available keys are {})".format(
+ key, list(stats.keys())
+ )
+ )
value = float(stats[key]) # copy to CPU
self._init_summary()
@@ -79,15 +83,17 @@ def _init_summary(self) -> None:
self._summary = reporting.DictSummary()
def state_dict(self) -> Dict[str, Any]:
- state = {'interval_trigger': self._interval_trigger.state_dict(),
- '_summary': self._summary.state_dict(),
- '_best_value': self._best_value}
+ state = {
+ "interval_trigger": self._interval_trigger.state_dict(),
+ "_summary": self._summary.state_dict(),
+ "_best_value": self._best_value,
+ }
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- self._interval_trigger.load_state_dict(to_load['interval_trigger'])
- self._summary.load_state_dict(to_load['_summary'])
- self._best_value = to_load['_best_value']
+ self._interval_trigger.load_state_dict(to_load["interval_trigger"])
+ self._summary.load_state_dict(to_load["_summary"])
+ self._best_value = to_load["_best_value"]
def may_fire(self, iteration: int, epoch_length: int) -> bool:
return self._interval_trigger.may_fire(iteration, epoch_length)
@@ -110,9 +116,10 @@ class MaxValueTrigger(BestValueTrigger):
"""
- def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')):
+ def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")):
super().__init__(
- key, lambda max_value, new_value: new_value > max_value, trigger)
+ key, lambda max_value, new_value: new_value > max_value, trigger
+ )
class MinValueTrigger(BestValueTrigger):
@@ -132,6 +139,7 @@ class MinValueTrigger(BestValueTrigger):
"""
- def __init__(self, key: str, trigger: 'TriggerLike' = (1, 'epoch')):
+ def __init__(self, key: str, trigger: "TriggerLike" = (1, "epoch")):
super().__init__(
- key, lambda min_value, new_value: new_value < min_value, trigger)
+ key, lambda min_value, new_value: new_value < min_value, trigger
+ )
diff --git a/pytorch_pfn_extras/training/triggers/once_trigger.py b/pytorch_pfn_extras/training/triggers/once_trigger.py
index 538ab3ee0..011ff0481 100644
--- a/pytorch_pfn_extras/training/triggers/once_trigger.py
+++ b/pytorch_pfn_extras/training/triggers/once_trigger.py
@@ -1,7 +1,9 @@
from typing import Any, Dict
from pytorch_pfn_extras.training import trigger
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
class OnceTrigger(trigger.Trigger):
@@ -38,11 +40,11 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
return fire
def state_dict(self) -> Dict[str, Any]:
- state = {'_flag_first': self._flag_first}
+ state = {"_flag_first": self._flag_first}
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- self._flag_first = to_load['_flag_first']
+ self._flag_first = to_load["_flag_first"]
def may_fire(self, iteration: int, epoch_length: int) -> bool:
return not (self._flag_first or self._flag_resumed)
diff --git a/pytorch_pfn_extras/training/triggers/time_trigger.py b/pytorch_pfn_extras/training/triggers/time_trigger.py
index 4d74100e4..e4006ac05 100644
--- a/pytorch_pfn_extras/training/triggers/time_trigger.py
+++ b/pytorch_pfn_extras/training/triggers/time_trigger.py
@@ -1,7 +1,9 @@
from typing import Any, Dict
from pytorch_pfn_extras.training import trigger
-from pytorch_pfn_extras.training._manager_protocol import ExtensionsManagerProtocol
+from pytorch_pfn_extras.training._manager_protocol import (
+ ExtensionsManagerProtocol,
+)
class TimeTrigger(trigger.Trigger):
@@ -27,8 +29,8 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> bool:
return False
def state_dict(self) -> Dict[str, Any]:
- state = {'next_time': self._next_time}
+ state = {"next_time": self._next_time}
return state
def load_state_dict(self, to_load: Dict[str, Any]) -> None:
- self._next_time = to_load['next_time']
+ self._next_time = to_load["next_time"]
diff --git a/pytorch_pfn_extras/utils/checkpoint.py b/pytorch_pfn_extras/utils/checkpoint.py
index 685843183..0af2efdd4 100644
--- a/pytorch_pfn_extras/utils/checkpoint.py
+++ b/pytorch_pfn_extras/utils/checkpoint.py
@@ -16,20 +16,22 @@ class _CheckpointFunction(torch.utils.checkpoint.CheckpointFunction):
can help deal with incorrect values in the BatchNormalization
persistent parameters.
"""
+
@staticmethod
def forward( # type: ignore[override]
- ctx: Any,
- run_function: Any,
- preserve_rng_state: bool,
- *args: Any,
+ ctx: Any,
+ run_function: Any,
+ preserve_rng_state: bool,
+ *args: Any,
) -> Any:
_patch_bn_momentum(run_function)
return super(_CheckpointFunction, _CheckpointFunction).forward(
- ctx, run_function, preserve_rng_state, *args)
+ ctx, run_function, preserve_rng_state, *args
+ )
def _patch_bn_momentum(module: torch.nn.Module) -> None:
- if not hasattr(module, '_bn_momentum_patched'):
+ if not hasattr(module, "_bn_momentum_patched"):
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
@@ -43,8 +45,9 @@ def _patch_bn_momentum(module: torch.nn.Module) -> None:
def checkpoint(function: torch.nn.Module, *args: Any, **kwargs: Any) -> Any:
# Hack to mix *args with **kwargs in a python 2.7-compliant way
- preserve = kwargs.pop('preserve_rng_state', True)
+ preserve = kwargs.pop("preserve_rng_state", True)
if kwargs:
raise ValueError(
- 'Unexpected keyword arguments: ' + ','.join(arg for arg in kwargs))
+ "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
+ )
return _CheckpointFunction.apply(function, preserve, *args) # type: ignore[no-untyped-call]
diff --git a/pytorch_pfn_extras/utils/comparer.py b/pytorch_pfn_extras/utils/comparer.py
index aeaa091d1..30afe1650 100644
--- a/pytorch_pfn_extras/utils/comparer.py
+++ b/pytorch_pfn_extras/utils/comparer.py
@@ -1,40 +1,46 @@
import collections
+import concurrent.futures
import pathlib
import re
import threading
import weakref
-import concurrent.futures
from typing import (
- Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence,
- Tuple, Type, Union,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
)
+import pytorch_pfn_extras
import torch.nn
import torch.testing
-
-import pytorch_pfn_extras
from pytorch_pfn_extras import handler as _handler_module
from pytorch_pfn_extras.handler import _logic
-from pytorch_pfn_extras.training import _trainer
+from pytorch_pfn_extras.training import _evaluator, _trainer
from pytorch_pfn_extras.training import manager as manager_module
-from pytorch_pfn_extras.training import _evaluator
from pytorch_pfn_extras.training import trigger as trigger_module
-
_thread_local = threading.local()
_intermediate_prefix = "intermedaite:"
class _ComparableHandler(_handler_module.BaseHandler):
def __init__(
- self,
- handler: _handler_module.BaseHandler,
- name: str,
- get_target_cb: Any,
- compare_cb: Any,
- trigger: Optional[trigger_module.Trigger] = None,
- *,
- dir: Optional[str] = None,
+ self,
+ handler: _handler_module.BaseHandler,
+ name: str,
+ get_target_cb: Any,
+ compare_cb: Any,
+ trigger: Optional[trigger_module.Trigger] = None,
+ *,
+ dir: Optional[str] = None,
) -> None:
self._handler = handler
self._get_target_cb = get_target_cb
@@ -52,7 +58,8 @@ def train_setup(self, trainer: _trainer._Trainer, loader: Any) -> None:
self._handler.train_setup(trainer, loader)
def train_epoch_begin(
- self, trainer: _trainer._Trainer, loader: Iterable[Any]) -> None:
+ self, trainer: _trainer._Trainer, loader: Iterable[Any]
+ ) -> None:
self._handler.train_epoch_begin(trainer, loader)
_thread_local.handler = self
self._epoch = trainer.epoch
@@ -61,31 +68,33 @@ def train_epoch_end(self, trainer: _trainer._Trainer) -> None:
self._handler.train_epoch_end(trainer)
def train_validation_begin(
- self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator) -> None:
+ self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator
+ ) -> None:
self._handler.train_validation_begin(trainer, evaluator)
_thread_local.handler = self
def train_validation_end(
- self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator) -> None:
+ self, trainer: _trainer._Trainer, evaluator: _evaluator.Evaluator
+ ) -> None:
self._handler.train_validation_end(trainer, evaluator)
def train_step(
- self,
- trainer: _trainer._Trainer,
- batch_idx: int,
- batch: Any,
- complete_fn: Callable[[int, Any], None],
+ self,
+ trainer: _trainer._Trainer,
+ batch_idx: int,
+ batch: Any,
+ complete_fn: Callable[[int, Any], None],
) -> None:
self._batch_idx = batch_idx
self._reset_intermediate_values()
self._handler.train_step(trainer, batch_idx, batch, complete_fn)
def train_post_step(
- self,
- trainer: _trainer.Trainer,
- batch_idx: int,
- batch: Any,
- outputs: Any
+ self,
+ trainer: _trainer.Trainer,
+ batch_idx: int,
+ batch: Any,
+ outputs: Any,
) -> None:
class _ManagerProxy(manager_module._ManagerProxy):
@property
@@ -100,7 +109,8 @@ def iteration(self) -> int:
self._compare(trainer, batch_idx, outputs)
def eval_setup(
- self, evaluator: _evaluator.Evaluator, loader: Iterable[Any]) -> None:
+ self, evaluator: _evaluator.Evaluator, loader: Iterable[Any]
+ ) -> None:
self._handler.eval_setup(evaluator, loader)
def eval_loop_begin(self, evaluator: _evaluator.Evaluator) -> None:
@@ -108,11 +118,11 @@ def eval_loop_begin(self, evaluator: _evaluator.Evaluator) -> None:
_thread_local.handler = self
def eval_step(
- self,
- evaluator: _evaluator.Evaluator,
- batch_idx: int,
- batch: Any,
- complete_fn: Callable[[int, Any], None],
+ self,
+ evaluator: _evaluator.Evaluator,
+ batch_idx: int,
+ batch: Any,
+ complete_fn: Callable[[int, Any], None],
) -> None:
self._batch_idx = batch_idx
self._reset_intermediate_values()
@@ -122,11 +132,11 @@ def eval_loop_end(self, evaluator: _evaluator.Evaluator) -> None:
self._handler.eval_loop_end(evaluator)
def eval_post_step(
- self,
- evaluator: _evaluator.Evaluator,
- batch_idx: int,
- batch: Any,
- outputs: Any,
+ self,
+ evaluator: _evaluator.Evaluator,
+ batch_idx: int,
+ batch: Any,
+ outputs: Any,
) -> None:
self._handler.eval_post_step(evaluator, batch_idx, batch, outputs)
self._compare(evaluator, batch_idx, outputs)
@@ -146,12 +156,13 @@ def _add_intermediate_value(self, name: str, value: torch.Tensor) -> None:
values = self._intermediate_values[self._batch_idx]
counts = self._intermediate_counts[self._batch_idx]
value = value.detach()
- count = counts.get('name', 0)
- counts['name'] = count + 1
- name = _intermediate_prefix + name + f'_{count}'
+ count = counts.get("name", 0)
+ counts["name"] = count + 1
+ name = _intermediate_prefix + name + f"_{count}"
# Defer importing onnx for performance and to avoid linkage issues.
import pytorch_pfn_extras.onnx
+
if pytorch_pfn_extras.onnx.available:
pytorch_pfn_extras.onnx.as_output(name, value)
values[name] = value
@@ -159,10 +170,12 @@ def _add_intermediate_value(self, name: str, value: torch.Tensor) -> None:
def _overwrite_handler(engine: Any, *args: Any, **kwargs: Any) -> None:
engine.handler = _ComparableHandler(engine.handler, *args, **kwargs)
- evaluator = getattr(engine, 'evaluator', None)
+ evaluator = getattr(engine, "evaluator", None)
if evaluator is not None:
# For trainer with evaluator
- evaluator.handler = _ComparableHandler(evaluator.handler, *args, **kwargs)
+ evaluator.handler = _ComparableHandler(
+ evaluator.handler, *args, **kwargs
+ )
_CompareFn = Callable[[str, str, str, Any, Any], None]
@@ -170,9 +183,9 @@ def _overwrite_handler(engine: Any, *args: Any, **kwargs: Any) -> None:
def get_default_comparer(
- rtol: float = 1e-04,
- atol: float = 0,
- equal_nan: bool = True,
+ rtol: float = 1e-04,
+ atol: float = 0,
+ equal_nan: bool = True,
) -> _CompareFn:
"""Creates default comparer function.
@@ -184,9 +197,10 @@ def get_default_comparer(
atol (float): Absolute tolerance.
equal_nan (bool): If ``True``, NaNs will be ignored.
"""
+
def compare_fn(
- backend1: str, backend2: str, name: str,
- val1: Any, val2: Any) -> None:
+ backend1: str, backend2: str, name: str, val1: Any, val2: Any
+ ) -> None:
# TODO select the device where
# the tensors will be compared?
if isinstance(val1, torch.Tensor):
@@ -208,17 +222,17 @@ def compare_fn(
def _compare_targets(
- compare_fn: _CompareFn,
- targets: Dict[str, Any],
- baseline: Optional[str],
- batch_idx: int,
+ compare_fn: _CompareFn,
+ targets: Dict[str, Any],
+ baseline: Optional[str],
+ batch_idx: int,
) -> None:
names = list(targets.keys())
if baseline is None:
baseline = names[0]
keys = sorted(targets[baseline].keys())
- err_msg = ''
+ err_msg = ""
for backend in set(names) - set([baseline]):
for val_name in keys:
out1 = targets[baseline][val_name]
@@ -228,18 +242,19 @@ def _compare_targets(
except AssertionError as e:
err_msg += (
f"Comparing '{baseline}' and '{backend}' in '{val_name}'\n"
- f"{str(e)}\n")
+ f"{str(e)}\n"
+ )
if err_msg:
- raise AssertionError(f'Batch: {batch_idx}\n' + str(err_msg))
+ raise AssertionError(f"Batch: {batch_idx}\n" + str(err_msg))
class _ComparerBase:
def __init__(
- self,
- engines: Mapping[str, _Engine],
- *,
- compare_fn: _CompareFn = _default_comparer,
- concurrency: Optional[int] = None,
+ self,
+ engines: Mapping[str, _Engine],
+ *,
+ compare_fn: _CompareFn = _default_comparer,
+ concurrency: Optional[int] = None,
) -> None:
e_type = type(next(iter(engines.values())))
if e_type not in (
@@ -257,16 +272,19 @@ def __init__(
self.compare_fn = compare_fn
self._finalized = False
self._semaphore = threading.Semaphore(
- len(engines) if concurrency is None else concurrency)
+ len(engines) if concurrency is None else concurrency
+ )
self.targets: Dict[str, Dict[str, Any]] = {}
self._iters: Dict[str, int] = {}
# engines must be a dict
for name, engine in engines.items():
- _overwrite_handler(engine, name, self._get_target, self.compare_targets)
+ _overwrite_handler(
+ engine, name, self._get_target, self.compare_targets
+ )
def _assert_incompatible_trigger(self, condition: bool) -> None:
if not condition:
- raise ValueError('Engines have different triggers.')
+ raise ValueError("Engines have different triggers.")
def run_engine(self, engine: _Engine, loaders: Any) -> None:
try:
@@ -307,26 +325,27 @@ def compare(self, loaders: Any, n_iters: Optional[int] = None) -> None:
) as executor:
futures = []
for name, engine in self.engines.items():
- futures.append(executor.submit(
- self.run_engine, engine, loaders[name]))
+ futures.append(
+ executor.submit(self.run_engine, engine, loaders[name])
+ )
for future in concurrent.futures.as_completed(futures):
future.result()
def _get_target(
- self,
- handle: _handler_module.BaseHandler,
- engine: _Engine,
- batch_idx: int,
- outputs: Any,
+ self,
+ handle: _handler_module.BaseHandler,
+ engine: _Engine,
+ batch_idx: int,
+ outputs: Any,
) -> Dict[str, Any]:
- raise NotImplementedError('Comparers must override _get_target')
+ raise NotImplementedError("Comparers must override _get_target")
def compare_targets(
- self,
- name: str,
- engine: _Engine,
- batch_idx: int,
- target: Dict[str, Any],
+ self,
+ name: str,
+ engine: _Engine,
+ batch_idx: int,
+ target: Dict[str, Any],
) -> None:
self._iters[name] += 1
if (self.n_iters is None) or (self._iters[name] % self.n_iters == 0):
@@ -335,7 +354,9 @@ def compare_targets(
self.targets[name] = target
if len(self.targets.keys()) == len(self.engines.keys()):
# all outputs have been filled, lets compare and reset
- _compare_targets(self.compare_fn, self.targets, None, batch_idx)
+ _compare_targets(
+ self.compare_fn, self.targets, None, batch_idx
+ )
self.targets = {}
self._assert_incompatible_trigger(not self._finalized)
# Excplicitly synchronize
@@ -346,12 +367,12 @@ def compare_targets(
class OutputsComparer(_ComparerBase):
def __init__(
- self,
- engines: Mapping[str, _Engine],
- to_compare_keys: Optional[Sequence[str]] = None,
- *,
- compare_fn: _CompareFn = _default_comparer,
- concurrency: Optional[int] = None,
+ self,
+ engines: Mapping[str, _Engine],
+ to_compare_keys: Optional[Sequence[str]] = None,
+ *,
+ compare_fn: _CompareFn = _default_comparer,
+ concurrency: Optional[int] = None,
) -> None:
"""A class for comparison of iteration outputs.
@@ -378,15 +399,17 @@ def __init__(
>>> comp.compare({"cpu": loader, "gpu": loader}])
"""
# If to_compare_key is None, then we compare all
- super().__init__(engines, compare_fn=compare_fn, concurrency=concurrency)
+ super().__init__(
+ engines, compare_fn=compare_fn, concurrency=concurrency
+ )
self.to_compare_keys = to_compare_keys
def _get_target(
- self,
- handle: _handler_module.BaseHandler,
- engine: _Engine,
- batch_idx: int,
- outputs: Any,
+ self,
+ handle: _handler_module.BaseHandler,
+ engine: _Engine,
+ batch_idx: int,
+ outputs: Any,
) -> Dict[str, Any]:
keys = (
self.to_compare_keys
@@ -398,12 +421,12 @@ def _get_target(
class ModelComparer(_ComparerBase):
def __init__(
- self,
- engines: Mapping[str, _Engine],
- to_compare_keys: Optional[Sequence[str]] = None,
- *,
- compare_fn: _CompareFn = _default_comparer,
- concurrency: Optional[int] = None,
+ self,
+ engines: Mapping[str, _Engine],
+ to_compare_keys: Optional[Sequence[str]] = None,
+ *,
+ compare_fn: _CompareFn = _default_comparer,
+ concurrency: Optional[int] = None,
):
"""A class for comparison of iteration model parameters.
@@ -430,7 +453,9 @@ def __init__(
>>> comp.compare({"cpu": loader, "gpu": loader}])
"""
# If to_compare_key is None, then we compare all
- super().__init__(engines, compare_fn=compare_fn, concurrency=concurrency)
+ super().__init__(
+ engines, compare_fn=compare_fn, concurrency=concurrency
+ )
self.to_compare_keys = to_compare_keys
self._preprocessed_keys: Optional[List[str]] = None
@@ -447,16 +472,17 @@ def _preprocess_keys(self, sdict: Dict[str, Any]) -> None:
matched = True
if not matched:
raise ValueError(
- f'didnt find a match for {tc_k} in the model')
+ f"didnt find a match for {tc_k} in the model"
+ )
def _get_target(
- self,
- handle: _handler_module.BaseHandler,
- engine: _Engine,
- batch_idx: int,
- outputs: Any,
+ self,
+ handle: _handler_module.BaseHandler,
+ engine: _Engine,
+ batch_idx: int,
+ outputs: Any,
) -> Dict[str, Any]:
- sdict = engine.models['main'].state_dict()
+ sdict = engine.models["main"].state_dict()
if self._preprocessed_keys is None:
self._preprocess_keys(sdict)
assert self._preprocessed_keys is not None
@@ -465,9 +491,10 @@ def _get_target(
# New comparer interface
+
def _filter(
- keys: Union[bool, str, Sequence[str]],
- get_dict: Callable[[], Dict[str, Any]],
+ keys: Union[bool, str, Sequence[str]],
+ get_dict: Callable[[], Dict[str, Any]],
) -> Dict[str, Any]:
if keys is False:
return {}
@@ -487,23 +514,22 @@ def _filter(
ret[sd_k] = sdict[sd_k]
break
else:
- raise ValueError(f'didnt find a match for {tc_k} in the model')
+ raise ValueError(f"didnt find a match for {tc_k} in the model")
return ret
- raise ValueError(f'Unsupported type: {type(keys)}')
+ raise ValueError(f"Unsupported type: {type(keys)}")
class Comparer:
-
def __init__(
- self,
- *,
- trigger: Optional[trigger_module.TriggerLike] = None,
- compare_fn: _CompareFn = _default_comparer,
- concurrency: Optional[int] = None,
- outputs: Union[bool, str, Sequence[str]] = True,
- params: Union[bool, str, Sequence[str]] = False,
- baseline: Optional[str] = None,
+ self,
+ *,
+ trigger: Optional[trigger_module.TriggerLike] = None,
+ compare_fn: _CompareFn = _default_comparer,
+ concurrency: Optional[int] = None,
+ outputs: Union[bool, str, Sequence[str]] = True,
+ params: Union[bool, str, Sequence[str]] = False,
+ baseline: Optional[str] = None,
) -> None:
"""A class for comparison of iteration outputs and model parameters.
@@ -559,21 +585,24 @@ def __init__(
self._trigger = trigger_module.get_trigger(trigger)
def _get_target(
- self,
- handler: _ComparableHandler,
- engine: _Engine,
- batch_idx: int,
- outputs: Dict[str, Any],
+ self,
+ handler: _ComparableHandler,
+ engine: _Engine,
+ batch_idx: int,
+ outputs: Dict[str, Any],
) -> Dict[str, Any]:
targets = {}
outputs = _filter(self._output_keys, lambda: outputs)
- targets.update({
- k if k.startswith(_intermediate_prefix) else 'output:' + k: v
- for k, v in outputs.items()})
+ targets.update(
+ {
+ k if k.startswith(_intermediate_prefix) else "output:" + k: v
+ for k, v in outputs.items()
+ }
+ )
targets.update(handler._intermediate_values.pop(batch_idx))
- params = _filter(self._param_keys, engine.models['main'].state_dict)
- targets.update({'param:' + k: v for k, v in params.items()})
+ params = _filter(self._param_keys, engine.models["main"].state_dict)
+ targets.update({"param:" + k: v for k, v in params.items()})
return targets
def _assert_incompatible_trigger(self, condition: bool) -> None:
@@ -581,22 +610,23 @@ def _assert_incompatible_trigger(self, condition: bool) -> None:
raise ValueError("Engines have different triggers.")
def _get_filename(
- self, engine: _Engine, handler_name: str, batch_idx: int) -> str:
- name = f'dump_{self._count:08}'
- name += '_' + type(engine).__name__
+ self, engine: _Engine, handler_name: str, batch_idx: int
+ ) -> str:
+ name = f"dump_{self._count:08}"
+ name += "_" + type(engine).__name__
orig_engine, _, _ = self._engines[handler_name]
epoch = orig_engine.handler._epoch # type: ignore
if epoch is not None:
- name += f'_epoch_{epoch}'
- name += f'_iter_{batch_idx}'
+ name += f"_epoch_{epoch}"
+ name += f"_iter_{batch_idx}"
return name
def _compare_targets(
- self,
- name: str,
- engine: _Engine,
- batch_idx: int,
- target: Dict[str, Any],
+ self,
+ name: str,
+ engine: _Engine,
+ batch_idx: int,
+ target: Dict[str, Any],
) -> None:
# Save the outputs of this iteration
with self._report_lock:
@@ -604,7 +634,8 @@ def _compare_targets(
if len(self._targets.keys()) == len(self._engines.keys()):
# all outputs have been filled, lets compare and reset
_compare_targets(
- self._compare_fn, self._targets, self._baseline, batch_idx)
+ self._compare_fn, self._targets, self._baseline, batch_idx
+ )
self._targets = {}
self._count += 1
self._assert_incompatible_trigger(not self._finalized)
@@ -619,11 +650,11 @@ def _compare_targets(
self._semaphore.acquire()
def add_engine(
- self,
- name: str,
- engine: _Engine,
- *args: Any,
- **kwargs: Any,
+ self,
+ name: str,
+ engine: _Engine,
+ *args: Any,
+ **kwargs: Any,
) -> None:
"""Add an engine to compare variables.
@@ -649,7 +680,8 @@ def add_engine(
raise ValueError(f"Engine named {name} already registered")
_overwrite_handler(
- engine, name, self._get_target, self._compare_targets, self._trigger)
+ engine, name, self._get_target, self._compare_targets, self._trigger
+ )
self._engines[name] = engine, args, kwargs
@@ -666,18 +698,20 @@ def add_dump(self, name: str, dir: str) -> None:
self._engines[name] = engine, (), {}
def _dump_targets(
- self,
- name: str,
- engine: _Engine,
- batch_idx: int,
- target: Dict[str, Any],
+ self,
+ name: str,
+ engine: _Engine,
+ batch_idx: int,
+ target: Dict[str, Any],
) -> None:
name = self._get_filename(engine, name, batch_idx)
assert isinstance(engine.handler, _ComparableHandler)
- torch.save(target, f'{engine.handler._dir}/{name}')
+ torch.save(target, f"{engine.handler._dir}/{name}")
self._count += 1
- def dump(self, engine: _Engine, dir: str, *args: Any, **kwargs: Any) -> None:
+ def dump(
+ self, engine: _Engine, dir: str, *args: Any, **kwargs: Any
+ ) -> None:
"""Add an engine to compare variables.
Args:
@@ -691,10 +725,15 @@ def dump(self, engine: _Engine, dir: str, *args: Any, **kwargs: Any) -> None:
pathlib.Path(dir).mkdir(parents=True, exist_ok=True)
self._count = 0
- name = '__dump'
+ name = "__dump"
_overwrite_handler(
- engine, name, self._get_target, self._dump_targets,
- self._trigger, dir=dir)
+ engine,
+ name,
+ self._get_target,
+ self._dump_targets,
+ self._trigger,
+ dir=dir,
+ )
self._engines[name] = engine, args, {}
engine.run(*args, **kwargs)
@@ -719,14 +758,16 @@ def _run_engine(self, engine: _Engine, args: Any, kwargs: Any) -> None:
self._semaphore.release()
def compare(self) -> None:
- """Compares outputs.
- """
+ """Compares outputs."""
self._count = 0
n_workers = len(self._engines)
self._barrier = threading.Barrier(n_workers)
self._semaphore = threading.Semaphore(
- n_workers if self._concurrency is None else self._concurrency)
- with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
+ n_workers if self._concurrency is None else self._concurrency
+ )
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=n_workers
+ ) as executor:
futures = []
for _, (engine, args, kwargs) in self._engines.items():
futures.append(executor.submit(self._run_engine, engine, args, kwargs)) # type: ignore[arg-type]
@@ -735,7 +776,7 @@ def compare(self) -> None:
def intermediate_value(name: str, value: torch.Tensor) -> None:
- if not hasattr(_thread_local, 'handler'):
+ if not hasattr(_thread_local, "handler"):
return
_thread_local.handler._add_intermediate_value(name, value)
@@ -751,10 +792,11 @@ def run(self) -> None:
assert comparer is not None
for path in sorted(pathlib.Path(self._dir).iterdir()):
filename = path.name
- if filename.startswith('dump_'):
+ if filename.startswith("dump_"):
target = torch.load(path) # type: ignore[no-untyped-call]
- iter_str = '_iter_'
+ iter_str = "_iter_"
pos = filename.find(iter_str)
- batch_idx = int(filename[pos + len(iter_str):])
+ batch_idx = int(filename[pos + len(iter_str) :])
comparer._compare_targets(
- self.name, None, batch_idx, target) # type: ignore
+ self.name, None, batch_idx, target # type: ignore
+ )
diff --git a/pytorch_pfn_extras/writing/__init__.py b/pytorch_pfn_extras/writing/__init__.py
index 3ae2e3f9f..4e8411ef2 100644
--- a/pytorch_pfn_extras/writing/__init__.py
+++ b/pytorch_pfn_extras/writing/__init__.py
@@ -1,9 +1,11 @@
-from pytorch_pfn_extras.writing._writer_base import Writer # NOQA
-from pytorch_pfn_extras.writing._writer_base import StandardWriter # NOQA
-from pytorch_pfn_extras.writing._simple_writer import SimpleWriter # NOQA
-from pytorch_pfn_extras.writing._parallel_writer import ThreadWriter # NOQA
from pytorch_pfn_extras.writing._parallel_writer import ProcessWriter # NOQA
+from pytorch_pfn_extras.writing._parallel_writer import ThreadWriter # NOQA
+from pytorch_pfn_extras.writing._queue_writer import ProcessQueueWriter # NOQA
from pytorch_pfn_extras.writing._queue_writer import QueueWriter # NOQA
from pytorch_pfn_extras.writing._queue_writer import ThreadQueueWriter # NOQA
-from pytorch_pfn_extras.writing._queue_writer import ProcessQueueWriter # NOQA
-from pytorch_pfn_extras.writing._tensorboard_writer import TensorBoardWriter # NOQA
+from pytorch_pfn_extras.writing._simple_writer import SimpleWriter # NOQA
+from pytorch_pfn_extras.writing._tensorboard_writer import ( # NOQA
+ TensorBoardWriter,
+)
+from pytorch_pfn_extras.writing._writer_base import StandardWriter # NOQA
+from pytorch_pfn_extras.writing._writer_base import Writer # NOQA
diff --git a/pytorch_pfn_extras/writing/_parallel_writer.py b/pytorch_pfn_extras/writing/_parallel_writer.py
index e40f37b3a..2c784e15a 100644
--- a/pytorch_pfn_extras/writing/_parallel_writer.py
+++ b/pytorch_pfn_extras/writing/_parallel_writer.py
@@ -1,12 +1,14 @@
import multiprocessing
-import threading
import sys
+import threading
from typing import Any, Optional
import torch
-
from pytorch_pfn_extras.writing._writer_base import (
- StandardWriter, _TargetType, _SaveFun, _FileSystem,
+ StandardWriter,
+ _FileSystem,
+ _SaveFun,
+ _TargetType,
)
@@ -21,47 +23,51 @@ class ThreadWriter(StandardWriter[threading.Thread]):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- **kwds: Any
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ **kwds: Any,
) -> None:
super().__init__(savefun=savefun, fs=fs, out_dir=out_dir, **kwds)
def _save_with_exitcode(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- savefun: _SaveFun,
- append: bool,
- **savefun_kwargs: Any,
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ savefun: _SaveFun,
+ append: bool,
+ **savefun_kwargs: Any,
) -> None:
try:
self.save(
- filename, out_dir, target, savefun, append, **savefun_kwargs)
+ filename, out_dir, target, savefun, append, **savefun_kwargs
+ )
except Exception as e:
thread = threading.current_thread()
thread.exitcode = -1 # type: ignore[attr-defined]
print(
f'Error: ThreadWriter failed in thread "{thread.name}": '
- f'{type(e).__name__}: {str(e)}', file=sys.stderr)
+ f"{type(e).__name__}: {str(e)}",
+ file=sys.stderr,
+ )
def create_worker(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False,
- **savefun_kwargs: Any,
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
+ **savefun_kwargs: Any,
) -> threading.Thread:
return threading.Thread(
target=self._save_with_exitcode,
args=(filename, out_dir, target, savefun, append),
- kwargs=savefun_kwargs)
+ kwargs=savefun_kwargs,
+ )
class ProcessWriter(StandardWriter[multiprocessing.Process]):
@@ -80,25 +86,26 @@ class ProcessWriter(StandardWriter[multiprocessing.Process]):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- **kwds: Any,
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ **kwds: Any,
) -> None:
super().__init__(savefun=savefun, fs=fs, out_dir=out_dir, **kwds)
def create_worker(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False,
- **savefun_kwargs: Any,
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
+ **savefun_kwargs: Any,
) -> multiprocessing.Process:
return multiprocessing.Process(
target=self.save,
args=(filename, out_dir, target, savefun, append),
- kwargs=savefun_kwargs)
+ kwargs=savefun_kwargs,
+ )
diff --git a/pytorch_pfn_extras/writing/_queue_writer.py b/pytorch_pfn_extras/writing/_queue_writer.py
index 885ad2d4a..b4348570b 100644
--- a/pytorch_pfn_extras/writing/_queue_writer.py
+++ b/pytorch_pfn_extras/writing/_queue_writer.py
@@ -4,15 +4,19 @@
from typing import Generic, Optional, Tuple
import torch
-
+from pytorch_pfn_extras.writing._simple_writer import SimpleWriter
from pytorch_pfn_extras.writing._writer_base import (
- Writer, _TargetType, _SaveFun, _TaskFun, _Worker, _FileSystem,
+ Writer,
+ _FileSystem,
+ _SaveFun,
+ _TargetType,
+ _TaskFun,
+ _Worker,
)
-from pytorch_pfn_extras.writing._simple_writer import SimpleWriter
-
-_QueUnit = Optional[Tuple[
- _TaskFun, str, str, _TargetType, Optional[_SaveFun], bool]]
+_QueUnit = Optional[
+ Tuple[_TaskFun, str, str, _TargetType, Optional[_SaveFun], bool]
+]
class QueueWriter(Writer, Generic[_Worker]):
@@ -41,11 +45,11 @@ class QueueWriter(Writer, Generic[_Worker]):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- task: Optional[_TaskFun] = None,
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ task: Optional[_TaskFun] = None,
) -> None:
super().__init__(fs=fs, out_dir=out_dir)
self._started = False
@@ -60,28 +64,29 @@ def __init__(
self._started = True
def __call__(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
) -> None:
assert not self._finalized
self._queue.put(
- (self._task, filename, out_dir, target, savefun, append))
+ (self._task, filename, out_dir, target, savefun, append)
+ )
def create_task(self, savefun: _SaveFun) -> _TaskFun:
return SimpleWriter(savefun=savefun)
- def create_queue(self) -> 'queue.Queue[_QueUnit]':
+ def create_queue(self) -> "queue.Queue[_QueUnit]":
raise NotImplementedError
- def create_consumer(self, q: 'queue.Queue[_QueUnit]') -> _Worker:
+ def create_consumer(self, q: "queue.Queue[_QueUnit]") -> _Worker:
raise NotImplementedError
- def consume(self, q: 'queue.Queue[_QueUnit]') -> None:
+ def consume(self, q: "queue.Queue[_QueUnit]") -> None:
while True:
task = q.get()
if task is None:
@@ -89,7 +94,8 @@ def consume(self, q: 'queue.Queue[_QueUnit]') -> None:
return
else:
task[0](
- task[1], task[2], task[3], savefun=task[4], append=task[5])
+ task[1], task[2], task[3], savefun=task[4], append=task[5]
+ )
q.task_done()
def finalize(self) -> None:
@@ -116,18 +122,18 @@ class ThreadQueueWriter(QueueWriter[threading.Thread]):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- task: Optional[_TaskFun] = None
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ task: Optional[_TaskFun] = None,
) -> None:
super().__init__(savefun=savefun, fs=fs, task=task, out_dir=out_dir)
- def create_queue(self) -> 'queue.Queue[_QueUnit]':
+ def create_queue(self) -> "queue.Queue[_QueUnit]":
return queue.Queue()
- def create_consumer(self, q: 'queue.Queue[_QueUnit]') -> threading.Thread:
+ def create_consumer(self, q: "queue.Queue[_QueUnit]") -> threading.Thread:
return threading.Thread(target=self.consume, args=(q,))
@@ -149,16 +155,18 @@ class ProcessQueueWriter(QueueWriter[multiprocessing.Process]):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- task: Optional[_TaskFun] = None
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ task: Optional[_TaskFun] = None,
) -> None:
super().__init__(savefun=savefun, fs=fs, out_dir=out_dir, task=task)
- def create_queue(self) -> 'queue.Queue[_QueUnit]':
+ def create_queue(self) -> "queue.Queue[_QueUnit]":
return multiprocessing.JoinableQueue()
- def create_consumer(self, q: 'queue.Queue[_QueUnit]') -> multiprocessing.Process:
+ def create_consumer(
+ self, q: "queue.Queue[_QueUnit]"
+ ) -> multiprocessing.Process:
return multiprocessing.Process(target=self.consume, args=(q,))
diff --git a/pytorch_pfn_extras/writing/_simple_writer.py b/pytorch_pfn_extras/writing/_simple_writer.py
index 9d14a3f4e..b26db45a7 100644
--- a/pytorch_pfn_extras/writing/_simple_writer.py
+++ b/pytorch_pfn_extras/writing/_simple_writer.py
@@ -1,9 +1,11 @@
from typing import Any, Optional
import torch
-
from pytorch_pfn_extras.writing._writer_base import (
- Writer, _TargetType, _SaveFun, _FileSystem
+ Writer,
+ _FileSystem,
+ _SaveFun,
+ _TargetType,
)
@@ -29,24 +31,24 @@ class SimpleWriter(Writer):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- **kwds: Any,
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ **kwds: Any,
) -> None:
super().__init__(fs=fs, out_dir=out_dir)
self._savefun = savefun
self._kwds = kwds
def __call__(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
) -> None:
if savefun is None:
savefun = self._savefun
diff --git a/pytorch_pfn_extras/writing/_tensorboard_writer.py b/pytorch_pfn_extras/writing/_tensorboard_writer.py
index aa99f5d22..8e004e972 100644
--- a/pytorch_pfn_extras/writing/_tensorboard_writer.py
+++ b/pytorch_pfn_extras/writing/_tensorboard_writer.py
@@ -1,13 +1,15 @@
-from typing import Any, KeysView, Optional
import warnings
+from typing import Any, KeysView, Optional
from pytorch_pfn_extras.writing._writer_base import (
- _TargetType, _SaveFun, _FileSystem
+ _FileSystem,
+ _SaveFun,
+ _TargetType,
)
class TensorBoardWriter(object):
- """ Writer that sends statistics to TensorBoard.
+ """Writer that sends statistics to TensorBoard.
This class contains a `torch.utils.tensorboard.SummaryWriter`
object that is used to send the collected statistics to TensorBoard.
@@ -20,38 +22,40 @@ class TensorBoardWriter(object):
stats (list): List of statistic keys.
kwds: Passed as an additional arguments to SummaryWriter.
"""
+
def __init__(
- self,
- savefun: Optional[_SaveFun] = None,
- fs: _FileSystem = None,
- out_dir: str = '',
- stats: Optional[KeysView[str]] = None,
- **kwds: Any
+ self,
+ savefun: Optional[_SaveFun] = None,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ stats: Optional[KeysView[str]] = None,
+ **kwds: Any,
) -> None:
self._writer = None
try:
import torch.utils.tensorboard
except ImportError:
warnings.warn(
- 'tensorboard is unavailable. '
- 'TensorBoardWriter will do nothing.')
+ "tensorboard is unavailable. "
+ "TensorBoardWriter will do nothing."
+ )
return
self._stats = stats
- self._writer = (
- torch.utils.tensorboard.SummaryWriter( # type: ignore[no-untyped-call]
- log_dir=out_dir, **kwds))
+ self._writer = torch.utils.tensorboard.SummaryWriter( # type: ignore[no-untyped-call]
+ log_dir=out_dir, **kwds
+ )
def __del__(self) -> None:
self.finalize()
def __call__(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False,
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
) -> None:
"""Sends the statistics to the TensorBoard.
@@ -71,14 +75,15 @@ def __call__(
stats_cpu = target[-1]
if not isinstance(stats_cpu, dict):
- raise TypeError('target must be dict or list of dicts')
+ raise TypeError("target must be dict or list of dicts")
keys = stats_cpu.keys()
if self._stats is not None:
keys = self._stats # type: ignore[assignment]
for key in keys:
value = stats_cpu[key]
self._writer.add_scalar( # type: ignore[no-untyped-call]
- key, value, stats_cpu['iteration'])
+ key, value, stats_cpu["iteration"]
+ )
def finalize(self) -> None:
if self._writer is not None:
diff --git a/pytorch_pfn_extras/writing/_writer_base.py b/pytorch_pfn_extras/writing/_writer_base.py
index b5d41a908..10f96c58d 100644
--- a/pytorch_pfn_extras/writing/_writer_base.py
+++ b/pytorch_pfn_extras/writing/_writer_base.py
@@ -1,23 +1,32 @@
-import multiprocessing
import io
+import multiprocessing
import os
import shutil
import sys
import threading
import types
from typing import (
- Any, Callable, Generic, IO, Iterator, List, Mapping, Optional, Sequence,
- Type, TypeVar, Union,
+ IO,
+ Any,
+ Callable,
+ Generic,
+ Iterator,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Type,
+ TypeVar,
+ Union,
)
import torch
-
_TargetType = Union[Sequence[Any], Mapping[str, Any]]
_SaveFun = Callable[..., None]
_HookFun = Callable[[], None]
_TaskFun = Callable[..., None]
-_Worker = TypeVar('_Worker', threading.Thread, multiprocessing.Process)
+_Worker = TypeVar("_Worker", threading.Thread, multiprocessing.Process)
_FileSystem = Any
@@ -40,6 +49,7 @@ class _PosixFileSystem(object):
This class currently abstracts POSIX
"""
+
def __init__(self, root: Optional[str] = None) -> None:
if root is None:
self._root = os.getcwd()
@@ -50,10 +60,11 @@ def get_actual_path(self, path: str) -> str:
return os.path.join(self.root, path)
def _wrap_fileobject(
- self,
- file_obj: IO[Any],
- file_path: str,
- *args: Any, **kwargs: Any,
+ self,
+ file_obj: IO[Any],
+ file_path: str,
+ *args: Any,
+ **kwargs: Any,
) -> IO[Any]:
return file_obj
@@ -66,35 +77,51 @@ def root(self, root: str) -> None:
self._root = root
def open(
- self,
- file_path: str,
- mode: str = 'r',
- buffering: int = -1,
- encoding: Optional[str] = None,
- errors: Optional[str] = None,
- newline: Optional[str] = None,
- closefd: bool = True,
- opener: Optional[Callable[[str, int], int]] = None,
+ self,
+ file_path: str,
+ mode: str = "r",
+ buffering: int = -1,
+ encoding: Optional[str] = None,
+ errors: Optional[str] = None,
+ newline: Optional[str] = None,
+ closefd: bool = True,
+ opener: Optional[Callable[[str, int], int]] = None,
) -> IO[Any]:
file_path = self.get_actual_path(file_path)
- file_obj = io.open(file_path, mode,
- buffering, encoding, errors,
- newline, closefd, opener)
+ file_obj = io.open(
+ file_path,
+ mode,
+ buffering,
+ encoding,
+ errors,
+ newline,
+ closefd,
+ opener,
+ )
return self._wrap_fileobject(
- file_obj, file_path, mode, buffering, encoding,
- errors, newline, closefd, opener)
+ file_obj,
+ file_path,
+ mode,
+ buffering,
+ encoding,
+ errors,
+ newline,
+ closefd,
+ opener,
+ )
def list(
- self,
- path_or_prefix: Optional[str] = None,
- recursive: bool = False,
+ self,
+ path_or_prefix: Optional[str] = None,
+ recursive: bool = False,
) -> Iterator[str]:
if path_or_prefix is not None:
path_or_prefix = self.get_actual_path(path_or_prefix)
if recursive:
if path_or_prefix is None:
raise ValueError(
- "'path_or_prefix' must not be none in recursive mode.")
+ "'path_or_prefix' must not be none in recursive mode."
+ )
path_or_prefix = path_or_prefix.rstrip("/")
# plus 1 to include the trailing slash
prefix_end_index = len(path_or_prefix) + 1
@@ -104,7 +131,9 @@ def list(
yield file.name
def _recursive_list(
- self, prefix_end_index: int, path: str,
+ self,
+ prefix_end_index: int,
+ path: str,
) -> Iterator[str]:
path = self.get_actual_path(path)
for file in os.scandir(path):
@@ -119,14 +148,14 @@ def stat(self, path: str) -> _PosixFileStat:
def close(self) -> None:
pass
- def __enter__(self) -> '_PosixFileSystem':
+ def __enter__(self) -> "_PosixFileSystem":
return self
def __exit__(
- self,
- exc_type: Optional[Type[BaseException]],
- exc_value: Optional[BaseException],
- traceback: Optional[types.TracebackType],
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
) -> None:
pass
@@ -135,20 +164,17 @@ def isdir(self, file_path: str) -> bool:
return os.path.isdir(path)
def mkdir(
- self,
- file_path: str,
- mode: int = 0o777,
- *args: Any,
- dir_fd: Optional[int] = None
+ self,
+ file_path: str,
+ mode: int = 0o777,
+ *args: Any,
+ dir_fd: Optional[int] = None,
) -> None:
file_path = self.get_actual_path(file_path)
return os.mkdir(file_path, mode, *args, dir_fd=dir_fd)
def makedirs(
- self,
- file_path: str,
- mode: int = 0o777,
- exist_ok: bool = False
+ self, file_path: str, mode: int = 0o777, exist_ok: bool = False
) -> None:
file_path = self.get_actual_path(file_path)
return os.makedirs(file_path, mode, exist_ok)
@@ -160,9 +186,11 @@ def rename(self, src: str, dst: str) -> None:
try:
return os.replace(src, dst)
except OSError:
- print('Destination {} is a directory '
- 'but source is not'.format(dst),
- file=sys.stderr)
+ print(
+ "Destination {} is a directory "
+ "but source is not".format(dst),
+ file=sys.stderr,
+ )
raise
def remove(self, file_path: str, recursive: bool = False) -> None:
@@ -195,9 +223,9 @@ class Writer:
"""
def __init__(
- self,
- fs: _FileSystem = None,
- out_dir: str = '',
+ self,
+ fs: _FileSystem = None,
+ out_dir: str = "",
) -> None:
self._post_save_hooks: List[_HookFun] = []
self.fs = fs or _PosixFileSystem()
@@ -205,13 +233,13 @@ def __init__(
self._initialized = False
def __call__(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
) -> None:
"""Does the actual writing to the file.
@@ -250,13 +278,13 @@ def finalize(self) -> None:
pass
def save(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- savefun: _SaveFun,
- append: bool,
- **savefun_kwargs: Any,
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ savefun: _SaveFun,
+ append: bool,
+ **savefun_kwargs: Any,
) -> None:
out_dir = self.out_dir
if not self._initialized:
@@ -265,19 +293,19 @@ def save(
dest = os.path.join(out_dir, filename)
if append:
- with self.fs.open(dest, 'ab') as f:
+ with self.fs.open(dest, "ab") as f:
# HDFS does not support overwrite
savefun(target, f, **savefun_kwargs)
else:
# Some filesystems are not compatible with temp folders, etc
# so we rely on raw temp files
- prefix = 'tmp_{}'.format(filename)
+ prefix = "tmp_{}".format(filename)
tmppath = os.path.join(out_dir, prefix)
make_backup = self.fs.exists(dest)
- with self.fs.open(tmppath, 'wb') as f:
+ with self.fs.open(tmppath, "wb") as f:
savefun(target, f, **savefun_kwargs)
if make_backup:
- bak = '{}.bak'.format(dest)
+ bak = "{}.bak".format(dest)
# Check if another backup file exists
# due to some unexpected termination of an earlier
# process
@@ -331,11 +359,11 @@ class StandardWriter(Writer, Generic[_Worker]):
"""
def __init__(
- self,
- savefun: _SaveFun = torch.save,
- fs: _FileSystem = None,
- out_dir: str = '',
- **kwds: Any,
+ self,
+ savefun: _SaveFun = torch.save,
+ fs: _FileSystem = None,
+ out_dir: str = "",
+ **kwds: Any,
) -> None:
super().__init__(fs=fs, out_dir=out_dir)
self._savefun = savefun
@@ -345,13 +373,13 @@ def __init__(
self._finalized = False
def __call__(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
) -> None:
assert not self._finalized
if savefun is None:
@@ -360,21 +388,26 @@ def __call__(
self.finalize()
self._filename = filename
self._worker = self.create_worker(
- filename, out_dir, target,
- savefun=savefun, append=append, **self._kwds)
+ filename,
+ out_dir,
+ target,
+ savefun=savefun,
+ append=append,
+ **self._kwds,
+ )
self._worker.start()
self._started = True
self._finalized = False
def create_worker(
- self,
- filename: str,
- out_dir: str,
- target: _TargetType,
- *,
- savefun: Optional[_SaveFun] = None,
- append: bool = False,
- **savefun_kwargs: Any,
+ self,
+ filename: str,
+ out_dir: str,
+ target: _TargetType,
+ *,
+ savefun: Optional[_SaveFun] = None,
+ append: bool = False,
+ **savefun_kwargs: Any,
) -> _Worker:
"""Creates a worker for the snapshot.
@@ -390,13 +423,13 @@ def finalize(self) -> None:
return
if self._worker is None:
- raise RuntimeError('worker is not created')
+ raise RuntimeError("worker is not created")
try:
if self._started and not self._finalized:
self._worker.join()
- exitcode = getattr(self._worker, 'exitcode', 0)
+ exitcode = getattr(self._worker, "exitcode", 0)
if exitcode != 0:
- raise RuntimeError(f'exit code is non-zero: {exitcode}')
+ raise RuntimeError(f"exit code is non-zero: {exitcode}")
finally:
self._started = False
self._finalized = True
diff --git a/setup.py b/setup.py
index 71b32251a..2868a27a0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,25 +1,25 @@
import os
-import setuptools
+import setuptools
here = os.path.abspath(os.path.dirname(__file__))
# Get __version__ variable
-exec(open(os.path.join(here, 'pytorch_pfn_extras', '_version.py')).read())
+exec(open(os.path.join(here, "pytorch_pfn_extras", "_version.py")).read())
setuptools.setup(
- name='pytorch-pfn-extras',
- version=__version__, # NOQA
- description='Supplementary components to accelerate research and '
- 'development in PyTorch.',
- author='Preferred Networks, Inc.',
- license='MIT License',
- install_requires=['numpy', 'packaging', 'torch', 'typing-extensions>=3.10'],
+ name="pytorch-pfn-extras",
+ version=__version__, # NOQA
+ description="Supplementary components to accelerate research and "
+ "development in PyTorch.",
+ author="Preferred Networks, Inc.",
+ license="MIT License",
+ install_requires=["numpy", "packaging", "torch", "typing-extensions>=3.10"],
extras_require={
- 'test': ['pytest', 'onnxruntime', 'torchvision'],
- 'onnx': ['onnx'],
+ "test": ["pytest", "onnxruntime", "torchvision"],
+ "onnx": ["onnx"],
},
- python_requires='>=3.6.0',
- packages=setuptools.find_packages(exclude=['tests', 'tests.*']),
- package_data={'pytorch_pfn_extras': ['py.typed']},
+ python_requires=">=3.6.0",
+ packages=setuptools.find_packages(exclude=["tests", "tests.*"]),
+ package_data={"pytorch_pfn_extras": ["py.typed"]},
)
diff --git a/stubs/torch/_C/__init__.pyi b/stubs/torch/_C/__init__.pyi
index bfe8ef95a..fb7e05143 100644
--- a/stubs/torch/_C/__init__.pyi
+++ b/stubs/torch/_C/__init__.pyi
@@ -1,34 +1,63 @@
# @generated from torch/_C/__init__.pyi.in
# flake8: noqa
-import torch
-from torch.package import PackageExporter
-from torch import Tensor, inf
-from torch.autograd.graph import Node as _Node
+import builtins
from enum import Enum
from pathlib import Path
from typing import (
- Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
- NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
- Literal, Generic, Set, AnyStr)
+ Any,
+ AnyStr,
+ BinaryIO,
+ Callable,
+ ContextManager,
+ Dict,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Literal,
+ NamedTuple,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+import torch
+from torch import Tensor, inf
+from torch.autograd.graph import Node as _Node
+from torch.package import PackageExporter
+from torch.storage import TypedStorage
from torch.types import (
- _int, _float, _bool, _dtype, _device, _qscheme, _size, _layout, Device, Number, Storage, SymInt, _dispatchkey
+ Device,
+ Number,
+ Storage,
+ SymInt,
+ _bool,
+ _device,
+ _dispatchkey,
+ _dtype,
+ _float,
+ _int,
+ _layout,
+ _qscheme,
+ _size,
)
-from torch.storage import TypedStorage
-
-import builtins
-
-# This module is defined in torch/csrc/Module.cpp
-from . import _nn as _nn
-from . import _onnx as _onnx
-from . import _VariableFunctions as _VariableFunctions
from . import _functorch as _functorch
from . import _lazy as _lazy
from . import _lazy_ts_backend as _lazy_ts_backend
+from . import _nn as _nn
+from . import _onnx as _onnx
+from . import _VariableFunctions as _VariableFunctions
+
+# This module is defined in torch/csrc/Module.cpp
-T = TypeVar('T')
+T = TypeVar("T")
S = TypeVar("S", bound="torch.Tensor")
# Defined in torch/csrc/Device.cpp
@@ -41,7 +70,6 @@ class device:
# THPDevice_pynew
@overload
def __init__(self, device: Union[_device, _int, str]) -> None: ...
-
@overload
def __init__(self, type: str, index: _int) -> None: ...
@@ -49,9 +77,7 @@ class device:
# def __call__(self, func: T) -> T: ...
def __enter__(self) -> "device": ...
-
def __exit__(self, exc_type, exc_val, exc_tb) -> None: ...
-
def __reduce__(self) -> Tuple[Any, ...]: ... # THPDevice_reduce
# Defined in torch/csrc/Stream.cpp
@@ -60,7 +86,7 @@ class Stream:
device_index: _int
device_type: _int
- device: device # The device of the stream
+ device: device # The device of the stream
...
@@ -70,10 +96,8 @@ class Size(Tuple[_int, ...]):
@overload # type: ignore[override]
def __getitem__(self: Size, key: _int) -> _int: ...
-
@overload
def __getitem__(self: Size, key: slice) -> Size: ...
-
def numel(self: Size) -> _int: ...
...
@@ -107,7 +131,6 @@ class finfo:
@overload
def __init__(self, dtype: _dtype) -> None: ...
-
@overload
def __init__(self) -> None: ...
@@ -139,21 +162,20 @@ quint4x2: dtype = ...
quint2x4: dtype = ...
# Defined in torch/csrc/Layout.cpp
-class layout:
- ...
+class layout: ...
# Defined in torch/csrc/utils/disable_torch_function.cpp
def DisableTorchFunction(): ...
def DisableTorchFunctionSubclass(): ...
# Defined in torch/csrc/utils/tensor_layouts.cpp
-strided : layout = ...
-sparse_coo : layout = ...
-sparse_csr : layout = ...
-sparse_csc : layout = ...
-sparse_bsr : layout = ...
-sparse_bsc : layout = ...
-_mkldnn : layout = ...
+strided: layout = ...
+sparse_coo: layout = ...
+sparse_csr: layout = ...
+sparse_csc: layout = ...
+sparse_bsr: layout = ...
+sparse_bsc: layout = ...
+_mkldnn: layout = ...
# Defined in torch/csrc/MemoryFormat.cpp
class memory_format: ...
@@ -175,45 +197,42 @@ per_channel_symmetric: qscheme = ...
per_channel_affine_float_qparams: qscheme = ...
# Defined in torch/csrc/autograd/python_function.cpp
-class _FunctionBase:
- ...
+class _FunctionBase: ...
# Defined in torch/csrc/autograd/python_legacy_variable.cpp
class _LegacyVariableBase(Tensor): # inherits from Tensor to appease mypy
def __init__(
self,
- data: Optional[Tensor]=...,
- requires_grad: Optional[_bool]=...,
- volatile: Optional[_bool]=...,
- _grad_fn: Optional[_FunctionBase]=...
+ data: Optional[Tensor] = ...,
+ requires_grad: Optional[_bool] = ...,
+ volatile: Optional[_bool] = ...,
+ _grad_fn: Optional[_FunctionBase] = ...,
) -> None: ...
# Defined in torch/csrc/jit/python/init.cpp
class IODescriptor: ...
-
class JITException: ...
class Future:
- def __init__(self, devices: List[device]) -> None: ...
- def done(self) -> _bool: ...
- def value(self) -> Any: ...
- def wait(self) -> Any: ...
- def add_done_callback(self, callback: Callable) -> None: ...
- def then(self, callback: Callable) -> Future: ...
- def set_result(self, result: Any) -> None: ...
- def _set_unwrap_func(self, callback: Callable) -> None: ...
+ def __init__(self, devices: List[device]) -> None: ...
+ def done(self) -> _bool: ...
+ def value(self) -> Any: ...
+ def wait(self) -> Any: ...
+ def add_done_callback(self, callback: Callable) -> None: ...
+ def then(self, callback: Callable) -> Future: ...
+ def set_result(self, result: Any) -> None: ...
+ def _set_unwrap_func(self, callback: Callable) -> None: ...
class _Await:
- def __init__(self) -> None: ...
- def fn(self) -> Callable: ...
- def args(self) -> Tuple[Any, ...]: ...
- def is_nowait(self) -> _bool: ...
+ def __init__(self) -> None: ...
+ def fn(self) -> Callable: ...
+ def args(self) -> Tuple[Any, ...]: ...
+ def is_nowait(self) -> _bool: ...
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
-class _MobileOptimizerType:
- ...
+class _MobileOptimizerType: ...
CONV_BN_FUSION: _MobileOptimizerType
INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType
@@ -229,50 +248,66 @@ def _awaitable_wait(aw: _Await) -> Any: ...
def _awaitable_nowait(x: Any) -> _Await: ...
def _collect_all(futures: List[Future]) -> Future: ...
def _set_print_stack_traces_on_fatal_signal(print: _bool) -> None: ...
-
def unify_type_list(types: List[JitType]) -> JitType: ...
-def _freeze_module(module: ScriptModule,
- preserved_attrs: List[str] = [],
- freeze_interfaces: _bool = True,
- preserveParameters: _bool = True) -> ScriptModule: ...
-def _jit_pass_optimize_frozen_graph(Graph, optimize_numerics: _bool = True) -> None: ...
-def _jit_pass_optimize_for_inference(module: 'torch.jit.ScriptModule',
- other_methods: List[str] = []) -> None: ...
+def _freeze_module(
+ module: ScriptModule,
+ preserved_attrs: List[str] = [],
+ freeze_interfaces: _bool = True,
+ preserveParameters: _bool = True,
+) -> ScriptModule: ...
+def _jit_pass_optimize_frozen_graph(
+ Graph, optimize_numerics: _bool = True
+) -> None: ...
+def _jit_pass_optimize_for_inference(
+ module: "torch.jit.ScriptModule", other_methods: List[str] = []
+) -> None: ...
def _jit_pass_fold_frozen_conv_bn(graph: Graph): ...
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
def _jit_pass_concat_frozen_linear(graph: Graph): ...
def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
-def _jit_pass_transpose_frozen_linear(graph:Graph): ...
-def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
-
+def _jit_pass_transpose_frozen_linear(graph: Graph): ...
+def _jit_pass_remove_dropout(module: "torch.jit.ScriptModule"): ...
def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
-def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, Callable, List[Any]]: ...
+def _get_operation_overload(
+ op_name: str, op_overload_name: str
+) -> Tuple[Callable, Callable, List[Any]]: ...
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
-def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
- optimization_blocklist: Set[_MobileOptimizerType],
- preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
-def _clone_module_with_class(module: 'torch.jit.ScriptModule',
- ignored_methods: List[AnyStr],
- ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
-def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
- optimization_blocklist: Set[_MobileOptimizerType],
- preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
-def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
- preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
+def _jit_pass_optimize_for_mobile(
+ module: "torch.jit.ScriptModule",
+ optimization_blocklist: Set[_MobileOptimizerType],
+ preserved_methods: List[AnyStr],
+) -> "torch.jit.ScriptModule": ...
+def _clone_module_with_class(
+ module: "torch.jit.ScriptModule",
+ ignored_methods: List[AnyStr],
+ ignored_attributes: List[AnyStr],
+) -> "torch.jit.ScriptModule": ...
+def _jit_pass_vulkan_optimize_for_mobile(
+ module: "torch.jit.ScriptModule",
+ optimization_blocklist: Set[_MobileOptimizerType],
+ preserved_methods: List[AnyStr],
+) -> "torch.jit.ScriptModule": ...
+def _jit_pass_metal_optimize_for_mobile(
+ module: "torch.jit.ScriptModule", preserved_methods: List[AnyStr]
+) -> "torch.jit.ScriptModule": ...
def _jit_pass_inline(Graph) -> None: ...
def _jit_pass_constant_propagation(Graph) -> None: ...
def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
-def _jit_register_decomposition_for_schema(schema: FunctionSchema, Graph) -> None: ...
+def _jit_register_decomposition_for_schema(
+ schema: FunctionSchema, Graph
+) -> None: ...
def _jit_erase_non_input_shape_information(Graph) -> None: ...
-def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
+def _jit_get_schemas_for_operator(name: str) -> List[FunctionSchema]: ...
def _jit_get_all_schemas() -> List[FunctionSchema]: ...
-def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
+def _jit_check_alias_annotation(
+ g: Graph, args: Tuple[Any, ...], unqualified_op_name: str
+): ...
def _jit_can_fuse_on_cpu() -> _bool: ...
def _jit_can_fuse_on_gpu() -> _bool: ...
def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
@@ -295,36 +330,48 @@ def _jit_cat_wo_conditionals(optimize_cat: _bool): ...
def _jit_opt_conditionals(opt_conds: _bool): ...
def _jit_pass_canonicalize(graph: Graph, keep_unique_names: _bool = True): ...
def _jit_pass_erase_shape_information(graph: Graph): ...
-def _jit_pass_fold_convbn(module: 'torch.jit.ScriptModule'): ...
-def _jit_pass_insert_observers(module: 'torch.jit.ScriptModule',
- method_name: str,
- qconfig_dict: Dict[str, Any],
- inplace: _bool,
- quant_type: _int): ...
-def _jit_pass_insert_quant_dequant(module: 'torch.jit.ScriptModule',
- method_name: str,
- inplace: _bool,
- debug: _bool,
- quant_type: _int): ...
-def _jit_pass_insert_quant_dequant_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
- method_name: str,
- inplace: _bool,
- debug: _bool,
- quant_type: _int): ...
-def _jit_pass_quant_finalize(module: 'torch.jit.ScriptModule',
- quant_type: _int,
- preserved_attrs: Sequence[str]): ...
-def _jit_pass_quant_finalize_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
- quant_type: _int,
- method_name: str): ...
-def _jit_pass_insert_observer_method_for_ondevice_ptq(module: 'torch.jit.ScriptModule',
- method_name: str,
- qconfig_dict: Dict[str, Any],
- inplace: _bool,
- quant_type: _int): ...
+def _jit_pass_fold_convbn(module: "torch.jit.ScriptModule"): ...
+def _jit_pass_insert_observers(
+ module: "torch.jit.ScriptModule",
+ method_name: str,
+ qconfig_dict: Dict[str, Any],
+ inplace: _bool,
+ quant_type: _int,
+): ...
+def _jit_pass_insert_quant_dequant(
+ module: "torch.jit.ScriptModule",
+ method_name: str,
+ inplace: _bool,
+ debug: _bool,
+ quant_type: _int,
+): ...
+def _jit_pass_insert_quant_dequant_for_ondevice_ptq(
+ module: "torch.jit.ScriptModule",
+ method_name: str,
+ inplace: _bool,
+ debug: _bool,
+ quant_type: _int,
+): ...
+def _jit_pass_quant_finalize(
+ module: "torch.jit.ScriptModule",
+ quant_type: _int,
+ preserved_attrs: Sequence[str],
+): ...
+def _jit_pass_quant_finalize_for_ondevice_ptq(
+ module: "torch.jit.ScriptModule", quant_type: _int, method_name: str
+): ...
+def _jit_pass_insert_observer_method_for_ondevice_ptq(
+ module: "torch.jit.ScriptModule",
+ method_name: str,
+ qconfig_dict: Dict[str, Any],
+ inplace: _bool,
+ quant_type: _int,
+): ...
def _jit_set_profiling_executor(profiling_flag: _bool) -> _bool: ...
def _jit_set_profiling_mode(profiling_flag: _bool) -> _bool: ...
-def _jit_set_fusion_strategy(strategy: List[Tuple[str, _int]]) -> List[Tuple[str, _int]]: ...
+def _jit_set_fusion_strategy(
+ strategy: List[Tuple[str, _int]]
+) -> List[Tuple[str, _int]]: ...
def _jit_try_infer_type(obj: Any) -> InferredType: ...
def _jit_get_trigger_value(trigger_name: str) -> _int: ...
@@ -333,23 +380,43 @@ ResolutionCallback = Callable[[str], Callable[..., Any]]
# Defined in torch/csrc/jit/python/script_init.cpp
# and torch/csrc/jit/python/init.cpp
-def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
+def _create_function_from_graph(
+ qualname: str, graph: Graph
+) -> ScriptFunction: ...
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
def _jit_assert_is_instance(obj: Any, type: JitType): ...
def _jit_clear_class_registry() -> None: ...
-def _jit_set_emit_hooks(ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]) -> None: ...
+def _jit_set_emit_hooks(
+ ModuleHook: Optional[Callable], FunctionHook: Optional[Callable]
+) -> None: ...
def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
-def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
-def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
+def _load_for_lite_interpreter(
+ filename: Union[str, Path], map_location: Union[_device, str, None]
+): ...
+def _load_for_lite_interpreter_from_buffer(
+ buffer: BinaryIO, map_location: Union[_device, str, None]
+): ...
def _export_operator_list(module: LiteScriptModule): ...
-def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
+def _quantize_ondevice_ptq_dynamic(
+ module: LiteScriptModule, method_name: str
+): ...
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
-def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...
-def _backport_for_mobile_from_buffer(buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int) -> None: ...
-def _backport_for_mobile_to_buffer(filename_input: Union[str, Path], to_version: _int) -> bytes:...
-def _backport_for_mobile_from_buffer_to_buffer(buffer: BinaryIO, to_version: _int) -> bytes:...
+def _backport_for_mobile(
+ filename_input: Union[str, Path],
+ filename_output: Union[str, Path],
+ to_version: _int,
+) -> None: ...
+def _backport_for_mobile_from_buffer(
+ buffer: BinaryIO, filename_output: Union[str, Path], to_version: _int
+) -> None: ...
+def _backport_for_mobile_to_buffer(
+ filename_input: Union[str, Path], to_version: _int
+) -> bytes: ...
+def _backport_for_mobile_from_buffer_to_buffer(
+ buffer: BinaryIO, to_version: _int
+) -> bytes: ...
def _get_model_ops_and_info(filename: Union[str, Path]): ...
def _get_model_ops_and_info_from_buffer(buffer: BinaryIO): ...
def _get_mobile_model_contained_types(filename: Union[str, Path]): ...
@@ -365,7 +432,7 @@ def _create_function_from_trace(
var_lookup_fn: Callable[[Tensor], str],
strict: _bool,
force_outplace: _bool,
- argument_names: List[str]
+ argument_names: List[str],
) -> Tuple[Graph, Stack]: ...
def _create_function_from_trace_with_dict(
qualname: str,
@@ -374,7 +441,7 @@ def _create_function_from_trace_with_dict(
var_lookup_fn: Callable[[Tensor], str],
strict: _bool,
force_outplace: _bool,
- argument_names: List[str]
+ argument_names: List[str],
) -> Tuple[Graph, Stack]: ...
def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
@@ -383,25 +450,46 @@ def _get_upgraders_map_size() -> _int: ...
def _dump_upgraders_map() -> Dict[str, str]: ...
def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ...
def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ...
-def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ...
+def merge_type_from_type_comment(
+ decl: Decl, type_annotation_decl: Decl, is_method: _bool
+) -> Decl: ...
def parse_ir(input: str, parse_tensor_constants: _bool) -> Graph: ...
def parse_schema(schema: str) -> FunctionSchema: ...
def get_device(input: Tensor) -> _int: ...
-
-def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallback) -> JitType: ...
+def _resolve_type_from_object(
+ obj: Any, range: SourceRange, rcb: ResolutionCallback
+) -> JitType: ...
def _create_module_with_type(ty: JitType) -> ScriptModule: ...
def _create_object_with_type(ty: ClassType) -> ScriptObject: ...
def _run_emit_module_hook(m: ScriptModule): ...
-def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def, new_name: str) -> Def: ...
-
+def _replace_overloaded_method_decl(
+ overload_decl: Decl, implementation_def: Def, new_name: str
+) -> Def: ...
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
-def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
-def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, params_dict: Dict[str, IValue], opset_version: _int) -> None: ...
-def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool, is_script: _bool, opset_version: _int) -> None: ...
-def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph, module: Optional[ScriptModule] = None) -> None: ...
+def _jit_pass_onnx_set_dynamic_input_shape(
+ graph: Graph,
+ dynamic_axes: Dict[str, Dict[_int, str]],
+ input_names: List[str],
+) -> None: ...
+def _jit_pass_onnx_graph_shape_type_inference(
+ graph: Graph, params_dict: Dict[str, IValue], opset_version: _int
+) -> None: ...
+def _jit_pass_onnx_assign_output_shape(
+ graph: Graph,
+ tensors: List[Tensor],
+ desc: IODescriptor,
+ onnx_shape_inference: _bool,
+ is_script: _bool,
+ opset_version: _int,
+) -> None: ...
+def _jit_pass_onnx_remove_inplace_ops_for_onnx(
+ graph: Graph, module: Optional[ScriptModule] = None
+) -> None: ...
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
-def _jit_pass_peephole(graph: Graph, disable_shape_peepholes: _bool = False) -> None: ...
+def _jit_pass_peephole(
+ graph: Graph, disable_shape_peepholes: _bool = False
+) -> None: ...
def _jit_pass_onnx_autograd_function_process(graph: Graph) -> None: ...
def _jit_pass_fuse_addmm(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess(graph: Graph) -> None: ...
@@ -409,102 +497,133 @@ def _jit_pass_prepare_division_for_onnx(graph: Graph) -> None: ...
def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
def _jit_pass_onnx_unpack_quantized_weights(
- graph: Graph,
- paramsDict: Dict[str, IValue],
- caffe2: _bool
+ graph: Graph, paramsDict: Dict[str, IValue], caffe2: _bool
) -> Dict[str, IValue]: ...
def _jit_pass_onnx_quantization_insert_permutes(
- graph: Graph,
- paramsDict: Dict[str, IValue]
+ graph: Graph, paramsDict: Dict[str, IValue]
) -> Dict[str, IValue]: ...
-def _jit_pass_custom_pattern_based_rewrite_graph(pattern: str, fused_node_name: str, graph: Graph) -> None: ...
-def _jit_onnx_list_model_parameters(module: ScriptModule) -> Tuple[ScriptModule, List[IValue]]: ...
+def _jit_pass_custom_pattern_based_rewrite_graph(
+ pattern: str, fused_node_name: str, graph: Graph
+) -> None: ...
+def _jit_onnx_list_model_parameters(
+ module: ScriptModule,
+) -> Tuple[ScriptModule, List[IValue]]: ...
def _jit_pass_erase_number_types(graph: Graph) -> None: ...
def _jit_pass_onnx_lint(graph: Graph) -> None: ...
-def _jit_pass_onnx(graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes) -> Graph: ...
-def _jit_pass_onnx_scalar_type_analysis(graph: Graph, lowprecision_cast: _bool, opset_version: _int) -> None: ...
-def _jit_pass_onnx_peephole(graph: Graph, opset_version: _int, fixed_batch_size: _bool) -> None: ...
-def _jit_pass_dce_allow_deleting_nodes_with_side_effects(graph: Graph) -> None: ...
+def _jit_pass_onnx(
+ graph: Graph, _jit_pass_onnx: _onnx.OperatorExportTypes
+) -> Graph: ...
+def _jit_pass_onnx_scalar_type_analysis(
+ graph: Graph, lowprecision_cast: _bool, opset_version: _int
+) -> None: ...
+def _jit_pass_onnx_peephole(
+ graph: Graph, opset_version: _int, fixed_batch_size: _bool
+) -> None: ...
+def _jit_pass_dce_allow_deleting_nodes_with_side_effects(
+ graph: Graph,
+) -> None: ...
def _jit_pass_onnx_function_substitution(graph: Graph) -> None: ...
-def _jit_pass_onnx_function_extraction(graph: Graph, module_names : Set[str], param_names : List[str]) -> Dict[Node, Dict[str, str]]: ...
+def _jit_pass_onnx_function_extraction(
+ graph: Graph, module_names: Set[str], param_names: List[str]
+) -> Dict[Node, Dict[str, str]]: ...
def _jit_pass_onnx_clear_scope_records() -> None: ...
-def _jit_pass_onnx_track_scope_attributes(graph: Graph, onnx_attrs: Dict[str, Any]) -> None: ...
+def _jit_pass_onnx_track_scope_attributes(
+ graph: Graph, onnx_attrs: Dict[str, Any]
+) -> None: ...
def _jit_is_onnx_log_enabled() -> _bool: ...
def _jit_set_onnx_log_enabled(enabled: _bool) -> None: ...
def _jit_set_onnx_log_output_stream(stream_name: str) -> None: ...
def _jit_onnx_log(*args: Any) -> None: ...
-def _jit_pass_lower_graph(graph: Graph, m: Module) -> Tuple[Graph, List[IValue]]: ...
+def _jit_pass_lower_graph(
+ graph: Graph, m: Module
+) -> Tuple[Graph, List[IValue]]: ...
def _jit_pass_inline_fork_wait(graph: Graph) -> None: ...
-def _jit_pass_onnx_deduplicate_initializers(graph: Graph, params_dict: Dict[str, IValue], is_train: _bool) -> Dict[str, IValue]: ...
-def _jit_pass_onnx_eval_peephole(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
-def _jit_pass_onnx_constant_fold(graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int) -> Dict[str, IValue]: ...
-def _jit_pass_onnx_eliminate_unused_items(graph: Graph, paramsDict: Dict[str, IValue]) -> Dict[str, IValue]: ...
+def _jit_pass_onnx_deduplicate_initializers(
+ graph: Graph, params_dict: Dict[str, IValue], is_train: _bool
+) -> Dict[str, IValue]: ...
+def _jit_pass_onnx_eval_peephole(
+ graph: Graph, paramsDict: Dict[str, IValue]
+) -> Dict[str, IValue]: ...
+def _jit_pass_onnx_constant_fold(
+ graph: Graph, paramsDict: Dict[str, IValue], opset_version: _int
+) -> Dict[str, IValue]: ...
+def _jit_pass_onnx_eliminate_unused_items(
+ graph: Graph, paramsDict: Dict[str, IValue]
+) -> Dict[str, IValue]: ...
def _jit_pass_onnx_cast_all_constant_to_floating(graph: Graph) -> None: ...
-def _jit_pass_filter_non_tensor_arguments(params: Dict[str, IValue]) -> Dict[str, Tensor]: ...
+def _jit_pass_filter_non_tensor_arguments(
+ params: Dict[str, IValue]
+) -> Dict[str, Tensor]: ...
def _jit_decay_packed_param_input_types(graph: Graph) -> None: ...
-def _jit_pass_onnx_node_shape_type_inference(n: Node, paramsDict: Dict[str, IValue], opset_version: _int) -> None: ...
-def _jit_onnx_convert_pattern_from_subblock(block: Block, n: Node, env: Dict[Value, Value]) -> List[Value]: ...
+def _jit_pass_onnx_node_shape_type_inference(
+ n: Node, paramsDict: Dict[str, IValue], opset_version: _int
+) -> None: ...
+def _jit_onnx_convert_pattern_from_subblock(
+ block: Block, n: Node, env: Dict[Value, Value]
+) -> List[Value]: ...
def _jit_pass_onnx_block(
old_block: Block,
new_block: Block,
operator_export_type: _onnx.OperatorExportTypes,
env: Dict[Value, Value],
- is_sub_block: _bool
+ is_sub_block: _bool,
) -> Dict[Value, Value]: ...
-def _jit_pass_onnx_assign_scoped_names_for_node_and_value(graph: Graph) -> None: ...
-def _jit_pass_fixup_onnx_controlflow_node(n: Node, opset_version: _int) -> List[Value]: ...
-def _jit_onnx_create_full_scope_name(class_name: str, variable_name: str) -> str: ...
-
+def _jit_pass_onnx_assign_scoped_names_for_node_and_value(
+ graph: Graph,
+) -> None: ...
+def _jit_pass_fixup_onnx_controlflow_node(
+ n: Node, opset_version: _int
+) -> List[Value]: ...
+def _jit_onnx_create_full_scope_name(
+ class_name: str, variable_name: str
+) -> str: ...
def _compile_graph_to_code_table(name: str, graph: Graph) -> IValue: ...
-
def _generate_upgraders_graph() -> Dict[str, Graph]: ...
-
def _calculate_package_version_based_on_upgraders(val: _bool): ...
-
def _get_version_calculator_flag() -> _bool: ...
-
-def _jit_script_interface_compile(name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool): ...
+def _jit_script_interface_compile(
+ name: str, class_def: ClassDef, rcb: ResolutionCallback, is_module: _bool
+): ...
def _jit_script_compile_overload(
qualname: str,
overload_decl: Decl,
implementation_def: Def,
rcb: ResolutionCallback,
implementation_defaults: Dict[str, Any],
- signature: Any
+ signature: Any,
): ...
def _jit_script_compile(
qual_name: str,
definition: Def,
rcb: ResolutionCallback,
- defaults: Dict[str, Any]
+ defaults: Dict[str, Any],
): ...
def _jit_script_class_compile(
qual_name: str,
definition: ClassDef,
defaults: Dict[str, Dict[str, Any]],
- rcb: ResolutionCallback
+ rcb: ResolutionCallback,
): ...
def _parse_source_def(src: str) -> Def: ...
def import_ir_module(
cu: CompilationUnit,
filename: Union[str, Path],
map_location: Union[_device, str, None],
- extra_files: Dict[str, Any]
+ extra_files: Dict[str, Any],
) -> ScriptModule: ...
def import_ir_module_from_buffer(
cu: CompilationUnit,
buffer: BinaryIO,
map_location: Union[_device, str, None],
- extra_files: Dict[str, Any]
+ extra_files: Dict[str, Any],
) -> ScriptModule: ...
def _import_ir_module_from_package(
cu: CompilationUnit,
reader: PyTorchFileReader,
storage_context: DeserializationStorageContext,
map_location: Union[_device, str, None],
- ts_id: str
+ ts_id: str,
) -> ScriptModule: ...
-
def _assign_output_shapes(graph: Graph, inputs: List[Tensor]) -> Graph: ...
def _check_onnx_proto(proto: str) -> None: ...
def _propagate_and_assign_input_shapes(
@@ -512,12 +631,11 @@ def _propagate_and_assign_input_shapes(
inputs: Tuple[Tensor, ...],
param_count_list: List[_int],
with_grad: _bool,
- propagate: _bool
+ propagate: _bool,
) -> Graph: ...
# Defined in torch/csrc/jit/runtime/graph_executor.h
-class GraphExecutorState:
- ...
+class GraphExecutorState: ...
# Defined in torch/torch/csrc/jit/ir/alias_analysis.h
class AliasDb:
@@ -539,7 +657,7 @@ class Use:
# Defined in torch/csrc/jit/ir/ir.h
class Value:
- def type(self)-> JitType: ...
+ def type(self) -> JitType: ...
def setType(self, t: JitType) -> Value: ...
def setTypeAs(self, other: Value) -> Value: ...
def inferTypeFrom(self, t: Tensor) -> None: ...
@@ -677,31 +795,34 @@ class Graph:
def setInsertPoint(self, n: Union[Block, Node]) -> None: ...
def insert_point_guard(self, n: Union[Block, Node]) -> _InsertPoint: ...
def insertPoint(self) -> Node: ...
- def insertGraph(self, callee: Graph, inputs: List[Value]) -> List[Value]: ...
+ def insertGraph(
+ self, callee: Graph, inputs: List[Value]
+ ) -> List[Value]: ...
def makeMultiOutputIntoTuple(self) -> None: ...
def copy(self) -> Graph: ...
def create(self, name: str, *args, num_outputs: _int = 1) -> Node: ...
- def op(self, kind: str, *args: Any, **kwargs: Any) -> Union[Value, Sequence[Value]]: ...
+ def op(
+ self, kind: str, *args: Any, **kwargs: Any
+ ) -> Union[Value, Sequence[Value]]: ...
...
-
# Defined in torch/aten/src/ATen/core/alias_info.h
class AliasInfo:
is_write: _bool
before_set: Set[str]
after_set: Set[str]
-
# Defined in torch/aten/src/ATen/core/function_schema.h
class Argument:
name: str
type: JitType
default_value: Optional[Any]
def has_default_value(self) -> _bool: ...
- kwarg_only : _bool
+ kwarg_only: _bool
is_out: _bool
alias_info: Optional[AliasInfo]
...
+
class FunctionSchema:
arguments: List[Argument]
returns: List[Argument]
@@ -713,26 +834,28 @@ class _UpgraderEntry:
bumped_at_version: _int
upgrader_name: str
old_schema: str
- def __init__(self, bumped_at_version: _int, upgrader_name: str, old_schema: str) -> None: ...
+ def __init__(
+ self, bumped_at_version: _int, upgrader_name: str, old_schema: str
+ ) -> None: ...
class _UpgraderRange:
min_version: _int
max_version: _int
def _get_max_operator_version() -> _int: ...
-
def _get_operator_version_map() -> Dict[str, List[_UpgraderEntry]]: ...
-
def _get_upgrader_ranges(name: str) -> List[_UpgraderRange]: ...
-
-def _test_only_add_entry_to_op_version(op_name: str, entry: _UpgraderEntry) -> None: ...
-
+def _test_only_add_entry_to_op_version(
+ op_name: str, entry: _UpgraderEntry
+) -> None: ...
def _test_only_remove_entry_to_op_version(op_name: str) -> None: ...
# Defined in torch/csrc/jit/python/script_init.cpp
class ScriptModuleSerializer:
def __init__(self, export_writer: PyTorchFileWriter) -> None: ...
- def serialize(self, model: ScriptModule, script_module_id: _int) -> None: ...
+ def serialize(
+ self, model: ScriptModule, script_module_id: _int
+ ) -> None: ...
def write_files(self) -> None: ...
def storage_context(self) -> SerializationStorageContext: ...
...
@@ -759,13 +882,19 @@ class ConcreteModuleTypeBuilder:
def set_module_list(self): ...
def set_parameter_list(self): ...
def set_parameter_dict(self): ...
- def add_attribute(self, name: str, ty: JitType, is_param: _bool, is_buffer: _bool): ...
+ def add_attribute(
+ self, name: str, ty: JitType, is_param: _bool, is_buffer: _bool
+ ): ...
def add_module(self, name: str, meta: ConcreteModuleType): ...
def add_constant(self, name: str, value: Any): ...
- def add_overload(self, method_name: str, overloaded_method_names: List[str]): ...
+ def add_overload(
+ self, method_name: str, overloaded_method_names: List[str]
+ ): ...
def add_builtin_function(self, name: str, symbol_name: str): ...
def add_failed_attribute(self, name: str, failure_reason: str): ...
- def add_function_attribute(self, name: str, ty: JitType, func: Callable[..., Any]): ...
+ def add_function_attribute(
+ self, name: str, ty: JitType, func: Callable[..., Any]
+ ): ...
def add_ignored_attribute(self, name: str): ...
def add_ignored_attributes(self, names: List[str]): ...
def add_forward_hook(self, hook: Callable[..., Any]): ...
@@ -773,8 +902,7 @@ class ConcreteModuleTypeBuilder:
class ConcreteModuleType:
def get_constants(self) -> Dict[str, Any]: ...
- def equals(self, other: 'ConcreteModuleType') -> _bool: ...
-
+ def equals(self, other: "ConcreteModuleType") -> _bool: ...
@staticmethod
def from_jit_type(ty: JitType) -> ConcreteModuleType: ...
@@ -784,18 +912,21 @@ class CallStack:
class ErrorReport:
def __init__(self, range: SourceRange) -> None: ...
def what(self) -> str: ...
-
@staticmethod
def call_stack() -> str: ...
class CompilationUnit:
- def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ...
+ def __init__(self, lang: str = ..., _frames_up: _int = ...) -> None: ...
def find_function(self, name: str) -> ScriptFunction: ...
def __getattr__(self, name: str) -> ScriptFunction: ...
- def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ...
+ def define(
+ self, script: str, rcb: ResolutionCallback = ..., _frames_up: _int = ...
+ ): ...
def get_interface(self, name: str) -> InterfaceType: ...
def get_functions(self) -> List[ScriptFunction]: ...
- def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ...
+ def create_function(
+ self, name: str, graph: Graph, shouldMangle: _bool = ...
+ ) -> ScriptFunction: ...
def get_class(self, name: str) -> ClassType: ...
class ScriptObject:
@@ -842,16 +973,19 @@ class BufferDict:
def __init__(self, mod: ScriptModule) -> None: ...
# Defined in torch/csrc/jit/api/module.h
-class Module:
- ...
+class Module: ...
# Defined in torch/csrc/Module.cpp
-def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
+def _initExtension(
+ shm_manager_path: str,
+) -> None: ... # THPModule_initExtension
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr
def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames
def _has_distributed() -> _bool: ... # THPModule_hasDistributed
-def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType
+def _set_default_tensor_type(
+ type,
+) -> None: ... # THPModule_setDefaultTensorType
def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype
def _infer_size(arg1: Size, arg2: Size) -> Size: ... # THPModule_inferSize
def _crash_if_csrc_asan() -> _int: ... # THPModule_crashIfCsrcASAN
@@ -860,56 +994,101 @@ def _crash_if_aten_asan() -> _int: ... # THPModule_crashIfATenASAN
def _show_config() -> str: ... # THPModule_showConfig
def _cxx_flags() -> str: ... # THPModule_cxxFlags
def _parallel_info() -> str: ... # THPModule_parallelInfo
-def _set_backcompat_broadcast_warn(arg: _bool) -> None: ... # THPModule_setBackcompatBroadcastWarn
-def _get_backcompat_broadcast_warn() -> _bool: ... # THPModule_getBackcompatBroadcastWarn
-def _set_backcompat_keepdim_warn(arg: _bool) -> None: ... # THPModule_setBackcompatKeepdimWarn
-def _get_backcompat_keepdim_warn() -> _bool: ... # THPModule_getBackcompatKeepdimWarn
+def _set_backcompat_broadcast_warn(
+ arg: _bool,
+) -> None: ... # THPModule_setBackcompatBroadcastWarn
+def _get_backcompat_broadcast_warn() -> (
+ _bool
+): ... # THPModule_getBackcompatBroadcastWarn
+def _set_backcompat_keepdim_warn(
+ arg: _bool,
+) -> None: ... # THPModule_setBackcompatKeepdimWarn
+def _get_backcompat_keepdim_warn() -> (
+ _bool
+): ... # THPModule_getBackcompatKeepdimWarn
def get_num_thread() -> _int: ... # THPModule_getNumThreads
def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads
def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads
-def set_num_interop_threads(nthreads: _int) -> None: ... # THPModule_setNumInteropThreads
+def set_num_interop_threads(
+ nthreads: _int,
+) -> None: ... # THPModule_setNumInteropThreads
def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN
def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN
def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP
def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash
-def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
-def _set_sdp_use_mem_efficient(arg: _bool) -> None: ... # THPModule_setSDPUseMemEfficient
+def _get_mem_efficient_sdp_enabled() -> (
+ _bool
+): ... # THPModule_userEnabledMathSDP
+def _set_sdp_use_mem_efficient(
+ arg: _bool,
+) -> None: ... # THPModule_setSDPUseMemEfficient
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
-def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn
+def _set_mkldnn_enabled(
+ arg: _bool,
+) -> None: ... # THPModule_setUserEnabledMkldnn
def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN
def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN
def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN
-def _set_cudnn_deterministic(arg: _bool) -> None: ... # THPModule_setDeterministicCuDNN
-def _get_deterministic_algorithms() -> _bool: ... # THPModule_deterministicAlgorithms
-def _get_deterministic_algorithms_warn_only() -> _bool: ... # THPModule_deterministicAlgorithmsWarnOnly
-def _set_deterministic_algorithms(mode: _bool, *, warn_only: _bool=...) -> None: ... # THPModule_setDeterministicAlgorithms
+def _set_cudnn_deterministic(
+ arg: _bool,
+) -> None: ... # THPModule_setDeterministicCuDNN
+def _get_deterministic_algorithms() -> (
+ _bool
+): ... # THPModule_deterministicAlgorithms
+def _get_deterministic_algorithms_warn_only() -> (
+ _bool
+): ... # THPModule_deterministicAlgorithmsWarnOnly
+def _set_deterministic_algorithms(
+ mode: _bool, *, warn_only: _bool = ...
+) -> None: ... # THPModule_setDeterministicAlgorithms
def _get_warnAlways() -> _bool: ... # THPModule_warnAlways
def _set_warnAlways(arg: _bool) -> None: ... # THPModule_setWarnAlways
def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
-def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
+def _set_cudnn_allow_tf32(
+ arg: _bool,
+) -> None: ... # THPModule_setAllowTF32CuDNN
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
-def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS
-def _get_float32_matmul_precision() -> str: ... #THPModule_float32MatmulPrecision
-def _set_float32_matmul_precision(arg: str) -> None: ... #THPModule_setFloat32MatmulPrecision
-def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS
-def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
-def _get_cublas_allow_bf16_reduced_precision_reduction() -> _bool: ... #THPModule_allowBF16ReductionCuBLAS
-def _set_cublas_allow_bf16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowBF16ReductionCuBLAS
+def _set_cublas_allow_tf32(
+ arg: _bool,
+) -> None: ... # THPModule_setAllowTF32CuBLAS
+def _get_float32_matmul_precision() -> (
+ str
+): ... # THPModule_float32MatmulPrecision
+def _set_float32_matmul_precision(
+ arg: str,
+) -> None: ... # THPModule_setFloat32MatmulPrecision
+def _get_cublas_allow_fp16_reduced_precision_reduction() -> (
+ _bool
+): ... # THPModule_allowFP16ReductionCuBLAS
+def _set_cublas_allow_fp16_reduced_precision_reduction(
+ arg: _bool,
+) -> None: ... # THPModule_setAllowFP16ReductionCuBLAS
+def _get_cublas_allow_bf16_reduced_precision_reduction() -> (
+ _bool
+): ... # THPModule_allowBF16ReductionCuBLAS
+def _set_cublas_allow_bf16_reduced_precision_reduction(
+ arg: _bool,
+) -> None: ... # THPModule_setAllowBF16ReductionCuBLAS
def _set_conj(x: Tensor, conj: _bool) -> None: ...
def _set_neg(x: Tensor, neg: _bool) -> None: ...
def _set_meta_in_tls_dispatch_include(meta_in_tls: _bool) -> None: ...
def _meta_in_tls_dispatch_include() -> _bool: ...
def _select_conv_backend(*args, **kwargs) -> ConvBackend: ...
-def _conv_determine_backend_memory_format(input: Tensor, weight: Tensor, backend: ConvBackend) -> memory_format: ...
+def _conv_determine_backend_memory_format(
+ input: Tensor, weight: Tensor, backend: ConvBackend
+) -> memory_format: ...
def _has_storage(x: Tensor) -> _bool: ...
def _should_allow_numbers_as_tensors(func_name: str) -> _bool: ...
+
# NB: There is no Capsule type in typing, see
# https://code.activestate.com/lists/python-dev/139675/
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
def _from_dlpack(data: Any) -> Tensor: ... # THPModule_fromDLPack
-def _get_cpp_backtrace(frames_to_skip: _int, maximum_number_of_frames: _int) -> str: ... # THPModule_getCppBacktrace
+def _get_cpp_backtrace(
+ frames_to_skip: _int, maximum_number_of_frames: _int
+) -> str: ... # THPModule_getCppBacktrace
def set_flush_denormal(arg: _bool) -> _bool: ... # THPModule_setFlushDenormal
def get_default_dtype() -> _dtype: ... # THPModule_getDefaultDtype
def _get_default_device() -> str: ... # THPModule_getDefaultDevice
@@ -917,34 +1096,60 @@ def _get_qengine() -> _int: ... # THPModule_qEngine
def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine
def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines
def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
-def _check_sparse_tensor_invariants() -> _bool: ... # THPModule_checkSparseTensorInvariants
-def _set_check_sparse_tensor_invariants(arg: _bool) -> None: ... # THPModule_setCheckSparseTensorInvariants
-def _set_default_mobile_cpu_allocator() -> None: ... # THPModule_setDefaultMobileCPUAllocator
-def _unset_default_mobile_cpu_allocator() -> None: ... # THPModule_unsetDefaultMobileCPUAllocator
-def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction
-def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function
-def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary
-def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic
-def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
-def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
+def _check_sparse_tensor_invariants() -> (
+ _bool
+): ... # THPModule_checkSparseTensorInvariants
+def _set_check_sparse_tensor_invariants(
+ arg: _bool,
+) -> None: ... # THPModule_setCheckSparseTensorInvariants
+def _set_default_mobile_cpu_allocator() -> (
+ None
+): ... # THPModule_setDefaultMobileCPUAllocator
+def _unset_default_mobile_cpu_allocator() -> (
+ None
+): ... # THPModule_unsetDefaultMobileCPUAllocator
+def _is_torch_function_enabled() -> (
+ _bool
+): ... # THPModule_isEnabledTorchFunction
+def _has_torch_function(
+ args: Iterable[Any],
+) -> _bool: ... # THPModule_has_torch_function
+def _has_torch_function_unary(
+ Any,
+) -> _bool: ... # THPModule_has_torch_function_unary
+def _has_torch_function_variadic(
+ *args: Any,
+) -> _bool: ... # THPModule_has_torch_function_variadic
+def _vmapmode_increment_nesting() -> (
+ _int
+): ... # THPModule_vmapmode_increment_nesting
+def _vmapmode_decrement_nesting() -> (
+ _int
+): ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
def _demangle(str) -> str: ... # c10::demangle
-def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_torch_function
-def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
+def _disabled_torch_function_impl(
+ func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict
+) -> Any: ... # THPModule_disable_torch_function
+def _disabled_torch_dispatch_impl(
+ func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict
+) -> Any: ... # THPModule_disable_dispatch_function
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
+
class _LinalgBackend:
Default: _LinalgBackend
Cusolver: _LinalgBackend
Magma: _LinalgBackend
-class ConvBackend(Enum):
- ...
+class ConvBackend(Enum): ...
# Defined in `valgrind.h` and `callgrind.h` respecitively.
def _valgrind_supported_platform() -> _bool: ... # NVALGRIND
def _valgrind_toggle() -> None: ... # CALLGRIND_TOGGLE_COLLECT
-def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS
+def _valgrind_toggle_and_dump_stats() -> (
+ None
+): ... # CALLGRIND_TOGGLE_COLLECT and CALLGRIND_DUMP_STATS
has_openmp: _bool
has_mkl: _bool
@@ -985,16 +1190,16 @@ def _make_dual(tensor: Tensor, tangent: Tensor, level: _int) -> Tensor: ...
def _unpack_dual(tensor: Tensor, level: _int) -> Tensor: ...
def __set_forward_AD_enabled(enabled: _bool) -> None: ...
def __is_forward_AD_enabled() -> _bool: ...
-def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
+def _register_default_hooks(
+ pack_hook: Callable, unpack_hook: Callable
+) -> None: ...
def _reset_default_hooks() -> None: ...
-
-def _is_torch_function_mode_enabled()-> _bool: ...
+def _is_torch_function_mode_enabled() -> _bool: ...
def _set_torch_function_mode(cls: Any) -> None: ...
def _push_on_torch_function_stack(cls: Any) -> None: ...
def _pop_torch_function_stack() -> Any: ...
def _get_function_stack_at(idx: _int) -> Any: ...
def _len_torch_function_stack() -> _int: ...
-
def _set_torch_dispatch_mode(cls: Any) -> None: ...
def _push_on_torch_dispatch_stack(cls: Any) -> None: ...
def _pop_torch_dispatch_stack() -> Any: ...
@@ -1017,14 +1222,9 @@ class _ViewReplayEnabled:
def __init__(self, mode: _bool) -> None: ...
# Defined in torch/csrc/jit/python/script_init.cpp
-class LoggerBase:
- ...
-
-class NoopLogger(LoggerBase):
- ...
-
-class LockingLogger(LoggerBase):
- ...
+class LoggerBase: ...
+class NoopLogger(LoggerBase): ...
+class LockingLogger(LoggerBase): ...
class AggregationType(Enum):
SUM = 0
@@ -1032,13 +1232,15 @@ class AggregationType(Enum):
class FileCheck:
def run(self, test_string: str) -> None: ...
- def check(self, test_string: str) -> 'FileCheck': ...
- def check_not(self, test_string: str) -> 'FileCheck': ...
- def check_same(self, test_string: str) -> 'FileCheck': ...
- def check_next(self, test_string: str) -> 'FileCheck': ...
- def check_count(self, test_string: str, count: _int, exactly: _bool = False) -> 'FileCheck': ...
- def check_dag(self, test_string: str) -> 'FileCheck': ...
- def check_source_highlighted(self, test_string: str) -> 'FileCheck': ...
+ def check(self, test_string: str) -> "FileCheck": ...
+ def check_not(self, test_string: str) -> "FileCheck": ...
+ def check_same(self, test_string: str) -> "FileCheck": ...
+ def check_next(self, test_string: str) -> "FileCheck": ...
+ def check_count(
+ self, test_string: str, count: _int, exactly: _bool = False
+ ) -> "FileCheck": ...
+ def check_dag(self, test_string: str) -> "FileCheck": ...
+ def check_source_highlighted(self, test_string: str) -> "FileCheck": ...
...
# Defined in torch/csrc/jit/python/init.cpp
@@ -1055,7 +1257,9 @@ class PyTorchFileWriter:
def __init__(self, name: str) -> None: ...
@overload
def __init__(self, buffer: BinaryIO) -> None: ...
- def write_record(self, name: str, data: Union[bytes, _int], size: _int) -> None: ...
+ def write_record(
+ self, name: str, data: Union[bytes, _int], size: _int
+ ) -> None: ...
def write_end_of_file(self) -> None: ...
def set_min_version(self, version: _int) -> None: ...
def get_all_written_records(self) -> List[str]: ...
@@ -1087,7 +1291,6 @@ class Generator:
def seed(self) -> _int: ...
def initial_seed(self) -> _int: ...
-
# Defined in torch/csrc/utils/python_dispatch.cpp
class _DispatchOperatorHandle:
@@ -1096,27 +1299,51 @@ class _DispatchOperatorHandle:
class _DispatchModule:
def def_(self, schema: str, alias: str = "") -> _DispatchModule: ...
def def_legacy(self, schema: str) -> _DispatchModule: ...
- def def_name_t_t(self, name: str, dispatch: str, debug: str = "default_def_name_t_t") -> _DispatchModule: ...
- def def_schema_t_t(self, schema: str, dispatch: str, alias: str, debug: str = "default_def_schema_t_t") -> _DispatchModule: ...
- def impl_t_t(self, name: str, dispatch: str, debug: str = "impl_t_t") -> _DispatchModule: ...
- def impl(self, name: str, dispatch: str, func: Callable) -> _DispatchModule: ...
+ def def_name_t_t(
+ self, name: str, dispatch: str, debug: str = "default_def_name_t_t"
+ ) -> _DispatchModule: ...
+ def def_schema_t_t(
+ self,
+ schema: str,
+ dispatch: str,
+ alias: str,
+ debug: str = "default_def_schema_t_t",
+ ) -> _DispatchModule: ...
+ def impl_t_t(
+ self, name: str, dispatch: str, debug: str = "impl_t_t"
+ ) -> _DispatchModule: ...
+ def impl(
+ self, name: str, dispatch: str, func: Callable
+ ) -> _DispatchModule: ...
def define(self, schema: str, alias: str = "") -> _DispatchModule: ...
def fallback_fallthrough(self, dispatch: str = "") -> _DispatchModule: ...
-def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> _DispatchModule: ...
+def _dispatch_library(
+ kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0
+) -> _DispatchModule: ...
def _dispatch_dump(name: str) -> str: ...
def _dispatch_dump_table(name: str) -> str: ...
def _dispatch_check_invariants(name: str) -> None: ...
def _dispatch_check_all_invariants() -> None: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
-def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: _dispatchkey) -> _bool: ...
-def _dispatch_has_kernel_for_any_dispatch_key(name: str, dispatch_key_set: DispatchKeySet) -> _bool: ...
-def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: _dispatchkey) -> _bool: ...
+def _dispatch_has_kernel_for_dispatch_key(
+ name: str, dispatch: _dispatchkey
+) -> _bool: ...
+def _dispatch_has_kernel_for_any_dispatch_key(
+ name: str, dispatch_key_set: DispatchKeySet
+) -> _bool: ...
+def _dispatch_has_computed_kernel_for_dispatch_key(
+ name: str, dispatch: _dispatchkey
+) -> _bool: ...
def _dispatch_find_dangling_impls() -> List[str]: ...
def _dispatch_get_all_op_names() -> List[str]: ...
-def _dispatch_tls_set_dispatch_key_excluded(dispatch: _dispatchkey, val: _bool) -> None: ...
+def _dispatch_tls_set_dispatch_key_excluded(
+ dispatch: _dispatchkey, val: _bool
+) -> None: ...
def _dispatch_tls_is_dispatch_key_excluded(dispatch: _dispatchkey) -> _bool: ...
-def _dispatch_tls_set_dispatch_key_included(dispatch: _dispatchkey, val: _bool) -> None: ...
+def _dispatch_tls_set_dispatch_key_included(
+ dispatch: _dispatchkey, val: _bool
+) -> None: ...
def _dispatch_tls_is_dispatch_key_included(dispatch: _dispatchkey) -> _bool: ...
def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
def _dispatch_key_name(dispatch: _dispatchkey) -> str: ...
@@ -1241,14 +1468,19 @@ class DispatchKeySet:
def __repr__(self) -> str: ...
_dispatch_autogradother_backends: DispatchKeySet
+
def _dispatch_has_backend_fallback(dispatch: _dispatchkey) -> _bool: ...
def _dispatch_keyset_full_after(t: _dispatchkey) -> DispatchKeySet: ...
def _dispatch_keyset_to_string(keyset: DispatchKeySet) -> str: ...
-def _dispatch_get_backend_keyset_from_autograd(dispatch: _dispatchkey) -> DispatchKeySet: ...
+def _dispatch_get_backend_keyset_from_autograd(
+ dispatch: _dispatchkey,
+) -> DispatchKeySet: ...
def _dispatch_keys(tensor: Tensor) -> DispatchKeySet: ...
def _dispatch_tls_local_exclude_set() -> DispatchKeySet: ...
def _dispatch_tls_local_include_set() -> DispatchKeySet: ...
-def _dispatch_is_included_in_alias(dispatch_a: _dispatchkey, dispatch_b: _dispatchkey) -> _bool: ...
+def _dispatch_is_included_in_alias(
+ dispatch_a: _dispatchkey, dispatch_b: _dispatchkey
+) -> _bool: ...
class ExcludeDispatchKeyGuard:
pass
@@ -1256,9 +1488,12 @@ class ExcludeDispatchKeyGuard:
class _AutoDispatchBelowAutograd:
pass
-def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
-def _dispatch_get_registrations_for_dispatch_key(dispatch_key: str = "") -> List[str]: ...
-
+def _dispatch_print_registrations_for_dispatch_key(
+ dispatch_key: str = "",
+) -> None: ...
+def _dispatch_get_registrations_for_dispatch_key(
+ dispatch_key: str = "",
+) -> List[str]: ...
def _are_functorch_transforms_active() -> _bool: ...
# Define in torch/csrc/autograd/init.cpp
@@ -1270,7 +1505,6 @@ class _EnablePythonDispatcher:
def _set_python_dispatcher(dispatcher: object) -> None: ...
-
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig:
num_calling_threads: _int
@@ -1353,7 +1587,9 @@ class _TensorBase(metaclass=_TensorMeta):
def __float__(self) -> builtins.float: ...
def __floordiv__(self, other: Any) -> Tensor: ...
def __ge__(self, other: Any) -> Tensor: ...
- def __getitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple]) -> Tensor: ...
+ def __getitem__(
+ self, indices: Union[None, _int, slice, Tensor, List, Tuple]
+ ) -> Tensor: ...
def __gt__(self, other: Any) -> Tensor: ...
def __iadd__(self, other: Any) -> Tensor: ...
@overload
@@ -1374,13 +1610,13 @@ class _TensorBase(metaclass=_TensorMeta):
def __imul__(self, other: Any) -> Tensor: ...
def __index__(self) -> builtins.int: ...
@overload
- def __init__(self, *args: Any, device: Device=None) -> None: ...
+ def __init__(self, *args: Any, device: Device = None) -> None: ...
@overload
def __init__(self, storage: Storage) -> None: ...
@overload
def __init__(self, other: Tensor) -> None: ...
@overload
- def __init__(self, size: _size, *, device: Device=None) -> None: ...
+ def __init__(self, size: _size, *, device: Device = None) -> None: ...
def __int__(self) -> builtins.int: ...
def __invert__(self) -> Tensor: ...
@overload
@@ -1439,7 +1675,11 @@ class _TensorBase(metaclass=_TensorMeta):
def __rsub__(self, other: Any) -> Tensor: ...
def __rtruediv__(self, other: Any) -> Tensor: ...
def __rxor__(self, other: Any) -> Tensor: ...
- def __setitem__(self, indices: Union[None, _int, slice, Tensor, List, Tuple], val: Union[Tensor, Number]) -> None: ...
+ def __setitem__(
+ self,
+ indices: Union[None, _int, slice, Tensor, List, Tuple],
+ val: Union[Tensor, Number],
+ ) -> None: ...
def __sub__(self, other: Any) -> Tensor: ...
def __truediv__(self, other: Any) -> Tensor: ...
@overload
@@ -1448,9 +1688,25 @@ class _TensorBase(metaclass=_TensorMeta):
def __xor__(self, other: Number) -> Tensor: ...
@overload
def __xor__(self, other: Any) -> Tensor: ...
- def _addmm_activation(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1, use_gelu: _bool=False) -> Tensor: ...
- def _autocast_to_full_precision(self, cuda_enabled: _bool, cpu_enabled: _bool) -> Tensor: ...
- def _autocast_to_reduced_precision(self, cuda_enabled: _bool, cpu_enabled: _bool, cuda_dtype: _dtype, cpu_dtype: _dtype) -> Tensor: ...
+ def _addmm_activation(
+ self,
+ mat1: Tensor,
+ mat2: Tensor,
+ *,
+ beta: Number = 1,
+ alpha: Number = 1,
+ use_gelu: _bool = False,
+ ) -> Tensor: ...
+ def _autocast_to_full_precision(
+ self, cuda_enabled: _bool, cpu_enabled: _bool
+ ) -> Tensor: ...
+ def _autocast_to_reduced_precision(
+ self,
+ cuda_enabled: _bool,
+ cpu_enabled: _bool,
+ cuda_dtype: _dtype,
+ cpu_dtype: _dtype,
+ ) -> Tensor: ...
def _coalesced_(self, coalesced: _bool) -> Tensor: ...
def _conj(self) -> Tensor: ...
def _conj_physical(self) -> Tensor: ...
@@ -1461,12 +1717,19 @@ class _TensorBase(metaclass=_TensorMeta):
def _is_any_true(self) -> Tensor: ...
def _is_view(self) -> _bool: ...
def _is_zerotensor(self) -> _bool: ...
- def _make_subclass(cls, data: Tensor, require_grad: _bool = False, dispatch_strides: _bool=False, dispatch_device: _bool=False, device_for_backend_keys: Optional[_device] = None) -> Tensor: ...
+ def _make_subclass(
+ cls,
+ data: Tensor,
+ require_grad: _bool = False,
+ dispatch_strides: _bool = False,
+ dispatch_device: _bool = False,
+ device_for_backend_keys: Optional[_device] = None,
+ ) -> Tensor: ...
def _neg_view(self) -> Tensor: ...
def _nested_tensor_size(self) -> Tensor: ...
def _nested_tensor_strides(self) -> Tensor: ...
def _nnz(self) -> _int: ...
- def _to_dense(self, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def _to_dense(self, dtype: Optional[_dtype] = None) -> Tensor: ...
def _values(self) -> Tensor: ...
def abs(self) -> Tensor: ...
def abs_(self) -> Tensor: ...
@@ -1476,43 +1739,108 @@ class _TensorBase(metaclass=_TensorMeta):
def acos_(self) -> Tensor: ...
def acosh(self) -> Tensor: ...
def acosh_(self) -> Tensor: ...
- def add(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...
- def add_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1) -> Tensor: ...
- def addbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addcdiv(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ...
- def addcdiv_(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ...
- def addcmul(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ...
- def addcmul_(self, tensor1: Tensor, tensor2: Tensor, *, value: Number=1) -> Tensor: ...
- def addmm(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addmm_(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addmv(self, mat: Tensor, vec: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addmv_(self, mat: Tensor, vec: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addr(self, vec1: Tensor, vec2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def addr_(self, vec1: Tensor, vec2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
+ def add(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ alpha: Optional[Number] = 1,
+ out: Optional[Tensor] = None,
+ ) -> Tensor: ...
+ def add_(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ alpha: Optional[Number] = 1,
+ ) -> Tensor: ...
+ def addbmm(
+ self,
+ batch1: Tensor,
+ batch2: Tensor,
+ *,
+ beta: Number = 1,
+ alpha: Number = 1,
+ ) -> Tensor: ...
+ def addbmm_(
+ self,
+ batch1: Tensor,
+ batch2: Tensor,
+ *,
+ beta: Number = 1,
+ alpha: Number = 1,
+ ) -> Tensor: ...
+ def addcdiv(
+ self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1
+ ) -> Tensor: ...
+ def addcdiv_(
+ self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1
+ ) -> Tensor: ...
+ def addcmul(
+ self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1
+ ) -> Tensor: ...
+ def addcmul_(
+ self, tensor1: Tensor, tensor2: Tensor, *, value: Number = 1
+ ) -> Tensor: ...
+ def addmm(
+ self, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
+ def addmm_(
+ self, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
+ def addmv(
+ self, mat: Tensor, vec: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
+ def addmv_(
+ self, mat: Tensor, vec: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
+ def addr(
+ self, vec1: Tensor, vec2: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
+ def addr_(
+ self, vec1: Tensor, vec2: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
def adjoint(self) -> Tensor: ...
def align_as(self, other: Tensor) -> Tensor: ...
@overload
- def align_to(self, order: Sequence[Union[str, ellipsis, None]], ellipsis_idx: _int) -> Tensor: ...
+ def align_to(
+ self, order: Sequence[Union[str, ellipsis, None]], ellipsis_idx: _int
+ ) -> Tensor: ...
@overload
- def align_to(self, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ...
+ def align_to(
+ self, names: Sequence[Union[str, ellipsis, None]]
+ ) -> Tensor: ...
@overload
def all(self) -> Tensor: ...
@overload
- def all(self, dim: _int, keepdim: _bool=False) -> Tensor: ...
+ def all(self, dim: _int, keepdim: _bool = False) -> Tensor: ...
@overload
- def all(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> Tensor: ...
- def allclose(self, other: Tensor, rtol: _float=1e-05, atol: _float=1e-08, equal_nan: _bool=False) -> _bool: ...
- def amax(self, dim: Union[_int, _size]=(), keepdim: _bool=False) -> Tensor: ...
- def amin(self, dim: Union[_int, _size]=(), keepdim: _bool=False) -> Tensor: ...
- def aminmax(self, *, dim: Optional[_int]=None, keepdim: _bool=False) -> torch.return_types.aminmax: ...
+ def all(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> Tensor: ...
+ def allclose(
+ self,
+ other: Tensor,
+ rtol: _float = 1e-05,
+ atol: _float = 1e-08,
+ equal_nan: _bool = False,
+ ) -> _bool: ...
+ def amax(
+ self, dim: Union[_int, _size] = (), keepdim: _bool = False
+ ) -> Tensor: ...
+ def amin(
+ self, dim: Union[_int, _size] = (), keepdim: _bool = False
+ ) -> Tensor: ...
+ def aminmax(
+ self, *, dim: Optional[_int] = None, keepdim: _bool = False
+ ) -> torch.return_types.aminmax: ...
def angle(self) -> Tensor: ...
@overload
def any(self) -> Tensor: ...
@overload
- def any(self, dim: _int, keepdim: _bool=False) -> Tensor: ...
+ def any(self, dim: _int, keepdim: _bool = False) -> Tensor: ...
@overload
- def any(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> Tensor: ...
+ def any(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> Tensor: ...
def apply_(self, callable: Callable) -> Tensor: ...
def arccos(self) -> Tensor: ...
def arccos_(self) -> Tensor: ...
@@ -1528,18 +1856,42 @@ class _TensorBase(metaclass=_TensorMeta):
def arctan_(self) -> Tensor: ...
def arctanh(self) -> Tensor: ...
def arctanh_(self) -> Tensor: ...
- def argmax(self, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
- def argmin(self, dim: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
- @overload
- def argsort(self, *, stable: _bool, dim: _int=-1, descending: _bool=False) -> Tensor: ...
- @overload
- def argsort(self, dim: _int=-1, descending: _bool=False) -> Tensor: ...
- @overload
- def argsort(self, dim: Union[str, ellipsis, None], descending: _bool=False) -> Tensor: ...
+ def argmax(
+ self, dim: Optional[_int] = None, keepdim: _bool = False
+ ) -> Tensor: ...
+ def argmin(
+ self, dim: Optional[_int] = None, keepdim: _bool = False
+ ) -> Tensor: ...
+ @overload
+ def argsort(
+ self, *, stable: _bool, dim: _int = -1, descending: _bool = False
+ ) -> Tensor: ...
+ @overload
+ def argsort(self, dim: _int = -1, descending: _bool = False) -> Tensor: ...
+ @overload
+ def argsort(
+ self, dim: Union[str, ellipsis, None], descending: _bool = False
+ ) -> Tensor: ...
def argwhere(self) -> Tensor: ...
- def as_strided(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]]=None) -> Tensor: ...
- def as_strided_(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]]=None) -> Tensor: ...
- def as_strided_scatter(self, src: Tensor, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], storage_offset: Optional[Union[_int, SymInt]]=None) -> Tensor: ...
+ def as_strided(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ stride: Sequence[Union[_int, SymInt]],
+ storage_offset: Optional[Union[_int, SymInt]] = None,
+ ) -> Tensor: ...
+ def as_strided_(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ stride: Sequence[Union[_int, SymInt]],
+ storage_offset: Optional[Union[_int, SymInt]] = None,
+ ) -> Tensor: ...
+ def as_strided_scatter(
+ self,
+ src: Tensor,
+ size: Sequence[Union[_int, SymInt]],
+ stride: Sequence[Union[_int, SymInt]],
+ storage_offset: Optional[Union[_int, SymInt]] = None,
+ ) -> Tensor: ...
def as_subclass(self, cls: Type[S]) -> S: ...
def asin(self) -> Tensor: ...
def asin_(self) -> Tensor: ...
@@ -1551,18 +1903,40 @@ class _TensorBase(metaclass=_TensorMeta):
def atan_(self) -> Tensor: ...
def atanh(self) -> Tensor: ...
def atanh_(self) -> Tensor: ...
- def baddbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- def baddbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
- @overload
- def bernoulli(self, *, generator: Optional[Generator]=None) -> Tensor: ...
- @overload
- def bernoulli(self, p: _float, *, generator: Optional[Generator]=None) -> Tensor: ...
- @overload
- def bernoulli_(self, p: Tensor, *, generator: Optional[Generator]=None) -> Tensor: ...
- @overload
- def bernoulli_(self, p: _float=0.5, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def baddbmm(
+ self,
+ batch1: Tensor,
+ batch2: Tensor,
+ *,
+ beta: Number = 1,
+ alpha: Number = 1,
+ ) -> Tensor: ...
+ def baddbmm_(
+ self,
+ batch1: Tensor,
+ batch2: Tensor,
+ *,
+ beta: Number = 1,
+ alpha: Number = 1,
+ ) -> Tensor: ...
+ @overload
+ def bernoulli(self, *, generator: Optional[Generator] = None) -> Tensor: ...
+ @overload
+ def bernoulli(
+ self, p: _float, *, generator: Optional[Generator] = None
+ ) -> Tensor: ...
+ @overload
+ def bernoulli_(
+ self, p: Tensor, *, generator: Optional[Generator] = None
+ ) -> Tensor: ...
+ @overload
+ def bernoulli_(
+ self, p: _float = 0.5, *, generator: Optional[Generator] = None
+ ) -> Tensor: ...
def bfloat16(self) -> Tensor: ...
- def bincount(self, weights: Optional[Tensor]=None, minlength: _int=0) -> Tensor: ...
+ def bincount(
+ self, weights: Optional[Tensor] = None, minlength: _int = 0
+ ) -> Tensor: ...
@overload
def bitwise_and(self, other: Tensor) -> Tensor: ...
@overload
@@ -1612,24 +1986,42 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def broadcast_to(self, *size: _int) -> Tensor: ...
def byte(self) -> Tensor: ...
- def cauchy_(self, median: _float=0, sigma: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def cauchy_(
+ self,
+ median: _float = 0,
+ sigma: _float = 1,
+ *,
+ generator: Optional[Generator] = None,
+ ) -> Tensor: ...
def ccol_indices(self) -> Tensor: ...
def ceil(self) -> Tensor: ...
def ceil_(self) -> Tensor: ...
- def chalf(self, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
+ def chalf(
+ self, *, memory_format: Optional[memory_format] = None
+ ) -> Tensor: ...
def char(self) -> Tensor: ...
- def cholesky(self, upper: _bool=False) -> Tensor: ...
- def cholesky_inverse(self, upper: _bool=False) -> Tensor: ...
- def cholesky_solve(self, input2: Tensor, upper: _bool=False) -> Tensor: ...
- def chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]: ...
+ def cholesky(self, upper: _bool = False) -> Tensor: ...
+ def cholesky_inverse(self, upper: _bool = False) -> Tensor: ...
+ def cholesky_solve(
+ self, input2: Tensor, upper: _bool = False
+ ) -> Tensor: ...
+ def chunk(self, chunks: _int, dim: _int = 0) -> List[Tensor]: ...
@overload
- def clamp(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ...
+ def clamp(
+ self, min: Optional[Tensor] = None, max: Optional[Tensor] = None
+ ) -> Tensor: ...
@overload
- def clamp(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ...
+ def clamp(
+ self, min: Optional[Number] = None, max: Optional[Number] = None
+ ) -> Tensor: ...
@overload
- def clamp_(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ...
+ def clamp_(
+ self, min: Optional[Tensor] = None, max: Optional[Tensor] = None
+ ) -> Tensor: ...
@overload
- def clamp_(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ...
+ def clamp_(
+ self, min: Optional[Number] = None, max: Optional[Number] = None
+ ) -> Tensor: ...
@overload
def clamp_max(self, max: Tensor) -> Tensor: ...
@overload
@@ -1647,21 +2039,31 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def clamp_min_(self, min: Number) -> Tensor: ...
@overload
- def clip(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ...
+ def clip(
+ self, min: Optional[Tensor] = None, max: Optional[Tensor] = None
+ ) -> Tensor: ...
@overload
- def clip(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ...
+ def clip(
+ self, min: Optional[Number] = None, max: Optional[Number] = None
+ ) -> Tensor: ...
@overload
- def clip_(self, min: Optional[Tensor]=None, max: Optional[Tensor]=None) -> Tensor: ...
+ def clip_(
+ self, min: Optional[Tensor] = None, max: Optional[Tensor] = None
+ ) -> Tensor: ...
@overload
- def clip_(self, min: Optional[Number]=None, max: Optional[Number]=None) -> Tensor: ...
- def clone(self, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
+ def clip_(
+ self, min: Optional[Number] = None, max: Optional[Number] = None
+ ) -> Tensor: ...
+ def clone(
+ self, *, memory_format: Optional[memory_format] = None
+ ) -> Tensor: ...
def coalesce(self) -> Tensor: ...
def col_indices(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def conj_physical(self) -> Tensor: ...
def conj_physical_(self) -> Tensor: ...
def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...
- def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ...
+ def copy_(self, src: Tensor, non_blocking: _bool = False) -> Tensor: ...
@overload
def copysign(self, other: Tensor) -> Tensor: ...
@overload
@@ -1676,40 +2078,70 @@ class _TensorBase(metaclass=_TensorMeta):
def cosh(self) -> Tensor: ...
def cosh_(self) -> Tensor: ...
@overload
- def count_nonzero(self, dim: Optional[_int]=None) -> Tensor: ...
+ def count_nonzero(self, dim: Optional[_int] = None) -> Tensor: ...
@overload
def count_nonzero(self, dim: _size) -> Tensor: ...
@overload
def count_nonzero(self, *dim: _int) -> Tensor: ...
- def cov(self, *, correction: _int=1, fweights: Optional[Tensor]=None, aweights: Optional[Tensor]=None) -> Tensor: ...
+ def cov(
+ self,
+ *,
+ correction: _int = 1,
+ fweights: Optional[Tensor] = None,
+ aweights: Optional[Tensor] = None,
+ ) -> Tensor: ...
def cpu(self) -> Tensor: ...
- def cross(self, other: Tensor, dim: Optional[_int]=None) -> Tensor: ...
+ def cross(self, other: Tensor, dim: Optional[_int] = None) -> Tensor: ...
def crow_indices(self) -> Tensor: ...
- def cuda(self, device: Optional[Union[_device, _int, str]]=None, non_blocking: _bool=False) -> Tensor: ...
+ def cuda(
+ self,
+ device: Optional[Union[_device, _int, str]] = None,
+ non_blocking: _bool = False,
+ ) -> Tensor: ...
@overload
def cummax(self, dim: _int) -> torch.return_types.cummax: ...
@overload
- def cummax(self, dim: Union[str, ellipsis, None]) -> torch.return_types.cummax: ...
+ def cummax(
+ self, dim: Union[str, ellipsis, None]
+ ) -> torch.return_types.cummax: ...
@overload
def cummin(self, dim: _int) -> torch.return_types.cummin: ...
@overload
- def cummin(self, dim: Union[str, ellipsis, None]) -> torch.return_types.cummin: ...
+ def cummin(
+ self, dim: Union[str, ellipsis, None]
+ ) -> torch.return_types.cummin: ...
@overload
- def cumprod(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumprod(
+ self, dim: _int, *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumprod(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumprod(
+ self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumprod_(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumprod_(
+ self, dim: _int, *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumprod_(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumprod_(
+ self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumsum(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumsum(
+ self, dim: _int, *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumsum(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumsum(
+ self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumsum_(self, dim: _int, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumsum_(
+ self, dim: _int, *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def cumsum_(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def cumsum_(
+ self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
def data_ptr(self) -> _int: ...
def deg2rad(self) -> Tensor: ...
def deg2rad_(self) -> Tensor: ...
@@ -1718,35 +2150,72 @@ class _TensorBase(metaclass=_TensorMeta):
def det(self) -> Tensor: ...
def detach(self) -> Tensor: ...
def detach_(self) -> Tensor: ...
- def diag(self, diagonal: _int=0) -> Tensor: ...
- def diag_embed(self, offset: _int=0, dim1: _int=-2, dim2: _int=-1) -> Tensor: ...
- def diagflat(self, offset: _int=0) -> Tensor: ...
+ def diag(self, diagonal: _int = 0) -> Tensor: ...
+ def diag_embed(
+ self, offset: _int = 0, dim1: _int = -2, dim2: _int = -1
+ ) -> Tensor: ...
+ def diagflat(self, offset: _int = 0) -> Tensor: ...
@overload
- def diagonal(self, *, outdim: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None], dim2: Union[str, ellipsis, None], offset: _int=0) -> Tensor: ...
- @overload
- def diagonal(self, offset: _int=0, dim1: _int=0, dim2: _int=1) -> Tensor: ...
- def diagonal_scatter(self, src: Tensor, offset: _int=0, dim1: _int=0, dim2: _int=1) -> Tensor: ...
- def diff(self, n: _int=1, dim: _int=-1, prepend: Optional[Tensor]=None, append: Optional[Tensor]=None) -> Tensor: ...
+ def diagonal(
+ self,
+ *,
+ outdim: Union[str, ellipsis, None],
+ dim1: Union[str, ellipsis, None],
+ dim2: Union[str, ellipsis, None],
+ offset: _int = 0,
+ ) -> Tensor: ...
+ @overload
+ def diagonal(
+ self, offset: _int = 0, dim1: _int = 0, dim2: _int = 1
+ ) -> Tensor: ...
+ def diagonal_scatter(
+ self, src: Tensor, offset: _int = 0, dim1: _int = 0, dim2: _int = 1
+ ) -> Tensor: ...
+ def diff(
+ self,
+ n: _int = 1,
+ dim: _int = -1,
+ prepend: Optional[Tensor] = None,
+ append: Optional[Tensor] = None,
+ ) -> Tensor: ...
def digamma(self) -> Tensor: ...
def digamma_(self) -> Tensor: ...
def dim(self) -> _int: ...
- def dist(self, other: Tensor, p: Number=2) -> Tensor: ...
- def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...
- def div_(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ...
+ def dist(self, other: Tensor, p: Number = 2) -> Tensor: ...
+ def div(
+ self,
+ other: Union[Tensor, Number],
+ *,
+ rounding_mode: Optional[str] = None,
+ ) -> Tensor: ...
+ def div_(
+ self,
+ other: Union[Tensor, Number],
+ *,
+ rounding_mode: Optional[str] = None,
+ ) -> Tensor: ...
@overload
def divide(self, other: Tensor) -> Tensor: ...
@overload
- def divide(self, other: Tensor, *, rounding_mode: Optional[str]) -> Tensor: ...
+ def divide(
+ self, other: Tensor, *, rounding_mode: Optional[str]
+ ) -> Tensor: ...
@overload
- def divide(self, other: Number, *, rounding_mode: Optional[str]) -> Tensor: ...
+ def divide(
+ self, other: Number, *, rounding_mode: Optional[str]
+ ) -> Tensor: ...
@overload
def divide(self, other: Number) -> Tensor: ...
@overload
def divide_(self, other: Tensor) -> Tensor: ...
@overload
- def divide_(self, other: Tensor, *, rounding_mode: Optional[str]) -> Tensor: ...
+ def divide_(
+ self, other: Tensor, *, rounding_mode: Optional[str]
+ ) -> Tensor: ...
@overload
- def divide_(self, other: Number, *, rounding_mode: Optional[str]) -> Tensor: ...
+ def divide_(
+ self, other: Number, *, rounding_mode: Optional[str]
+ ) -> Tensor: ...
@overload
def divide_(self, other: Number) -> Tensor: ...
def dot(self, tensor: Tensor) -> Tensor: ...
@@ -1778,28 +2247,48 @@ class _TensorBase(metaclass=_TensorMeta):
def exp2_(self) -> Tensor: ...
def exp_(self) -> Tensor: ...
@overload
- def expand(self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool=False) -> Tensor: ...
+ def expand(
+ self, size: Sequence[Union[_int, SymInt]], *, implicit: _bool = False
+ ) -> Tensor: ...
@overload
- def expand(self, *size: _int, implicit: _bool=False) -> Tensor: ...
+ def expand(self, *size: _int, implicit: _bool = False) -> Tensor: ...
def expand_as(self, other: Tensor) -> Tensor: ...
def expm1(self) -> Tensor: ...
def expm1_(self) -> Tensor: ...
- def exponential_(self, lambd: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def exponential_(
+ self, lambd: _float = 1, *, generator: Optional[Generator] = None
+ ) -> Tensor: ...
@overload
def fill_(self, value: Tensor) -> Tensor: ...
@overload
def fill_(self, value: Number) -> Tensor: ...
- def fill_diagonal_(self, fill_value: Number, wrap: _bool=False) -> Tensor: ...
+ def fill_diagonal_(
+ self, fill_value: Number, wrap: _bool = False
+ ) -> Tensor: ...
def fix(self) -> Tensor: ...
def fix_(self) -> Tensor: ...
@overload
- def flatten(self, start_dim: _int=0, end_dim: _int=-1) -> Tensor: ...
+ def flatten(self, start_dim: _int = 0, end_dim: _int = -1) -> Tensor: ...
@overload
- def flatten(self, start_dim: _int, end_dim: _int, out_dim: Union[str, ellipsis, None]) -> Tensor: ...
+ def flatten(
+ self,
+ start_dim: _int,
+ end_dim: _int,
+ out_dim: Union[str, ellipsis, None],
+ ) -> Tensor: ...
@overload
- def flatten(self, start_dim: Union[str, ellipsis, None], end_dim: Union[str, ellipsis, None], out_dim: Union[str, ellipsis, None]) -> Tensor: ...
+ def flatten(
+ self,
+ start_dim: Union[str, ellipsis, None],
+ end_dim: Union[str, ellipsis, None],
+ out_dim: Union[str, ellipsis, None],
+ ) -> Tensor: ...
@overload
- def flatten(self, dims: Sequence[Union[str, ellipsis, None]], out_dim: Union[str, ellipsis, None]) -> Tensor: ...
+ def flatten(
+ self,
+ dims: Sequence[Union[str, ellipsis, None]],
+ out_dim: Union[str, ellipsis, None],
+ ) -> Tensor: ...
@overload
def flip(self, dims: _size) -> Tensor: ...
@overload
@@ -1817,8 +2306,15 @@ class _TensorBase(metaclass=_TensorMeta):
def float_power_(self, exponent: Number) -> Tensor: ...
def floor(self) -> Tensor: ...
def floor_(self) -> Tensor: ...
- def floor_divide(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor]=None) -> Tensor: ...
- def floor_divide_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: ...
+ def floor_divide(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ out: Optional[Tensor] = None,
+ ) -> Tensor: ...
+ def floor_divide_(
+ self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]
+ ) -> Tensor: ...
def fmax(self, other: Tensor) -> Tensor: ...
def fmin(self, other: Tensor) -> Tensor: ...
@overload
@@ -1833,9 +2329,17 @@ class _TensorBase(metaclass=_TensorMeta):
def frac_(self) -> Tensor: ...
def frexp(self) -> torch.return_types.frexp: ...
@overload
- def gather(self, dim: _int, index: Tensor, *, sparse_grad: _bool=False) -> Tensor: ...
+ def gather(
+ self, dim: _int, index: Tensor, *, sparse_grad: _bool = False
+ ) -> Tensor: ...
@overload
- def gather(self, dim: Union[str, ellipsis, None], index: Tensor, *, sparse_grad: _bool=False) -> Tensor: ...
+ def gather(
+ self,
+ dim: Union[str, ellipsis, None],
+ index: Tensor,
+ *,
+ sparse_grad: _bool = False,
+ ) -> Tensor: ...
def gcd(self, other: Tensor) -> Tensor: ...
def gcd_(self, other: Tensor) -> Tensor: ...
@overload
@@ -1846,7 +2350,9 @@ class _TensorBase(metaclass=_TensorMeta):
def ge_(self, other: Tensor) -> Tensor: ...
@overload
def ge_(self, other: Number) -> Tensor: ...
- def geometric_(self, p: _float, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def geometric_(
+ self, p: _float, *, generator: Optional[Generator] = None
+ ) -> Tensor: ...
def geqrf(self) -> torch.return_types.geqrf: ...
def ger(self, vec2: Tensor) -> Tensor: ...
def get_device(self) -> _int: ...
@@ -1875,15 +2381,30 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def gt_(self, other: Number) -> Tensor: ...
def half(self) -> Tensor: ...
- def hardshrink(self, lambd: Number=0.5) -> Tensor: ...
+ def hardshrink(self, lambd: Number = 0.5) -> Tensor: ...
def has_names(self) -> _bool: ...
def heaviside(self, values: Tensor) -> Tensor: ...
def heaviside_(self, values: Tensor) -> Tensor: ...
- def histc(self, bins: _int=100, min: Number=0, max: Number=0) -> Tensor: ...
+ def histc(
+ self, bins: _int = 100, min: Number = 0, max: Number = 0
+ ) -> Tensor: ...
@overload
- def histogram(self, bins: Tensor, *, weight: Optional[Tensor]=None, density: _bool=False) -> torch.return_types.histogram: ...
+ def histogram(
+ self,
+ bins: Tensor,
+ *,
+ weight: Optional[Tensor] = None,
+ density: _bool = False,
+ ) -> torch.return_types.histogram: ...
@overload
- def histogram(self, bins: _int=100, *, range: Optional[Sequence[_float]]=None, weight: Optional[Tensor]=None, density: _bool=False) -> torch.return_types.histogram: ...
+ def histogram(
+ self,
+ bins: _int = 100,
+ *,
+ range: Optional[Sequence[_float]] = None,
+ weight: Optional[Tensor] = None,
+ density: _bool = False,
+ ) -> torch.return_types.histogram: ...
@overload
def hsplit(self, sections: _int) -> List[Tensor]: ...
@overload
@@ -1899,42 +2420,101 @@ class _TensorBase(metaclass=_TensorMeta):
def igammac(self, other: Tensor) -> Tensor: ...
def igammac_(self, other: Tensor) -> Tensor: ...
@overload
- def index_add(self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number=1) -> Tensor: ...
+ def index_add(
+ self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number = 1
+ ) -> Tensor: ...
@overload
- def index_add(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor, *, alpha: Number=1) -> Tensor: ...
- def index_add_(self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number=1) -> Tensor: ...
- @overload
- def index_copy(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ...
- @overload
- def index_copy(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: ...
- @overload
- def index_copy_(self, dim: _int, index: Tensor, source: Tensor) -> Tensor: ...
- @overload
- def index_copy_(self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor) -> Tensor: ...
+ def index_add(
+ self,
+ dim: Union[str, ellipsis, None],
+ index: Tensor,
+ source: Tensor,
+ *,
+ alpha: Number = 1,
+ ) -> Tensor: ...
+ def index_add_(
+ self, dim: _int, index: Tensor, source: Tensor, *, alpha: Number = 1
+ ) -> Tensor: ...
+ @overload
+ def index_copy(
+ self, dim: _int, index: Tensor, source: Tensor
+ ) -> Tensor: ...
+ @overload
+ def index_copy(
+ self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor
+ ) -> Tensor: ...
+ @overload
+ def index_copy_(
+ self, dim: _int, index: Tensor, source: Tensor
+ ) -> Tensor: ...
+ @overload
+ def index_copy_(
+ self, dim: Union[str, ellipsis, None], index: Tensor, source: Tensor
+ ) -> Tensor: ...
@overload
def index_fill(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: ...
@overload
- def index_fill(self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: ...
+ def index_fill(
+ self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor
+ ) -> Tensor: ...
@overload
def index_fill(self, dim: _int, index: Tensor, value: Number) -> Tensor: ...
@overload
- def index_fill(self, dim: Union[str, ellipsis, None], index: Tensor, value: Number) -> Tensor: ...
+ def index_fill(
+ self, dim: Union[str, ellipsis, None], index: Tensor, value: Number
+ ) -> Tensor: ...
@overload
- def index_fill_(self, dim: _int, index: Tensor, value: Tensor) -> Tensor: ...
+ def index_fill_(
+ self, dim: _int, index: Tensor, value: Tensor
+ ) -> Tensor: ...
@overload
- def index_fill_(self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor) -> Tensor: ...
+ def index_fill_(
+ self, dim: Union[str, ellipsis, None], index: Tensor, value: Tensor
+ ) -> Tensor: ...
@overload
- def index_fill_(self, dim: _int, index: Tensor, value: Number) -> Tensor: ...
+ def index_fill_(
+ self, dim: _int, index: Tensor, value: Number
+ ) -> Tensor: ...
@overload
- def index_fill_(self, dim: Union[str, ellipsis, None], index: Tensor, value: Number) -> Tensor: ...
- def index_put(self, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool=False) -> Tensor: ...
- def index_put_(self, indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]], values: Tensor, accumulate: _bool=False) -> Tensor: ...
- def index_reduce(self, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ...
- def index_reduce_(self, dim: _int, index: Tensor, source: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ...
+ def index_fill_(
+ self, dim: Union[str, ellipsis, None], index: Tensor, value: Number
+ ) -> Tensor: ...
+ def index_put(
+ self,
+ indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]],
+ values: Tensor,
+ accumulate: _bool = False,
+ ) -> Tensor: ...
+ def index_put_(
+ self,
+ indices: Optional[Union[Tuple[Tensor, ...], List[Tensor]]],
+ values: Tensor,
+ accumulate: _bool = False,
+ ) -> Tensor: ...
+ def index_reduce(
+ self,
+ dim: _int,
+ index: Tensor,
+ source: Tensor,
+ reduce: str,
+ *,
+ include_self: _bool = True,
+ ) -> Tensor: ...
+ def index_reduce_(
+ self,
+ dim: _int,
+ index: Tensor,
+ source: Tensor,
+ reduce: str,
+ *,
+ include_self: _bool = True,
+ ) -> Tensor: ...
@overload
def index_select(self, dim: _int, index: Tensor) -> Tensor: ...
@overload
- def index_select(self, dim: Union[str, ellipsis, None], index: Tensor) -> Tensor: ...
+ def index_select(
+ self, dim: Union[str, ellipsis, None], index: Tensor
+ ) -> Tensor: ...
def indices(self) -> Tensor: ...
def inner(self, other: Tensor) -> Tensor: ...
def int(self) -> Tensor: ...
@@ -1957,7 +2537,9 @@ class _TensorBase(metaclass=_TensorMeta):
is_nested: _bool
def is_nonzero(self) -> _bool: ...
is_ort: _bool
- def is_pinned(self, device: Optional[Union[_device, str, None]]=None) -> _bool: ...
+ def is_pinned(
+ self, device: Optional[Union[_device, str, None]] = None
+ ) -> _bool: ...
is_quantized: _bool
def is_same_size(self, other: Tensor) -> _bool: ...
def is_set_to(self, tensor: Tensor) -> _bool: ...
@@ -1965,20 +2547,41 @@ class _TensorBase(metaclass=_TensorMeta):
is_sparse: _bool
is_sparse_csr: _bool
is_vulkan: _bool
- def isclose(self, other: Tensor, rtol: _float=1e-05, atol: _float=1e-08, equal_nan: _bool=False) -> Tensor: ...
+ def isclose(
+ self,
+ other: Tensor,
+ rtol: _float = 1e-05,
+ atol: _float = 1e-08,
+ equal_nan: _bool = False,
+ ) -> Tensor: ...
def isfinite(self) -> Tensor: ...
def isinf(self) -> Tensor: ...
def isnan(self) -> Tensor: ...
def isneginf(self) -> Tensor: ...
def isposinf(self) -> Tensor: ...
def isreal(self) -> Tensor: ...
- def istft(self, n_fft: _int, hop_length: Optional[_int]=None, win_length: Optional[_int]=None, window: Optional[Tensor]=None, center: _bool=True, normalized: _bool=False, onesided: Optional[_bool]=None, length: Optional[_int]=None, return_complex: _bool=False) -> Tensor: ...
+ def istft(
+ self,
+ n_fft: _int,
+ hop_length: Optional[_int] = None,
+ win_length: Optional[_int] = None,
+ window: Optional[Tensor] = None,
+ center: _bool = True,
+ normalized: _bool = False,
+ onesided: Optional[_bool] = None,
+ length: Optional[_int] = None,
+ return_complex: _bool = False,
+ ) -> Tensor: ...
def item(self) -> Number: ...
def kron(self, other: Tensor) -> Tensor: ...
@overload
- def kthvalue(self, k: _int, dim: _int=-1, keepdim: _bool=False) -> torch.return_types.kthvalue: ...
+ def kthvalue(
+ self, k: _int, dim: _int = -1, keepdim: _bool = False
+ ) -> torch.return_types.kthvalue: ...
@overload
- def kthvalue(self, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.kthvalue: ...
+ def kthvalue(
+ self, k: _int, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> torch.return_types.kthvalue: ...
def lcm(self, other: Tensor) -> Tensor: ...
def lcm_(self, other: Tensor) -> Tensor: ...
def ldexp(self, other: Tensor) -> Tensor: ...
@@ -2025,11 +2628,21 @@ class _TensorBase(metaclass=_TensorMeta):
def log2(self) -> Tensor: ...
def log2_(self) -> Tensor: ...
def log_(self) -> Tensor: ...
- def log_normal_(self, mean: _float=1, std: _float=2, *, generator: Optional[Generator]=None) -> Tensor: ...
- @overload
- def log_softmax(self, dim: _int, dtype: Optional[_dtype]=None) -> Tensor: ...
- @overload
- def log_softmax(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def log_normal_(
+ self,
+ mean: _float = 1,
+ std: _float = 2,
+ *,
+ generator: Optional[Generator] = None,
+ ) -> Tensor: ...
+ @overload
+ def log_softmax(
+ self, dim: _int, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
+ @overload
+ def log_softmax(
+ self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
def logaddexp(self, other: Tensor) -> Tensor: ...
def logaddexp2(self, other: Tensor) -> Tensor: ...
@overload
@@ -2045,12 +2658,16 @@ class _TensorBase(metaclass=_TensorMeta):
def logical_or_(self, other: Tensor) -> Tensor: ...
def logical_xor(self, other: Tensor) -> Tensor: ...
def logical_xor_(self, other: Tensor) -> Tensor: ...
- def logit(self, eps: Optional[_float]=None) -> Tensor: ...
- def logit_(self, eps: Optional[_float]=None) -> Tensor: ...
+ def logit(self, eps: Optional[_float] = None) -> Tensor: ...
+ def logit_(self, eps: Optional[_float] = None) -> Tensor: ...
@overload
- def logsumexp(self, dim: Union[_int, _size], keepdim: _bool=False) -> Tensor: ...
+ def logsumexp(
+ self, dim: Union[_int, _size], keepdim: _bool = False
+ ) -> Tensor: ...
@overload
- def logsumexp(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool=False) -> Tensor: ...
+ def logsumexp(
+ self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool = False
+ ) -> Tensor: ...
def long(self) -> Tensor: ...
@overload
def lt(self, other: Tensor) -> Tensor: ...
@@ -2082,36 +2699,64 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def max(self, other: Tensor) -> Tensor: ...
@overload
- def max(self, dim: _int, keepdim: _bool=False) -> torch.return_types.max: ...
+ def max(
+ self, dim: _int, keepdim: _bool = False
+ ) -> torch.return_types.max: ...
@overload
- def max(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.max: ...
+ def max(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> torch.return_types.max: ...
def maximum(self, other: Tensor) -> Tensor: ...
@overload
- def mean(self, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def mean(self, *, dtype: Optional[_dtype] = None) -> Tensor: ...
@overload
- def mean(self, dim: Optional[Union[_int, _size]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def mean(
+ self,
+ dim: Optional[Union[_int, _size]],
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
@overload
- def mean(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def mean(
+ self,
+ dim: Sequence[Union[str, ellipsis, None]],
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
@overload
def median(self) -> Tensor: ...
@overload
- def median(self, dim: _int, keepdim: _bool=False) -> torch.return_types.median: ...
+ def median(
+ self, dim: _int, keepdim: _bool = False
+ ) -> torch.return_types.median: ...
@overload
- def median(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.median: ...
+ def median(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> torch.return_types.median: ...
@overload
def min(self) -> Tensor: ...
@overload
def min(self, other: Tensor) -> Tensor: ...
@overload
- def min(self, dim: _int, keepdim: _bool=False) -> torch.return_types.min: ...
+ def min(
+ self, dim: _int, keepdim: _bool = False
+ ) -> torch.return_types.min: ...
@overload
- def min(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.min: ...
+ def min(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> torch.return_types.min: ...
def minimum(self, other: Tensor) -> Tensor: ...
def mm(self, mat2: Tensor) -> Tensor: ...
@overload
- def mode(self, dim: _int=-1, keepdim: _bool=False) -> torch.return_types.mode: ...
+ def mode(
+ self, dim: _int = -1, keepdim: _bool = False
+ ) -> torch.return_types.mode: ...
@overload
- def mode(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.mode: ...
+ def mode(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> torch.return_types.mode: ...
@overload
def moveaxis(self, source: _int, destination: _int) -> Tensor: ...
@overload
@@ -2121,9 +2766,22 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def movedim(self, source: _size, destination: _size) -> Tensor: ...
def msort(self) -> Tensor: ...
- def mul(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor]=None) -> Tensor: ...
- def mul_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: ...
- def multinomial(self, num_samples: _int, replacement: _bool=False, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def mul(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ out: Optional[Tensor] = None,
+ ) -> Tensor: ...
+ def mul_(
+ self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]
+ ) -> Tensor: ...
+ def multinomial(
+ self,
+ num_samples: _int,
+ replacement: _bool = False,
+ *,
+ generator: Optional[Generator] = None,
+ ) -> Tensor: ...
@overload
def multiply(self, other: Tensor) -> Tensor: ...
@overload
@@ -2135,25 +2793,71 @@ class _TensorBase(metaclass=_TensorMeta):
def mv(self, vec: Tensor) -> Tensor: ...
def mvlgamma(self, p: _int) -> Tensor: ...
def mvlgamma_(self, p: _int) -> Tensor: ...
- def nan_to_num(self, nan: Optional[_float]=None, posinf: Optional[_float]=None, neginf: Optional[_float]=None) -> Tensor: ...
- def nan_to_num_(self, nan: Optional[_float]=None, posinf: Optional[_float]=None, neginf: Optional[_float]=None) -> Tensor: ...
- def nanmean(self, dim: Optional[Union[_int, _size]]=None, keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def nan_to_num(
+ self,
+ nan: Optional[_float] = None,
+ posinf: Optional[_float] = None,
+ neginf: Optional[_float] = None,
+ ) -> Tensor: ...
+ def nan_to_num_(
+ self,
+ nan: Optional[_float] = None,
+ posinf: Optional[_float] = None,
+ neginf: Optional[_float] = None,
+ ) -> Tensor: ...
+ def nanmean(
+ self,
+ dim: Optional[Union[_int, _size]] = None,
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
@overload
def nanmedian(self) -> Tensor: ...
@overload
- def nanmedian(self, dim: _int, keepdim: _bool=False) -> torch.return_types.nanmedian: ...
- @overload
- def nanmedian(self, dim: Union[str, ellipsis, None], keepdim: _bool=False) -> torch.return_types.nanmedian: ...
- @overload
- def nanquantile(self, q: Tensor, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ...
+ def nanmedian(
+ self, dim: _int, keepdim: _bool = False
+ ) -> torch.return_types.nanmedian: ...
@overload
- def nanquantile(self, q: _float, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ...
- def nansum(self, dim: Optional[Union[_int, _size]]=None, keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def nanmedian(
+ self, dim: Union[str, ellipsis, None], keepdim: _bool = False
+ ) -> torch.return_types.nanmedian: ...
@overload
- def narrow(self, dim: _int, start: Tensor, length: Union[_int, SymInt]) -> Tensor: ...
- @overload
- def narrow(self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: ...
- def narrow_copy(self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]) -> Tensor: ...
+ def nanquantile(
+ self,
+ q: Tensor,
+ dim: Optional[_int] = None,
+ keepdim: _bool = False,
+ *,
+ interpolation: str = "linear",
+ ) -> Tensor: ...
+ @overload
+ def nanquantile(
+ self,
+ q: _float,
+ dim: Optional[_int] = None,
+ keepdim: _bool = False,
+ *,
+ interpolation: str = "linear",
+ ) -> Tensor: ...
+ def nansum(
+ self,
+ dim: Optional[Union[_int, _size]] = None,
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
+ @overload
+ def narrow(
+ self, dim: _int, start: Tensor, length: Union[_int, SymInt]
+ ) -> Tensor: ...
+ @overload
+ def narrow(
+ self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]
+ ) -> Tensor: ...
+ def narrow_copy(
+ self, dim: _int, start: Union[_int, SymInt], length: Union[_int, SymInt]
+ ) -> Tensor: ...
def ndimension(self) -> _int: ...
@overload
def ne(self, other: Tensor) -> Tensor: ...
@@ -2169,37 +2873,126 @@ class _TensorBase(metaclass=_TensorMeta):
def negative_(self) -> Tensor: ...
def nelement(self) -> _int: ...
@overload
- def new(self, *args: Any, device: Device=None) ->Tensor: ...
+ def new(self, *args: Any, device: Device = None) -> Tensor: ...
@overload
def new(self, storage: Storage) -> Tensor: ...
@overload
def new(self, other: Tensor) -> Tensor: ...
@overload
- def new(self, size: _size, *, device: Device=None) -> Tensor: ...
- @overload
- def new_empty(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
- @overload
- def new_empty(self, *size: _int, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
- def new_empty_strided(self, size: Sequence[Union[_int, SymInt]], stride: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
- def new_full(self, size: Sequence[Union[_int, SymInt]], fill_value: Number, *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
- @overload
- def new_ones(self, size: _size, dtype: Optional[_dtype]=None, device: Device=None, requires_grad: _bool=False) -> Tensor: ...
+ def new(self, size: _size, *, device: Device = None) -> Tensor: ...
@overload
- def new_ones(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
- @overload
- def new_ones(self, *size: _int, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
- def new_tensor(self, data: Any, dtype: Optional[_dtype]=None, device: Device=None, requires_grad: _bool=False) -> Tensor: ...
+ def new_empty(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ *,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ @overload
+ def new_empty(
+ self,
+ *size: _int,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ def new_empty_strided(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ stride: Sequence[Union[_int, SymInt]],
+ *,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ def new_full(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ fill_value: Number,
+ *,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ @overload
+ def new_ones(
+ self,
+ size: _size,
+ dtype: Optional[_dtype] = None,
+ device: Device = None,
+ requires_grad: _bool = False,
+ ) -> Tensor: ...
@overload
- def new_zeros(self, size: Sequence[Union[_int, SymInt]], *, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
+ def new_ones(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ *,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ @overload
+ def new_ones(
+ self,
+ *size: _int,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ def new_tensor(
+ self,
+ data: Any,
+ dtype: Optional[_dtype] = None,
+ device: Device = None,
+ requires_grad: _bool = False,
+ ) -> Tensor: ...
@overload
- def new_zeros(self, *size: _int, dtype: Optional[_dtype]=None, layout: Optional[_layout]=None, device: Optional[Union[_device, str, None]]=None, pin_memory: Optional[_bool]=False, requires_grad: Optional[_bool]=False) -> Tensor: ...
+ def new_zeros(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ *,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
+ @overload
+ def new_zeros(
+ self,
+ *size: _int,
+ dtype: Optional[_dtype] = None,
+ layout: Optional[_layout] = None,
+ device: Optional[Union[_device, str, None]] = None,
+ pin_memory: Optional[_bool] = False,
+ requires_grad: Optional[_bool] = False,
+ ) -> Tensor: ...
def nextafter(self, other: Tensor) -> Tensor: ...
def nextafter_(self, other: Tensor) -> Tensor: ...
@overload
- def nonzero(self, *, as_tuple: Literal[False]=False) -> Tensor: ...
+ def nonzero(self, *, as_tuple: Literal[False] = False) -> Tensor: ...
@overload
def nonzero(self, *, as_tuple: Literal[True]) -> Tuple[Tensor, ...]: ...
- def normal_(self, mean: _float=0, std: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def normal_(
+ self,
+ mean: _float = 0,
+ std: _float = 1,
+ *,
+ generator: Optional[Generator] = None,
+ ) -> Tensor: ...
@overload
def not_equal(self, other: Tensor) -> Tensor: ...
@overload
@@ -2209,16 +3002,24 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def not_equal_(self, other: Number) -> Tensor: ...
def numel(self) -> _int: ...
- def numpy(self, *, force: _bool=False) -> Any: ...
+ def numpy(self, *, force: _bool = False) -> Any: ...
def orgqr(self, input2: Tensor) -> Tensor: ...
- def ormqr(self, input2: Tensor, input3: Tensor, left: _bool=True, transpose: _bool=False) -> Tensor: ...
+ def ormqr(
+ self,
+ input2: Tensor,
+ input3: Tensor,
+ left: _bool = True,
+ transpose: _bool = False,
+ ) -> Tensor: ...
def outer(self, vec2: Tensor) -> Tensor: ...
@overload
def permute(self, dims: _size) -> Tensor: ...
@overload
def permute(self, *dims: _int) -> Tensor: ...
- def pin_memory(self, device: Optional[Union[_device, str, None]]=None) -> Tensor: ...
- def pinverse(self, rcond: _float=1e-15) -> Tensor: ...
+ def pin_memory(
+ self, device: Optional[Union[_device, str, None]] = None
+ ) -> Tensor: ...
+ def pinverse(self, rcond: _float = 1e-15) -> Tensor: ...
def polygamma(self, n: _int) -> Tensor: ...
def polygamma_(self, n: _int) -> Tensor: ...
def positive(self) -> Tensor: ...
@@ -2232,37 +3033,77 @@ class _TensorBase(metaclass=_TensorMeta):
def pow_(self, exponent: Number) -> Tensor: ...
def prelu(self, weight: Tensor) -> Tensor: ...
@overload
- def prod(self, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def prod(self, *, dtype: Optional[_dtype] = None) -> Tensor: ...
@overload
- def prod(self, dim: _int, keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def prod(
+ self,
+ dim: _int,
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
@overload
- def prod(self, dim: Union[str, ellipsis, None], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
- def put(self, index: Tensor, source: Tensor, accumulate: _bool=False) -> Tensor: ...
- def put_(self, index: Tensor, source: Tensor, accumulate: _bool=False) -> Tensor: ...
+ def prod(
+ self,
+ dim: Union[str, ellipsis, None],
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
+ def put(
+ self, index: Tensor, source: Tensor, accumulate: _bool = False
+ ) -> Tensor: ...
+ def put_(
+ self, index: Tensor, source: Tensor, accumulate: _bool = False
+ ) -> Tensor: ...
def q_per_channel_axis(self) -> _int: ...
def q_per_channel_scales(self) -> Tensor: ...
def q_per_channel_zero_points(self) -> Tensor: ...
def q_scale(self) -> _float: ...
def q_zero_point(self) -> _int: ...
- def qr(self, some: _bool=True) -> torch.return_types.qr: ...
+ def qr(self, some: _bool = True) -> torch.return_types.qr: ...
def qscheme(self) -> _qscheme: ...
@overload
- def quantile(self, q: Tensor, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ...
- @overload
- def quantile(self, q: _float, dim: Optional[_int]=None, keepdim: _bool=False, *, interpolation: str="linear") -> Tensor: ...
+ def quantile(
+ self,
+ q: Tensor,
+ dim: Optional[_int] = None,
+ keepdim: _bool = False,
+ *,
+ interpolation: str = "linear",
+ ) -> Tensor: ...
+ @overload
+ def quantile(
+ self,
+ q: _float,
+ dim: Optional[_int] = None,
+ keepdim: _bool = False,
+ *,
+ interpolation: str = "linear",
+ ) -> Tensor: ...
def rad2deg(self) -> Tensor: ...
def rad2deg_(self) -> Tensor: ...
@overload
- def random_(self, *, generator: Optional[Generator]=None) -> Tensor: ...
- @overload
- def random_(self, from_: _int, to: Optional[_int], *, generator: Optional[Generator]=None) -> Tensor: ...
+ def random_(self, *, generator: Optional[Generator] = None) -> Tensor: ...
@overload
- def random_(self, to: _int, *, generator: Optional[Generator]=None) -> Tensor: ...
+ def random_(
+ self,
+ from_: _int,
+ to: Optional[_int],
+ *,
+ generator: Optional[Generator] = None,
+ ) -> Tensor: ...
+ @overload
+ def random_(
+ self, to: _int, *, generator: Optional[Generator] = None
+ ) -> Tensor: ...
def ravel(self) -> Tensor: ...
def reciprocal(self) -> Tensor: ...
def reciprocal_(self) -> Tensor: ...
def record_stream(self, s: Stream) -> None: ...
- def refine_names(self, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ...
+ def refine_names(
+ self, names: Sequence[Union[str, ellipsis, None]]
+ ) -> Tensor: ...
def relu(self) -> Tensor: ...
def relu_(self) -> Tensor: ...
@overload
@@ -2273,8 +3114,12 @@ class _TensorBase(metaclass=_TensorMeta):
def remainder_(self, other: Tensor) -> Tensor: ...
@overload
def remainder_(self, other: Number) -> Tensor: ...
- def rename(self, names: Optional[Sequence[Union[str, ellipsis, None]]]) -> Tensor: ...
- def rename_(self, names: Optional[Sequence[Union[str, ellipsis, None]]]) -> Tensor: ...
+ def rename(
+ self, names: Optional[Sequence[Union[str, ellipsis, None]]]
+ ) -> Tensor: ...
+ def rename_(
+ self, names: Optional[Sequence[Union[str, ellipsis, None]]]
+ ) -> Tensor: ...
def renorm(self, p: Number, dim: _int, maxnorm: Number) -> Tensor: ...
def renorm_(self, p: Number, dim: _int, maxnorm: Number) -> Tensor: ...
@overload
@@ -2282,26 +3127,52 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def repeat(self, *repeats: _int) -> Tensor: ...
@overload
- def repeat_interleave(self, repeats: Tensor, dim: Optional[_int]=None, *, output_size: Optional[_int]=None) -> Tensor: ...
+ def repeat_interleave(
+ self,
+ repeats: Tensor,
+ dim: Optional[_int] = None,
+ *,
+ output_size: Optional[_int] = None,
+ ) -> Tensor: ...
@overload
- def repeat_interleave(self, repeats: Union[_int, SymInt], dim: Optional[_int]=None, *, output_size: Optional[_int]=None) -> Tensor: ...
- def requires_grad_(self, mode: _bool=True) -> Tensor: ...
+ def repeat_interleave(
+ self,
+ repeats: Union[_int, SymInt],
+ dim: Optional[_int] = None,
+ *,
+ output_size: Optional[_int] = None,
+ ) -> Tensor: ...
+ def requires_grad_(self, mode: _bool = True) -> Tensor: ...
@overload
def reshape(self, shape: Sequence[Union[_int, SymInt]]) -> Tensor: ...
@overload
def reshape(self, *shape: _int) -> Tensor: ...
def reshape_as(self, other: Tensor) -> Tensor: ...
@overload
- def resize_(self, size: Sequence[Union[_int, SymInt]], *, memory_format: Optional[memory_format]=None) -> Tensor: ...
- @overload
- def resize_(self, *size: _int, memory_format: Optional[memory_format]=None) -> Tensor: ...
- def resize_as_(self, the_template: Tensor, *, memory_format: Optional[memory_format]=None) -> Tensor: ...
+ def resize_(
+ self,
+ size: Sequence[Union[_int, SymInt]],
+ *,
+ memory_format: Optional[memory_format] = None,
+ ) -> Tensor: ...
+ @overload
+ def resize_(
+ self, *size: _int, memory_format: Optional[memory_format] = None
+ ) -> Tensor: ...
+ def resize_as_(
+ self,
+ the_template: Tensor,
+ *,
+ memory_format: Optional[memory_format] = None,
+ ) -> Tensor: ...
def resize_as_sparse_(self, the_template: Tensor) -> Tensor: ...
def resolve_conj(self) -> Tensor: ...
def resolve_neg(self) -> Tensor: ...
def retain_grad(self) -> None: ...
- def roll(self, shifts: Union[_int, _size], dims: Union[_int, _size]=()) -> Tensor: ...
- def rot90(self, k: _int=1, dims: _size=(0,1)) -> Tensor: ...
+ def roll(
+ self, shifts: Union[_int, _size], dims: Union[_int, _size] = ()
+ ) -> Tensor: ...
+ def rot90(self, k: _int = 1, dims: _size = (0, 1)) -> Tensor: ...
@overload
def round(self) -> Tensor: ...
@overload
@@ -2316,37 +3187,77 @@ class _TensorBase(metaclass=_TensorMeta):
@overload
def scatter(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload
- def scatter(self, dim: _int, index: Tensor, src: Tensor, *, reduce: str) -> Tensor: ...
+ def scatter(
+ self, dim: _int, index: Tensor, src: Tensor, *, reduce: str
+ ) -> Tensor: ...
@overload
- def scatter(self, dim: _int, index: Tensor, value: Number, *, reduce: str) -> Tensor: ...
+ def scatter(
+ self, dim: _int, index: Tensor, value: Number, *, reduce: str
+ ) -> Tensor: ...
@overload
- def scatter(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ...
+ def scatter(
+ self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor
+ ) -> Tensor: ...
@overload
def scatter(self, dim: _int, index: Tensor, value: Number) -> Tensor: ...
@overload
- def scatter(self, dim: Union[str, ellipsis, None], index: Tensor, value: Number) -> Tensor: ...
+ def scatter(
+ self, dim: Union[str, ellipsis, None], index: Tensor, value: Number
+ ) -> Tensor: ...
@overload
def scatter_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload
- def scatter_(self, dim: _int, index: Tensor, src: Tensor, *, reduce: str) -> Tensor: ...
+ def scatter_(
+ self, dim: _int, index: Tensor, src: Tensor, *, reduce: str
+ ) -> Tensor: ...
@overload
- def scatter_(self, dim: _int, index: Tensor, value: Number, *, reduce: str) -> Tensor: ...
+ def scatter_(
+ self, dim: _int, index: Tensor, value: Number, *, reduce: str
+ ) -> Tensor: ...
@overload
def scatter_(self, dim: _int, index: Tensor, value: Number) -> Tensor: ...
@overload
def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload
- def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ...
+ def scatter_add(
+ self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor
+ ) -> Tensor: ...
def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
- def scatter_reduce(self, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ...
- def scatter_reduce_(self, dim: _int, index: Tensor, src: Tensor, reduce: str, *, include_self: _bool=True) -> Tensor: ...
+ def scatter_reduce(
+ self,
+ dim: _int,
+ index: Tensor,
+ src: Tensor,
+ reduce: str,
+ *,
+ include_self: _bool = True,
+ ) -> Tensor: ...
+ def scatter_reduce_(
+ self,
+ dim: _int,
+ index: Tensor,
+ src: Tensor,
+ reduce: str,
+ *,
+ include_self: _bool = True,
+ ) -> Tensor: ...
@overload
def select(self, dim: _int, index: Union[_int, SymInt]) -> Tensor: ...
@overload
- def select(self, dim: Union[str, ellipsis, None], index: _int) -> Tensor: ...
- def select_scatter(self, src: Tensor, dim: _int, index: Union[_int, SymInt]) -> Tensor: ...
+ def select(
+ self, dim: Union[str, ellipsis, None], index: _int
+ ) -> Tensor: ...
+ def select_scatter(
+ self, src: Tensor, dim: _int, index: Union[_int, SymInt]
+ ) -> Tensor: ...
@overload
- def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...
+ def set_(
+ self,
+ storage: Union[Storage, TypedStorage],
+ offset: _int,
+ size: _size,
+ stride: _size,
+ ) -> Tensor: ...
@overload
def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...
def sgn(self) -> Tensor: ...
@@ -2367,30 +3278,63 @@ class _TensorBase(metaclass=_TensorMeta):
def size(self) -> Size: ...
@overload
def size(self, dim: _int) -> _int: ...
- def slice_scatter(self, src: Tensor, dim: _int=0, start: Optional[Union[_int, SymInt]]=None, end: Optional[Union[_int, SymInt]]=None, step: Union[_int, SymInt]=1) -> Tensor: ...
+ def slice_scatter(
+ self,
+ src: Tensor,
+ dim: _int = 0,
+ start: Optional[Union[_int, SymInt]] = None,
+ end: Optional[Union[_int, SymInt]] = None,
+ step: Union[_int, SymInt] = 1,
+ ) -> Tensor: ...
def slogdet(self) -> torch.return_types.slogdet: ...
def smm(self, mat2: Tensor) -> Tensor: ...
@overload
- def softmax(self, dim: _int, dtype: Optional[_dtype]=None) -> Tensor: ...
- @overload
- def softmax(self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def softmax(self, dim: _int, dtype: Optional[_dtype] = None) -> Tensor: ...
@overload
- def sort(self, *, stable: Optional[_bool], dim: _int=-1, descending: _bool=False) -> torch.return_types.sort: ...
+ def softmax(
+ self, dim: Union[str, ellipsis, None], *, dtype: Optional[_dtype] = None
+ ) -> Tensor: ...
@overload
- def sort(self, dim: _int=-1, descending: _bool=False) -> torch.return_types.sort: ...
+ def sort(
+ self,
+ *,
+ stable: Optional[_bool],
+ dim: _int = -1,
+ descending: _bool = False,
+ ) -> torch.return_types.sort: ...
@overload
- def sort(self, *, stable: Optional[_bool], dim: Union[str, ellipsis, None], descending: _bool=False) -> torch.return_types.sort: ...
+ def sort(
+ self, dim: _int = -1, descending: _bool = False
+ ) -> torch.return_types.sort: ...
@overload
- def sort(self, dim: Union[str, ellipsis, None], descending: _bool=False) -> torch.return_types.sort: ...
+ def sort(
+ self,
+ *,
+ stable: Optional[_bool],
+ dim: Union[str, ellipsis, None],
+ descending: _bool = False,
+ ) -> torch.return_types.sort: ...
+ @overload
+ def sort(
+ self, dim: Union[str, ellipsis, None], descending: _bool = False
+ ) -> torch.return_types.sort: ...
def sparse_dim(self) -> _int: ...
def sparse_mask(self, mask: Tensor) -> Tensor: ...
- def sparse_resize_(self, size: _size, sparse_dim: _int, dense_dim: _int) -> Tensor: ...
- def sparse_resize_and_clear_(self, size: _size, sparse_dim: _int, dense_dim: _int) -> Tensor: ...
- @overload
- def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...
- @overload
- def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...
- def split_with_sizes(self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int=0) -> List[Tensor]: ...
+ def sparse_resize_(
+ self, size: _size, sparse_dim: _int, dense_dim: _int
+ ) -> Tensor: ...
+ def sparse_resize_and_clear_(
+ self, size: _size, sparse_dim: _int, dense_dim: _int
+ ) -> Tensor: ...
+ @overload
+ def split(self, split_size: _int, dim: _int = 0) -> Sequence[Tensor]: ...
+ @overload
+ def split(
+ self, split_size: Tuple[_int, ...], dim: _int = 0
+ ) -> Sequence[Tensor]: ...
+ def split_with_sizes(
+ self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0
+ ) -> List[Tensor]: ...
def sqrt(self) -> Tensor: ...
def sqrt_(self) -> Tensor: ...
def square(self) -> Tensor: ...
@@ -2415,17 +3359,41 @@ class _TensorBase(metaclass=_TensorMeta):
def squeeze_(self, *dim: _int) -> Tensor: ...
@overload
def squeeze_(self, dim: Union[str, ellipsis, None]) -> Tensor: ...
- def sspaddmm(self, mat1: Tensor, mat2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ...
+ def sspaddmm(
+ self, mat1: Tensor, mat2: Tensor, *, beta: Number = 1, alpha: Number = 1
+ ) -> Tensor: ...
@overload
- def std(self, dim: Optional[Union[_int, _size]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ...
+ def std(
+ self,
+ dim: Optional[Union[_int, _size]],
+ unbiased: _bool = True,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
@overload
- def std(self, dim: Optional[Union[_int, _size]]=None, *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
+ def std(
+ self,
+ dim: Optional[Union[_int, _size]] = None,
+ *,
+ correction: Optional[_int] = None,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
@overload
- def std(self, unbiased: _bool=True) -> Tensor: ...
+ def std(self, unbiased: _bool = True) -> Tensor: ...
@overload
- def std(self, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ...
+ def std(
+ self,
+ dim: Sequence[Union[str, ellipsis, None]],
+ unbiased: _bool = True,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
@overload
- def std(self, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
+ def std(
+ self,
+ dim: Sequence[Union[str, ellipsis, None]],
+ *,
+ correction: Optional[_int] = None,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
def untyped_storage(self) -> Storage: ...
def storage_offset(self) -> _int: ...
def storage_type(self) -> Storage: ...
@@ -2433,27 +3401,52 @@ class _TensorBase(metaclass=_TensorMeta):
def stride(self) -> Tuple[_int, ...]: ...
@overload
def stride(self, _int) -> _int: ...
- def sub(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1, out: Optional[Tensor]=None) -> Tensor: ...
- def sub_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, alpha: Optional[Number]=1) -> Tensor: ...
+ def sub(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ alpha: Optional[Number] = 1,
+ out: Optional[Tensor] = None,
+ ) -> Tensor: ...
+ def sub_(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ alpha: Optional[Number] = 1,
+ ) -> Tensor: ...
@overload
- def subtract(self, other: Tensor, *, alpha: Number=1) -> Tensor: ...
+ def subtract(self, other: Tensor, *, alpha: Number = 1) -> Tensor: ...
@overload
- def subtract(self, other: Number, alpha: Number=1) -> Tensor: ...
+ def subtract(self, other: Number, alpha: Number = 1) -> Tensor: ...
@overload
- def subtract_(self, other: Tensor, *, alpha: Number=1) -> Tensor: ...
+ def subtract_(self, other: Tensor, *, alpha: Number = 1) -> Tensor: ...
@overload
- def subtract_(self, other: Number, alpha: Number=1) -> Tensor: ...
+ def subtract_(self, other: Number, alpha: Number = 1) -> Tensor: ...
@overload
- def sum(self, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def sum(self, *, dtype: Optional[_dtype] = None) -> Tensor: ...
@overload
- def sum(self, dim: Optional[Union[_int, _size]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def sum(
+ self,
+ dim: Optional[Union[_int, _size]],
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
@overload
- def sum(self, dim: Sequence[Union[str, ellipsis, None]], keepdim: _bool=False, *, dtype: Optional[_dtype]=None) -> Tensor: ...
+ def sum(
+ self,
+ dim: Sequence[Union[str, ellipsis, None]],
+ keepdim: _bool = False,
+ *,
+ dtype: Optional[_dtype] = None,
+ ) -> Tensor: ...
@overload
def sum_to_size(self, size: _size) -> Tensor: ...
@overload
def sum_to_size(self, *size: _int) -> Tensor: ...
- def svd(self, some: _bool=True, compute_uv: _bool=True) -> torch.return_types.svd: ...
+ def svd(
+ self, some: _bool = True, compute_uv: _bool = True
+ ) -> torch.return_types.svd: ...
def swapaxes(self, axis0: _int, axis1: _int) -> Tensor: ...
def swapaxes_(self, axis0: _int, axis1: _int) -> Tensor: ...
def swapdims(self, dim0: _int, dim1: _int) -> Tensor: ...
@@ -2461,86 +3454,178 @@ class _TensorBase(metaclass=_TensorMeta):
def t(self) -> Tensor: ...
def t_(self) -> Tensor: ...
def take(self, index: Tensor) -> Tensor: ...
- def take_along_dim(self, indices: Tensor, dim: Optional[_int]=None) -> Tensor: ...
+ def take_along_dim(
+ self, indices: Tensor, dim: Optional[_int] = None
+ ) -> Tensor: ...
def tan(self) -> Tensor: ...
def tan_(self) -> Tensor: ...
def tanh(self) -> Tensor: ...
def tanh_(self) -> Tensor: ...
@overload
- def tensor_split(self, indices: Sequence[Union[_int, SymInt]], dim: _int=0) -> List[Tensor]: ...
+ def tensor_split(
+ self, indices: Sequence[Union[_int, SymInt]], dim: _int = 0
+ ) -> List[Tensor]: ...
@overload
- def tensor_split(self, tensor_indices_or_sections: Tensor, dim: _int=0) -> List[Tensor]: ...
+ def tensor_split(
+ self, tensor_indices_or_sections: Tensor, dim: _int = 0
+ ) -> List[Tensor]: ...
@overload
- def tensor_split(self, sections: Union[_int, SymInt], dim: _int=0) -> List[Tensor]: ...
+ def tensor_split(
+ self, sections: Union[_int, SymInt], dim: _int = 0
+ ) -> List[Tensor]: ...
@overload
def tile(self, dims: _size) -> Tensor: ...
@overload
def tile(self, *dims: _int) -> Tensor: ...
@overload
- def to(self, dtype: _dtype, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...
+ def to(
+ self, dtype: _dtype, non_blocking: _bool = False, copy: _bool = False
+ ) -> Tensor: ...
@overload
- def to(self, device: Optional[Union[_device, str]]=None, dtype: Optional[_dtype]=None, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...
- @overload
- def to(self, other: Tensor, non_blocking: _bool=False, copy: _bool=False) -> Tensor: ...
- def to_dense(self, dtype: Optional[_dtype]=None) -> Tensor: ...
- def to_mkldnn(self, dtype: Optional[_dtype]=None) -> Tensor: ...
- def to_padded_tensor(self, padding: _float, output_size: Optional[Sequence[Union[_int, SymInt]]]=None) -> Tensor: ...
+ def to(
+ self,
+ device: Optional[Union[_device, str]] = None,
+ dtype: Optional[_dtype] = None,
+ non_blocking: _bool = False,
+ copy: _bool = False,
+ ) -> Tensor: ...
+ @overload
+ def to(
+ self, other: Tensor, non_blocking: _bool = False, copy: _bool = False
+ ) -> Tensor: ...
+ def to_dense(self, dtype: Optional[_dtype] = None) -> Tensor: ...
+ def to_mkldnn(self, dtype: Optional[_dtype] = None) -> Tensor: ...
+ def to_padded_tensor(
+ self,
+ padding: _float,
+ output_size: Optional[Sequence[Union[_int, SymInt]]] = None,
+ ) -> Tensor: ...
@overload
- def to_sparse(self, *, layout: Optional[_layout]=None, blocksize: Optional[Union[_int, _size]]=None, dense_dim: Optional[_int]=None) -> Tensor: ...
+ def to_sparse(
+ self,
+ *,
+ layout: Optional[_layout] = None,
+ blocksize: Optional[Union[_int, _size]] = None,
+ dense_dim: Optional[_int] = None,
+ ) -> Tensor: ...
@overload
def to_sparse(self, sparse_dim: _int) -> Tensor: ...
- def to_sparse_bsc(self, blocksize: Union[_int, _size], dense_dim: Optional[_int]=None) -> Tensor: ...
- def to_sparse_bsr(self, blocksize: Union[_int, _size], dense_dim: Optional[_int]=None) -> Tensor: ...
- def to_sparse_csc(self, dense_dim: Optional[_int]=None) -> Tensor: ...
- def to_sparse_csr(self, dense_dim: Optional[_int]=None) -> Tensor: ...
+ def to_sparse_bsc(
+ self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None
+ ) -> Tensor: ...
+ def to_sparse_bsr(
+ self, blocksize: Union[_int, _size], dense_dim: Optional[_int] = None
+ ) -> Tensor: ...
+ def to_sparse_csc(self, dense_dim: Optional[_int] = None) -> Tensor: ...
+ def to_sparse_csr(self, dense_dim: Optional[_int] = None) -> Tensor: ...
def tolist(self) -> List: ...
- def topk(self, k: _int, dim: _int=-1, largest: _bool=True, sorted: _bool=True) -> torch.return_types.topk: ...
+ def topk(
+ self,
+ k: _int,
+ dim: _int = -1,
+ largest: _bool = True,
+ sorted: _bool = True,
+ ) -> torch.return_types.topk: ...
def trace(self) -> Tensor: ...
@overload
def transpose(self, dim0: _int, dim1: _int) -> Tensor: ...
@overload
- def transpose(self, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None]) -> Tensor: ...
+ def transpose(
+ self, dim0: Union[str, ellipsis, None], dim1: Union[str, ellipsis, None]
+ ) -> Tensor: ...
def transpose_(self, dim0: _int, dim1: _int) -> Tensor: ...
- def triangular_solve(self, A: Tensor, upper: _bool=True, transpose: _bool=False, unitriangular: _bool=False) -> torch.return_types.triangular_solve: ...
- def tril(self, diagonal: _int=0) -> Tensor: ...
- def tril_(self, diagonal: _int=0) -> Tensor: ...
- def triu(self, diagonal: _int=0) -> Tensor: ...
- def triu_(self, diagonal: _int=0) -> Tensor: ...
- def true_divide(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat], *, out: Optional[Tensor]=None) -> Tensor: ...
- def true_divide_(self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]) -> Tensor: ...
+ def triangular_solve(
+ self,
+ A: Tensor,
+ upper: _bool = True,
+ transpose: _bool = False,
+ unitriangular: _bool = False,
+ ) -> torch.return_types.triangular_solve: ...
+ def tril(self, diagonal: _int = 0) -> Tensor: ...
+ def tril_(self, diagonal: _int = 0) -> Tensor: ...
+ def triu(self, diagonal: _int = 0) -> Tensor: ...
+ def triu_(self, diagonal: _int = 0) -> Tensor: ...
+ def true_divide(
+ self,
+ other: Union[Tensor, Number, torch.SymInt, torch.SymFloat],
+ *,
+ out: Optional[Tensor] = None,
+ ) -> Tensor: ...
+ def true_divide_(
+ self, other: Union[Tensor, Number, torch.SymInt, torch.SymFloat]
+ ) -> Tensor: ...
def trunc(self) -> Tensor: ...
def trunc_(self) -> Tensor: ...
@overload
- def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...
+ def type(self, dtype: None = None, non_blocking: _bool = False) -> str: ...
@overload
- def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...
+ def type(
+ self, dtype: Union[str, _dtype], non_blocking: _bool = False
+ ) -> Tensor: ...
def type_as(self, other: Tensor) -> Tensor: ...
@overload
- def unbind(self, dim: _int=0) -> List[Tensor]: ...
+ def unbind(self, dim: _int = 0) -> List[Tensor]: ...
@overload
def unbind(self, dim: Union[str, ellipsis, None]) -> List[Tensor]: ...
@overload
- def unflatten(self, dim: Union[str, ellipsis, None], sizes: _size, names: Sequence[Union[str, ellipsis, None]]) -> Tensor: ...
+ def unflatten(
+ self,
+ dim: Union[str, ellipsis, None],
+ sizes: _size,
+ names: Sequence[Union[str, ellipsis, None]],
+ ) -> Tensor: ...
@overload
def unflatten(self, dim: _int, sizes: _size) -> Tensor: ...
def unfold(self, dimension: _int, size: _int, step: _int) -> Tensor: ...
- def uniform_(self, from_: _float=0, to: _float=1, *, generator: Optional[Generator]=None) -> Tensor: ...
- def unsafe_chunk(self, chunks: _int, dim: _int=0) -> List[Tensor]: ...
- def unsafe_split(self, split_size: Union[_int, SymInt], dim: _int=0) -> List[Tensor]: ...
- def unsafe_split_with_sizes(self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int=0) -> List[Tensor]: ...
+ def uniform_(
+ self,
+ from_: _float = 0,
+ to: _float = 1,
+ *,
+ generator: Optional[Generator] = None,
+ ) -> Tensor: ...
+ def unsafe_chunk(self, chunks: _int, dim: _int = 0) -> List[Tensor]: ...
+ def unsafe_split(
+ self, split_size: Union[_int, SymInt], dim: _int = 0
+ ) -> List[Tensor]: ...
+ def unsafe_split_with_sizes(
+ self, split_sizes: Sequence[Union[_int, SymInt]], dim: _int = 0
+ ) -> List[Tensor]: ...
def unsqueeze(self, dim: _int) -> Tensor: ...
def unsqueeze_(self, dim: _int) -> Tensor: ...
def values(self) -> Tensor: ...
@overload
- def var(self, dim: Optional[Union[_int, _size]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ...
+ def var(
+ self,
+ dim: Optional[Union[_int, _size]],
+ unbiased: _bool = True,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
@overload
- def var(self, dim: Optional[Union[_int, _size]]=None, *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
+ def var(
+ self,
+ dim: Optional[Union[_int, _size]] = None,
+ *,
+ correction: Optional[_int] = None,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
@overload
- def var(self, unbiased: _bool=True) -> Tensor: ...
+ def var(self, unbiased: _bool = True) -> Tensor: ...
@overload
- def var(self, dim: Sequence[Union[str, ellipsis, None]], unbiased: _bool=True, keepdim: _bool=False) -> Tensor: ...
+ def var(
+ self,
+ dim: Sequence[Union[str, ellipsis, None]],
+ unbiased: _bool = True,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
@overload
- def var(self, dim: Sequence[Union[str, ellipsis, None]], *, correction: Optional[_int]=None, keepdim: _bool=False) -> Tensor: ...
+ def var(
+ self,
+ dim: Sequence[Union[str, ellipsis, None]],
+ *,
+ correction: Optional[_int] = None,
+ keepdim: _bool = False,
+ ) -> Tensor: ...
def vdot(self, other: Tensor) -> Tensor: ...
@overload
def view(self, dtype: _dtype) -> Tensor: ...
@@ -2599,10 +3684,14 @@ def _cuda_synchronize() -> None: ...
def _cuda_ipc_collect() -> None: ...
def _cuda_getArchFlags() -> Optional[str]: ...
def _cuda_init() -> None: ...
-def _cuda_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ...
+def _cuda_setStream(
+ stream_id: _int, device_index: _int, device_type: _int
+) -> None: ...
def _cuda_getCompiledVersion() -> _int: ...
def _cuda_cudaHostAllocator() -> _int: ...
-def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
+def _cuda_cudaCachingAllocator_raw_alloc(
+ size: _int, cuda_stream: _int
+) -> _int: ...
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
def _cuda_setMemoryFraction(fraction: _float, device: _int) -> None: ...
@@ -2611,56 +3700,74 @@ def _cuda_memoryStats(device: _int) -> Dict[str, Any]: ...
def _cuda_resetAccumulatedMemoryStats(device: _int) -> None: ...
def _cuda_resetPeakMemoryStats(device: _int) -> None: ...
def _cuda_memorySnapshot() -> Dict[str, Any]: ...
-def _cuda_recordMemoryHistory(enabled: _bool, record_context: _bool, record_context_cpp: _bool, alloc_trace_max_entries: _int, alloc_trace_record_context: _bool) -> None: ...
+def _cuda_recordMemoryHistory(
+ enabled: _bool,
+ record_context: _bool,
+ record_context_cpp: _bool,
+ alloc_trace_max_entries: _int,
+ alloc_trace_record_context: _bool,
+) -> None: ...
def _cuda_getAllocatorBackend() -> str: ...
-class _cuda_CUDAAllocator:
- ...
+class _cuda_CUDAAllocator: ...
-def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ...
+def _cuda_customAllocator(
+ alloc_fn: _int, free_fn: _int
+) -> _cuda_CUDAAllocator: ...
def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ...
def _cuda_getAllocator() -> _cuda_CUDAAllocator: ...
def _cuda_lock_mutex() -> None: ...
def _cuda_unlock_mutex() -> None: ...
def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ...
-def _cuda_jiterator_compile_and_launch_kernel(code_string: str,
- kernel_name: str,
- return_by_ref: _bool,
- num_outputs: _int,
- tensors: Tuple,
- kwargs: Dict[str, Union[_int, _float, _bool]]) -> Tensor: ...
+def _cuda_jiterator_compile_and_launch_kernel(
+ code_string: str,
+ kernel_name: str,
+ return_by_ref: _bool,
+ num_outputs: _int,
+ tensors: Tuple,
+ kwargs: Dict[str, Union[_int, _float, _bool]],
+) -> Tensor: ...
def _cuda_get_cudnn_benchmark_limit() -> _int: ...
def _cuda_set_cudnn_benchmark_limit(arg: _int) -> None: ...
def _nccl_version() -> _int: ...
def _nccl_unique_id() -> bytes: ...
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
-def _nccl_reduce(input: Sequence[Tensor],
- output: Tensor,
- root: _int,
- op: _int,
- streams: Optional[Sequence[_CudaStreamBase]],
- comms: Optional[Sequence[object]]) -> None: ...
-def _nccl_all_reduce(input: Sequence[Tensor],
- output: Sequence[Tensor],
- op: _int,
- streams: Optional[Sequence[_CudaStreamBase]],
- comms: Optional[Sequence[object]]) -> None: ...
-def _nccl_broadcast(input: Sequence[Tensor],
- root: _int,
- streams: Optional[Sequence[_CudaStreamBase]],
- comms: Optional[Sequence[object]]) -> None: ...
-def _nccl_all_gather(input: Sequence[Tensor],
- output: Sequence[Tensor],
- streams: Optional[Sequence[_CudaStreamBase]],
- comms: Optional[Sequence[object]]) -> None: ...
-def _nccl_reduce_scatter(input: Sequence[Tensor],
- output: Sequence[Tensor],
- op: _int,
- streams: Optional[Sequence[_CudaStreamBase]],
- comms: Optional[Sequence[object]]) -> None: ...
+def _nccl_reduce(
+ input: Sequence[Tensor],
+ output: Tensor,
+ root: _int,
+ op: _int,
+ streams: Optional[Sequence[_CudaStreamBase]],
+ comms: Optional[Sequence[object]],
+) -> None: ...
+def _nccl_all_reduce(
+ input: Sequence[Tensor],
+ output: Sequence[Tensor],
+ op: _int,
+ streams: Optional[Sequence[_CudaStreamBase]],
+ comms: Optional[Sequence[object]],
+) -> None: ...
+def _nccl_broadcast(
+ input: Sequence[Tensor],
+ root: _int,
+ streams: Optional[Sequence[_CudaStreamBase]],
+ comms: Optional[Sequence[object]],
+) -> None: ...
+def _nccl_all_gather(
+ input: Sequence[Tensor],
+ output: Sequence[Tensor],
+ streams: Optional[Sequence[_CudaStreamBase]],
+ comms: Optional[Sequence[object]],
+) -> None: ...
+def _nccl_reduce_scatter(
+ input: Sequence[Tensor],
+ output: Sequence[Tensor],
+ op: _int,
+ streams: Optional[Sequence[_CudaStreamBase]],
+ comms: Optional[Sequence[object]],
+) -> None: ...
def _rocm_is_backward_pass() -> _bool: ...
-
class _CudaDeviceProperties:
name: str
major: _int
@@ -2672,17 +3779,31 @@ class _CudaDeviceProperties:
# Defined in torch/csrc/cuda/python_comm.cpp
def _broadcast(tensor: Tensor, devices: List[_int]) -> List[Tensor]: ...
-def _broadcast_out(tensor: Tensor, out_tensors: List[Tensor]) -> List[Tensor]: ...
+def _broadcast_out(
+ tensor: Tensor, out_tensors: List[Tensor]
+) -> List[Tensor]: ...
def _broadcast_coalesced(
- tensors: List[Tensor],
- devices: List[_int],
- buffer_size: _int
+ tensors: List[Tensor], devices: List[_int], buffer_size: _int
) -> List[List[Tensor]]: ...
-
-def _scatter(tensor: Tensor, devices: List[_int], chunk_sizes: Optional[List[_int]], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ...
-def _scatter_out(tensor: Tensor, out_tensors: List[Tensor], dim: _int, streams: Optional[List[Stream]]) -> List[Tensor]: ...
-def _gather(tensors: List[Tensor], dim: _int, destination_index: Optional[_int]) -> Tensor: ...
-def _gather_out(tensors: List[Tensor], out_tensor: Tensor, dim: _int) -> Tensor: ...
+def _scatter(
+ tensor: Tensor,
+ devices: List[_int],
+ chunk_sizes: Optional[List[_int]],
+ dim: _int,
+ streams: Optional[List[Stream]],
+) -> List[Tensor]: ...
+def _scatter_out(
+ tensor: Tensor,
+ out_tensors: List[Tensor],
+ dim: _int,
+ streams: Optional[List[Stream]],
+) -> List[Tensor]: ...
+def _gather(
+ tensors: List[Tensor], dim: _int, destination_index: Optional[_int]
+) -> Tensor: ...
+def _gather_out(
+ tensors: List[Tensor], out_tensor: Tensor, dim: _int
+) -> Tensor: ...
# Defined in torch/csrc/cuda/Stream.cpp
class _CudaStreamBase:
@@ -2694,7 +3815,13 @@ class _CudaStreamBase:
cuda_stream: _int
priority: _int
- def __new__(self, priority: _int = 0, stream_id: _int = 0, device_index: _int = 0, stream_ptr: _int = 0) -> _CudaStreamBase: ...
+ def __new__(
+ self,
+ priority: _int = 0,
+ stream_id: _int = 0,
+ device_index: _int = 0,
+ stream_ptr: _int = 0,
+ ) -> _CudaStreamBase: ...
def query(self) -> _bool: ...
def synchronize(self) -> None: ...
def priority_range(self) -> Tuple[_int, _int]: ...
@@ -2704,9 +3831,16 @@ class _CudaEventBase:
device: _device
cuda_event: _int
- def __new__(cls, enable_timing: _bool = False, blocking: _bool = False, interprocess: _bool = False) -> _CudaEventBase: ...
+ def __new__(
+ cls,
+ enable_timing: _bool = False,
+ blocking: _bool = False,
+ interprocess: _bool = False,
+ ) -> _CudaEventBase: ...
@classmethod
- def from_ipc_handle(cls, device: _device, ipc_handle: bytes) -> _CudaEventBase: ...
+ def from_ipc_handle(
+ cls, device: _device, ipc_handle: bytes
+ ) -> _CudaEventBase: ...
def record(self, stream: _CudaStreamBase) -> None: ...
def wait(self, stream: _CudaStreamBase) -> None: ...
def query(self) -> _bool: ...
@@ -2716,24 +3850,29 @@ class _CudaEventBase:
# Defined in torch/csrc/cuda/Graph.cpp
class _CUDAGraph:
- def capture_begin(self,
- pool: Optional[Tuple[_int, _int]]=...) -> None: ...
+ def capture_begin(
+ self, pool: Optional[Tuple[_int, _int]] = ...
+ ) -> None: ...
def capture_end(self) -> None: ...
def replay(self) -> None: ...
def reset(self) -> None: ...
def pool(self) -> Tuple[_int, _int]: ...
def enable_debug_mode(self) -> None: ...
- def debug_dump(self,
- debug_path: str) -> None: ...
+ def debug_dump(self, debug_path: str) -> None: ...
def _cuda_isCurrentStreamCapturing() -> _bool: ...
-
def _graph_pool_handle() -> Tuple[_int, _int]: ...
# Defined in torch/csrc/DataLoader.cpp
-def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers
-def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs
-def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
+def _set_worker_signal_handlers(
+ *arg: Any,
+) -> None: ... # THPModule_setWorkerSignalHandlers
+def _set_worker_pids(
+ key: _int, child_pids: Tuple[_int, ...]
+) -> None: ... # THPModule_setWorkerPIDs
+def _remove_worker_pids(
+ loader_id: _int,
+) -> None: ... # THPModule_removeWorkerPIDs
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
# Defined in torch/csrc/jit/python/python_tracer.cpp
@@ -2752,19 +3891,19 @@ def _create_graph_by_tracing(
strict: Any,
force_outplace: Any,
self: Any = None,
- argument_names: List[str] = []
+ argument_names: List[str] = [],
) -> Tuple[Graph, Stack]: ...
def _tracer_warn_use_python(): ...
def _get_tracing_state() -> TracingState: ...
# Defined in torch/csrc/jit/python/python_ir.cpp
# Not actually defined in python_ir.cpp, not sure where they are.
-class IValue:
- ...
+class IValue: ...
+
Stack = List[IValue]
class JitType:
- annotation_str : str
+ annotation_str: str
def isSubtypeOf(self, other: JitType) -> _bool: ...
def with_dtype(self, dtype: _dtype) -> JitType: ...
def with_sizes(self, sizes: List[Optional[_int]]) -> JitType: ...
@@ -2779,7 +3918,7 @@ class InferredType:
def success(self) -> _bool: ...
def reason(self) -> str: ...
-R = TypeVar('R', bound=JitType)
+R = TypeVar("R", bound=JitType)
class AnyType(JitType):
@staticmethod
@@ -2824,7 +3963,6 @@ class StreamObjType(JitType):
class ListType(JitType):
def __init__(self, a: JitType) -> None: ...
def getElementType(self) -> JitType: ...
-
@staticmethod
def ofInts() -> ListType: ...
@staticmethod
@@ -2861,7 +3999,6 @@ class InterfaceType(JitType):
class OptionalType(JitType, Generic[R]):
def __init__(self, a: JitType) -> None: ...
def getElementType(self) -> JitType: ...
-
@staticmethod
def ofTensor() -> OptionalType: ...
@@ -2881,16 +4018,17 @@ class EnumType(JitType):
self,
qualified_name: str,
value_type: JitType,
- enum_names_values: List[Any]
- ) -> None:
- ...
+ enum_names_values: List[Any],
+ ) -> None: ...
class TensorType(JitType):
@classmethod
def get(cls) -> TensorType: ...
@classmethod
def getInferred(cls) -> TensorType: ...
- def with_sizes(self, other: Optional[List[Optional[_int]]]) -> TensorType: ...
+ def with_sizes(
+ self, other: Optional[List[Optional[_int]]]
+ ) -> TensorType: ...
def sizes(self) -> Optional[List[_int]]: ...
def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
def strides(self) -> Optional[List[_int]]: ...
@@ -2901,24 +4039,19 @@ class TensorType(JitType):
def create_from_tensor(t: Tensor) -> TensorType: ...
# Defined in torch/csrc/jit/python/python_tree_views.cpp
-class SourceRange:
- ...
-
-class TreeView:
- ...
+class SourceRange: ...
+class TreeView: ...
class Ident(TreeView):
@property
def name(self) -> str: ...
-class ClassDef(TreeView):
- ...
+class ClassDef(TreeView): ...
class Def(TreeView):
def name(self) -> Ident: ...
-class Decl(TreeView):
- ...
+class Decl(TreeView): ...
# Defined in torch/csrc/distributed/rpc/init.cpp
def _rpc_init() -> _bool: ...
@@ -2931,7 +4064,6 @@ def _c10d_init() -> _bool: ...
# Defined in torch/csrc/distributed/rpc/testing/init.cpp
def _faulty_agent_init() -> _bool: ...
-
def _enable_minidumps(directory: str) -> None: ...
def _disable_minidumps() -> None: ...
def _enable_minidumps_on_exceptions() -> None: ...
@@ -2940,7 +4072,7 @@ def _activate_cuda_trace() -> None: ...
# Defined in torch/csrc/Module.cpp
def _current_graph_task_id() -> _int: ...
-def _current_autograd_node() -> _Node: ...
+def _current_autograd_node() -> _Node: ...
class _OutOfMemoryError:
pass
diff --git a/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py b/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py
index e0d7fb4fc..d72c4679f 100644
--- a/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py
+++ b/tests/pytorch_pfn_extras_tests/cuda_tests/test_allocator.py
@@ -1,11 +1,10 @@
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
def test_stream():
- cupy = pytest.importorskip('cupy')
+ cupy = pytest.importorskip("cupy")
assert 0 == cupy.cuda.get_current_stream().ptr
assert 0 == torch.cuda.current_stream().cuda_stream
@@ -42,7 +41,7 @@ def test_stream_none():
class TestMemoryPool:
@pytest.fixture
def cupy(self):
- cupy = pytest.importorskip('cupy')
+ cupy = pytest.importorskip("cupy")
mempool = cupy.get_default_memory_pool()
yield cupy
mempool.free_all_blocks()
@@ -98,7 +97,8 @@ def test_use_torch_mempool_stream_mismatch(self, cupy):
try:
stream.use()
with pytest.raises(
- RuntimeError, match='pytorch_pfn_extras.cuda.stream'):
+ RuntimeError, match="pytorch_pfn_extras.cuda.stream"
+ ):
arr = cupy.arange(10)
del arr
finally:
diff --git a/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py b/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py
index 6c7238a1e..ee868f850 100644
--- a/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py
+++ b/tests/pytorch_pfn_extras_tests/dataloader_test/test_dataloader.py
@@ -1,6 +1,5 @@
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
class DummyDataset(torch.utils.data.Dataset):
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py
index c21f35404..dbb6df388 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/dummy_dataset.py
@@ -1,15 +1,19 @@
import numpy as np
-
import pytorch_pfn_extras as ppe
class DummyDataset(ppe.dataset.TabularDataset):
-
def __init__(
- self, size=10, keys=('a', 'b', 'c'), mode=tuple,
- return_array=False, callback=None, convert=False):
+ self,
+ size=10,
+ keys=("a", "b", "c"),
+ mode=tuple,
+ return_array=False,
+ callback=None,
+ convert=False,
+ ):
if mode is None:
- keys = keys[0],
+ keys = (keys[0],)
self._keys = keys
self._mode = mode
@@ -47,6 +51,6 @@ def get_examples(self, indices, key_indices):
def convert(self, data):
if self._convert:
- return 'converted'
+ return "converted"
else:
return super(DummyDataset, self).convert(data)
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py
index 9ec9ddb2d..91b426e89 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_asmode.py
@@ -1,13 +1,11 @@
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
-@pytest.mark.parametrize(
- 'mode',
- [tuple, dict, None]
-)
+@pytest.mark.parametrize("mode", [tuple, dict, None])
def test_astuple(mode):
dataset = dummy_dataset.DummyDataset(mode=mode, convert=True)
view = dataset.astuple()
@@ -15,15 +13,11 @@ def test_astuple(mode):
assert len(view) == len(dataset)
assert view.keys == dataset.keys
assert view.mode == tuple
- assert (
- view.get_examples(None, None) == dataset.get_examples(None, None))
- assert view.convert(view.fetch()) == 'converted'
+ assert view.get_examples(None, None) == dataset.get_examples(None, None)
+ assert view.convert(view.fetch()) == "converted"
-@pytest.mark.parametrize(
- 'mode',
- [tuple, dict, None]
-)
+@pytest.mark.parametrize("mode", [tuple, dict, None])
def test_asdict(mode):
dataset = dummy_dataset.DummyDataset(mode=mode, convert=True)
view = dataset.asdict()
@@ -31,6 +25,5 @@ def test_asdict(mode):
assert len(view) == len(dataset)
assert view.keys == dataset.keys
assert view.mode == dict
- assert (
- view.get_examples(None, None) == dataset.get_examples(None, None))
- assert view.convert(view.fetch()) == 'converted'
+ assert view.get_examples(None, None) == dataset.get_examples(None, None)
+ assert view.convert(view.fetch()) == "converted"
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py
index b7209aa51..19197c264 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_concat.py
@@ -3,57 +3,61 @@
import numpy as np
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
mode_a = [tuple, dict, None]
mode_b = [tuple, dict, None]
return_array = [True, False]
parameter_set = [
- {'indices': None,
- 'expected_indices_a': None,
- 'expected_indices_b': None},
- {'indices': [3, 1, 4, 12, 14, 13, 7, 5],
- 'expected_indices_a': [3, 1, 4, 7, 5],
- 'expected_indices_b': [2, 4, 3]},
- {'indices': [3, 1, 4],
- 'expected_indices_a': [3, 1, 4]},
- {'indices': slice(13, 6, -2),
- 'expected_indices_a': slice(9, 6, -2),
- 'expected_indices_b': slice(3, None, -2)},
- {'indices': slice(9, None, -2),
- 'expected_indices_a': slice(9, None, -2)},
- {'indices': [1, 2, 1],
- 'expected_indices_a': [1, 2, 1]},
- {'indices': []},
+ {"indices": None, "expected_indices_a": None, "expected_indices_b": None},
+ {
+ "indices": [3, 1, 4, 12, 14, 13, 7, 5],
+ "expected_indices_a": [3, 1, 4, 7, 5],
+ "expected_indices_b": [2, 4, 3],
+ },
+ {"indices": [3, 1, 4], "expected_indices_a": [3, 1, 4]},
+ {
+ "indices": slice(13, 6, -2),
+ "expected_indices_a": slice(9, 6, -2),
+ "expected_indices_b": slice(3, None, -2),
+ },
+ {"indices": slice(9, None, -2), "expected_indices_a": slice(9, None, -2)},
+ {"indices": [1, 2, 1], "expected_indices_a": [1, 2, 1]},
+ {"indices": []},
]
@pytest.mark.parametrize(
- 'mode_a, mode_b, return_array, parameter_set',
- itertools.product(mode_a, mode_b, return_array, parameter_set)
+ "mode_a, mode_b, return_array, parameter_set",
+ itertools.product(mode_a, mode_b, return_array, parameter_set),
)
def test_concat(mode_a, mode_b, return_array, parameter_set):
def callback_a(indices, key_indices):
- assert indices == parameter_set['expected_indices_a']
+ assert indices == parameter_set["expected_indices_a"]
assert key_indices is None
dataset_a = dummy_dataset.DummyDataset(
- keys=('a', 'b', 'c') if mode_b else ('a',),
+ keys=("a", "b", "c") if mode_b else ("a",),
mode=mode_a,
- return_array=return_array, callback=callback_a,
- convert=True)
+ return_array=return_array,
+ callback=callback_a,
+ convert=True,
+ )
def callback_b(indices, key_indices):
- assert indices == parameter_set['expected_indices_b']
+ assert indices == parameter_set["expected_indices_b"]
assert key_indices is None
dataset_b = dummy_dataset.DummyDataset(
size=5,
- keys=('a', 'b', 'c') if mode_a else ('a',),
+ keys=("a", "b", "c") if mode_a else ("a",),
mode=mode_b,
- return_array=return_array, callback=callback_b)
+ return_array=return_array,
+ callback=callback_b,
+ )
view = dataset_a.concat(dataset_b)
assert isinstance(view, ppe.dataset.TabularDataset)
@@ -61,27 +65,28 @@ def callback_b(indices, key_indices):
assert view.keys == dataset_a.keys
assert view.mode == dataset_a.mode
- output = view.get_examples(parameter_set['indices'], None)
+ output = view.get_examples(parameter_set["indices"], None)
data = np.hstack((dataset_a.data, dataset_b.data))
- if parameter_set['indices'] is not None:
- data = data[:, parameter_set['indices']]
+ if parameter_set["indices"] is not None:
+ data = data[:, parameter_set["indices"]]
for out, d in itertools.zip_longest(output, data):
np.testing.assert_equal(out, d)
if return_array and operator.xor(
- ('expected_indices_a' in parameter_set),
- ('expected_indices_b' in parameter_set)):
+ ("expected_indices_a" in parameter_set),
+ ("expected_indices_b" in parameter_set),
+ ):
assert isinstance(out, np.ndarray)
else:
assert isinstance(out, list)
- assert view.convert(output) == 'converted'
+ assert view.convert(output) == "converted"
def test_concat_key_length():
dataset_a = dummy_dataset.DummyDataset()
- dataset_b = dummy_dataset.DummyDataset(keys=('a', 'b'))
+ dataset_b = dummy_dataset.DummyDataset(keys=("a", "b"))
with pytest.raises(ValueError):
dataset_a.concat(dataset_b)
@@ -89,7 +94,7 @@ def test_concat_key_length():
def test_concat_key_order():
dataset_a = dummy_dataset.DummyDataset()
- dataset_b = dummy_dataset.DummyDataset(keys=('b', 'a', 'c'))
+ dataset_b = dummy_dataset.DummyDataset(keys=("b", "a", "c"))
with pytest.raises(ValueError):
dataset_a.concat(dataset_b)
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py
index 3236bb795..0f1e6014a 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_delegate_dataset.py
@@ -1,21 +1,17 @@
import pytest
-
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.dataset import tabular
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
-@pytest.mark.parametrize(
- 'mode',
- [tuple, dict, None]
-)
+@pytest.mark.parametrize("mode", [tuple, dict, None])
def test_delegate_dataset(mode):
- dataset = tabular.DelegateDataset(
- dummy_dataset.DummyDataset(mode=mode))
+ dataset = tabular.DelegateDataset(dummy_dataset.DummyDataset(mode=mode))
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == len(dataset.dataset)
assert dataset.keys == dataset.dataset.keys
assert dataset.mode == dataset.dataset.mode
- assert (
- dataset.get_example(3) == dataset.dataset.get_example(3))
+ assert dataset.get_example(3) == dataset.dataset.get_example(3)
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py
index 6ca854ccb..783e2ac92 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_from_data.py
@@ -1,12 +1,10 @@
import numpy as np
import pytest
-
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.dataset import tabular
class TestFromData:
-
def test_unary_array(self):
dataset = tabular.from_data(np.arange(10))
@@ -20,11 +18,11 @@ def test_unary_array(self):
assert isinstance(output, np.ndarray)
def test_unary_array_with_key(self):
- dataset = tabular.from_data(('key_a', np.arange(10)))
+ dataset = tabular.from_data(("key_a", np.arange(10)))
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a',)
+ assert dataset.keys == ("key_a",)
assert dataset.mode is None
output = dataset.slice[[1, 3]].fetch()
@@ -44,11 +42,11 @@ def test_unary_list(self):
assert isinstance(output, list)
def test_unary_list_with_key(self):
- dataset = tabular.from_data(('key_a', [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]))
+ dataset = tabular.from_data(("key_a", [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]))
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a',)
+ assert dataset.keys == ("key_a",)
assert dataset.mode is None
output = dataset.slice[[1, 3]].fetch()
@@ -56,12 +54,12 @@ def test_unary_list_with_key(self):
assert isinstance(output, list)
def test_unary_callable_unary(self):
- dataset = tabular.from_data(('key_a', lambda i: i * i), size=10)
+ dataset = tabular.from_data(("key_a", lambda i: i * i), size=10)
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a',)
- assert(dataset.mode) is None
+ assert dataset.keys == ("key_a",)
+ assert (dataset.mode) is None
output = dataset.slice[[1, 3]].fetch()
np.testing.assert_equal(output, [1, 9])
@@ -69,11 +67,12 @@ def test_unary_callable_unary(self):
def test_unary_callable_tuple(self):
dataset = tabular.from_data(
- (('key_a', 'key_b'), lambda i: (i * i, -i)), size=10)
+ (("key_a", "key_b"), lambda i: (i * i, -i)), size=10
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a', 'key_b')
+ assert dataset.keys == ("key_a", "key_b")
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -83,16 +82,17 @@ def test_unary_callable_tuple(self):
def test_unary_callable_dict(self):
dataset = tabular.from_data(
- (('key_a', 'key_b'),
- lambda i: {'key_a': i * i, 'key_b': -i}), size=10)
+ (("key_a", "key_b"), lambda i: {"key_a": i * i, "key_b": -i}),
+ size=10,
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a', 'key_b')
+ assert dataset.keys == ("key_a", "key_b")
assert dataset.mode == dict
output = dataset.slice[[1, 3]].fetch()
- np.testing.assert_equal(output, {'key_a': [1, 9], 'key_b': [-1, -3]})
+ np.testing.assert_equal(output, {"key_a": [1, 9], "key_b": [-1, -3]})
for out in output.values():
assert isinstance(out, list)
@@ -102,11 +102,12 @@ def test_unary_callable_without_key(self):
def test_unary_callable_without_size(self):
with pytest.raises(ValueError):
- tabular.from_data(('key_a', lambda i: i * i))
+ tabular.from_data(("key_a", lambda i: i * i))
def test_tuple_array_list(self):
dataset = tabular.from_data(
- (np.arange(10), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]))
+ (np.arange(10), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
@@ -120,12 +121,13 @@ def test_tuple_array_list(self):
def test_tuple_array_with_key_list(self):
dataset = tabular.from_data(
- (('key_a', np.arange(10)), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]))
+ (("key_a", np.arange(10)), [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
assert len(dataset.keys) == 2
- assert dataset.keys[0] == 'key_a'
+ assert dataset.keys[0] == "key_a"
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -135,12 +137,13 @@ def test_tuple_array_with_key_list(self):
def test_tuple_array_list_with_key(self):
dataset = tabular.from_data(
- (np.arange(10), ('key_b', [2, 7, 1, 8, 4, 5, 9, 0, 3, 6])))
+ (np.arange(10), ("key_b", [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]))
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
assert len(dataset.keys) == 2
- assert dataset.keys[1] == 'key_b'
+ assert dataset.keys[1] == "key_b"
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -149,13 +152,12 @@ def test_tuple_array_list_with_key(self):
assert isinstance(output[1], list)
def test_tuple_array_callable_unary(self):
- dataset = tabular.from_data(
- (np.arange(10), ('key_b', lambda i: i * i)))
+ dataset = tabular.from_data((np.arange(10), ("key_b", lambda i: i * i)))
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
assert len(dataset.keys) == 2
- assert dataset.keys[1] == 'key_b'
+ assert dataset.keys[1] == "key_b"
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -165,12 +167,13 @@ def test_tuple_array_callable_unary(self):
def test_tuple_array_callable_tuple(self):
dataset = tabular.from_data(
- (np.arange(10), (('key_b', 'key_c'), lambda i: (i * i, -i))))
+ (np.arange(10), (("key_b", "key_c"), lambda i: (i * i, -i)))
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
assert len(dataset.keys) == 3
- assert dataset.keys[1] == ('key_b')
+ assert dataset.keys[1] == ("key_b")
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -180,13 +183,16 @@ def test_tuple_array_callable_tuple(self):
def test_tuple_array_callable_dict(self):
dataset = tabular.from_data(
- (np.arange(10), (('key_b', 'key_c'),
- lambda i: {'key_b': i * i, 'key_c': -i})))
+ (
+ np.arange(10),
+ (("key_b", "key_c"), lambda i: {"key_b": i * i, "key_c": -i}),
+ )
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
assert len(dataset.keys) == 3
- assert dataset.keys[1] == ('key_b')
+ assert dataset.keys[1] == ("key_b")
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -196,11 +202,12 @@ def test_tuple_array_callable_dict(self):
def test_tuple_array_with_key_callable_unary(self):
dataset = tabular.from_data(
- (('key_a', np.arange(10)), ('key_b', lambda i: i * i)))
+ (("key_a", np.arange(10)), ("key_b", lambda i: i * i))
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a', 'key_b')
+ assert dataset.keys == ("key_a", "key_b")
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -210,11 +217,12 @@ def test_tuple_array_with_key_callable_unary(self):
def test_tuple_callable_unary_callable_unary(self):
dataset = tabular.from_data(
- (('key_a', lambda i: i * i), ('key_b', lambda i: -i)), size=10)
+ (("key_a", lambda i: i * i), ("key_b", lambda i: -i)), size=10
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert dataset.keys == ('key_a', 'key_b')
+ assert dataset.keys == ("key_a", "key_b")
assert dataset.mode == tuple
output = dataset.slice[[1, 3]].fetch()
@@ -225,87 +233,97 @@ def test_tuple_callable_unary_callable_unary(self):
def test_tuple_callable_unary_callable_unary_without_size(self):
with pytest.raises(ValueError):
tabular.from_data(
- (('key_a', lambda i: i * i), ('key_b', lambda i: -i)))
+ (("key_a", lambda i: i * i), ("key_b", lambda i: -i))
+ )
def test_dict_array_list(self):
dataset = tabular.from_data(
- {'key_a': np.arange(10), 'key_b': [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]})
+ {"key_a": np.arange(10), "key_b": [2, 7, 1, 8, 4, 5, 9, 0, 3, 6]}
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert set(dataset.keys) == {'key_a', 'key_b'}
+ assert set(dataset.keys) == {"key_a", "key_b"}
assert dataset.mode == dict
output = dataset.slice[[1, 3]].fetch()
- np.testing.assert_equal(output, {'key_a': [1, 3], 'key_b': [7, 8]})
- assert isinstance(output['key_a'], np.ndarray)
- assert isinstance(output['key_b'], list)
+ np.testing.assert_equal(output, {"key_a": [1, 3], "key_b": [7, 8]})
+ assert isinstance(output["key_a"], np.ndarray)
+ assert isinstance(output["key_b"], list)
def test_dict_array_callable_unary(self):
dataset = tabular.from_data(
- {'key_a': np.arange(10), 'key_b': lambda i: i * i})
+ {"key_a": np.arange(10), "key_b": lambda i: i * i}
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert set(dataset.keys) == {'key_a', 'key_b'}
+ assert set(dataset.keys) == {"key_a", "key_b"}
assert dataset.mode == dict
output = dataset.slice[[1, 3]].fetch()
- np.testing.assert_equal(output, {'key_a': [1, 3], 'key_b': [1, 9]})
- assert isinstance(output['key_a'], np.ndarray)
- assert isinstance(output['key_b'], list)
+ np.testing.assert_equal(output, {"key_a": [1, 3], "key_b": [1, 9]})
+ assert isinstance(output["key_a"], np.ndarray)
+ assert isinstance(output["key_b"], list)
def test_dict_array_callable_tuple(self):
dataset = tabular.from_data(
- {'key_a': np.arange(10),
- ('key_b', 'key_c'): lambda i: (i * i, -i)})
+ {"key_a": np.arange(10), ("key_b", "key_c"): lambda i: (i * i, -i)}
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert set(dataset.keys) == {'key_a', 'key_b', 'key_c'}
+ assert set(dataset.keys) == {"key_a", "key_b", "key_c"}
assert dataset.mode == dict
output = dataset.slice[[1, 3]].fetch()
np.testing.assert_equal(
- output, {'key_a': [1, 3], 'key_b': [1, 9], 'key_c': [-1, -3]})
- assert isinstance(output['key_a'], np.ndarray)
- assert isinstance(output['key_b'], list)
- assert isinstance(output['key_c'], list)
+ output, {"key_a": [1, 3], "key_b": [1, 9], "key_c": [-1, -3]}
+ )
+ assert isinstance(output["key_a"], np.ndarray)
+ assert isinstance(output["key_b"], list)
+ assert isinstance(output["key_c"], list)
def test_dict_array_callable_dict(self):
dataset = tabular.from_data(
- {'key_a': np.arange(10),
- ('key_b', 'key_c'): lambda i: {'key_b': i * i, 'key_c': -i}})
+ {
+ "key_a": np.arange(10),
+ ("key_b", "key_c"): lambda i: {"key_b": i * i, "key_c": -i},
+ }
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert set(dataset.keys) == {'key_a', 'key_b', 'key_c'}
+ assert set(dataset.keys) == {"key_a", "key_b", "key_c"}
assert dataset.mode == dict
output = dataset.slice[[1, 3]].fetch()
np.testing.assert_equal(
- output, {'key_a': [1, 3], 'key_b': [1, 9], 'key_c': [-1, -3]})
- assert isinstance(output['key_a'], np.ndarray)
- assert isinstance(output['key_b'], list)
- assert isinstance(output['key_c'], list)
+ output, {"key_a": [1, 3], "key_b": [1, 9], "key_c": [-1, -3]}
+ )
+ assert isinstance(output["key_a"], np.ndarray)
+ assert isinstance(output["key_b"], list)
+ assert isinstance(output["key_c"], list)
def test_dict_callable_unary_callable_unary(self):
dataset = tabular.from_data(
- {'key_a': lambda i: i * i, 'key_b': lambda i: -i}, size=10)
+ {"key_a": lambda i: i * i, "key_b": lambda i: -i}, size=10
+ )
assert isinstance(dataset, ppe.dataset.TabularDataset)
assert len(dataset) == 10
- assert set(dataset.keys) == {'key_a', 'key_b'}
+ assert set(dataset.keys) == {"key_a", "key_b"}
output = dataset.slice[[1, 3]].fetch()
- np.testing.assert_equal(output, {'key_a': [1, 9], 'key_b': [-1, -3]})
- assert isinstance(output['key_a'], list)
- assert isinstance(output['key_b'], list)
+ np.testing.assert_equal(output, {"key_a": [1, 9], "key_b": [-1, -3]})
+ assert isinstance(output["key_a"], list)
+ assert isinstance(output["key_b"], list)
def test_dict_callable_unary_callable_unary_without_size(self):
with pytest.raises(ValueError):
- tabular.from_data((
- {'key_a': lambda i: i * i, 'key_b': lambda i: -i}))
+ tabular.from_data(
+ ({"key_a": lambda i: i * i, "key_b": lambda i: -i})
+ )
def test_unique(self):
dataset_a = tabular.from_data(np.arange(10))
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py
index 38bd6beff..79016d2e4 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_join.py
@@ -2,9 +2,10 @@
import numpy as np
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
def _filter_params(params):
@@ -13,20 +14,22 @@ def _filter_params(params):
key_size += 3 if param[0] else 1
key_size += 2 if param[1] else 1
- if param[3] and \
- any(key_size <= key_index for key_index in param[3]):
+ if param[3] and any(key_size <= key_index for key_index in param[3]):
continue
yield param
@pytest.mark.parametrize(
- 'mode_a, mode_b, return_array, key_indices',
- _filter_params(itertools.product(
- [tuple, dict, None],
- [tuple, dict, None],
- [True, False],
- [None, (0, 4, 1), (0, 2), (1, 0), ()]))
+ "mode_a, mode_b, return_array, key_indices",
+ _filter_params(
+ itertools.product(
+ [tuple, dict, None],
+ [tuple, dict, None],
+ [True, False],
+ [None, (0, 4, 1), (0, 2), (1, 0), ()],
+ )
+ ),
)
def test_join(mode_a, mode_b, return_array, key_indices):
if key_indices is None:
@@ -37,13 +40,13 @@ def test_join(mode_a, mode_b, return_array, key_indices):
key_size_a = 3 if mode_a else 1
key_indices_a = tuple(
- key_index
- for key_index in key_indices
- if key_index < key_size_a)
+ key_index for key_index in key_indices if key_index < key_size_a
+ )
key_indices_b = tuple(
key_index - key_size_a
for key_index in key_indices
- if key_size_a <= key_index)
+ if key_size_a <= key_index
+ )
if key_indices_a:
expected_key_indices_a = key_indices_a
@@ -56,16 +59,21 @@ def callback_a(indices, key_indices):
dataset_a = dummy_dataset.DummyDataset(
mode=mode_a,
- return_array=return_array, callback=callback_a,
- convert=True)
+ return_array=return_array,
+ callback=callback_a,
+ convert=True,
+ )
def callback_b(indices, key_indices):
assert indices is None
assert key_indices == expected_key_indices_b
- dataset_b = dummy_dataset. DummyDataset(
- keys=('d', 'e'), mode=mode_b,
- return_array=return_array, callback=callback_b)
+ dataset_b = dummy_dataset.DummyDataset(
+ keys=("d", "e"),
+ mode=mode_b,
+ return_array=return_array,
+ callback=callback_b,
+ )
view = dataset_a.join(dataset_b)
assert isinstance(view, ppe.dataset.TabularDataset)
@@ -86,12 +94,12 @@ def callback_b(indices, key_indices):
else:
assert isinstance(out, list)
- assert view.convert(output) == 'converted'
+ assert view.convert(output) == "converted"
def test_join_length():
dataset_a = dummy_dataset.DummyDataset()
- dataset_b = dummy_dataset.DummyDataset(size=5, keys=('d', 'e'))
+ dataset_b = dummy_dataset.DummyDataset(size=5, keys=("d", "e"))
with pytest.raises(ValueError):
dataset_a.join(dataset_b)
@@ -99,7 +107,7 @@ def test_join_length():
def test_join_conflict_key():
dataset_a = dummy_dataset.DummyDataset()
- dataset_b = dummy_dataset.DummyDataset(keys=('a', 'd'))
+ dataset_b = dummy_dataset.DummyDataset(keys=("a", "d"))
with pytest.raises(ValueError):
dataset_a.join(dataset_b)
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py
index 5473ecc08..d8c2773bf 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_slice.py
@@ -3,9 +3,10 @@
import numpy as np
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
def _values_to_dicts(names, values):
@@ -18,143 +19,170 @@ def safe_zip(ns, vs):
assert isinstance(vs, (tuple, list)) and len(ns) == len(vs)
return zip(ns, vs)
- names = names.split(',')
+ names = names.split(",")
params = [dict(safe_zip(names, value_list)) for value_list in values]
return params
def product(parameter):
if isinstance(parameter, dict):
- return product_dict(*[
- _values_to_dicts(names, values)
- for names, values in sorted(parameter.items())])
+ return product_dict(
+ *[
+ _values_to_dicts(names, values)
+ for names, values in sorted(parameter.items())
+ ]
+ )
elif isinstance(parameter, list):
# list of lists of dicts
if not all(isinstance(_, list) for _ in parameter):
- raise TypeError('parameter must be list of lists of dicts')
+ raise TypeError("parameter must be list of lists of dicts")
if not all(isinstance(_, dict) for l in parameter for _ in l): # NOQA
- raise TypeError('parameter must be list of lists of dicts')
+ raise TypeError("parameter must be list of lists of dicts")
return product_dict(*parameter)
else:
raise TypeError(
- 'parameter must be either dict or list. Actual: {}'.format(
- type(parameter)))
+ "parameter must be either dict or list. Actual: {}".format(
+ type(parameter)
+ )
+ )
def product_dict(*parameters):
return [
{k: v for dic in dicts for k, v in dic.items()}
- for dicts in itertools.product(*parameters)]
+ for dicts in itertools.product(*parameters)
+ ]
def _filter_params(params):
for param in params:
- if 'expected_len' in param and \
- isinstance(param['get_examples_indices'], list) and \
- any(param['expected_len'] <= index
- for index in param['get_examples_indices']):
+ if (
+ "expected_len" in param
+ and isinstance(param["get_examples_indices"], list)
+ and any(
+ param["expected_len"] <= index
+ for index in param["get_examples_indices"]
+ )
+ ):
continue
- if 'expected_keys' in param and \
- isinstance(param['get_examples_key_indices'], tuple) and \
- any(len(param['expected_keys']) <= key_index
- for key_index in param['get_examples_key_indices']):
+ if (
+ "expected_keys" in param
+ and isinstance(param["get_examples_key_indices"], tuple)
+ and any(
+ len(param["expected_keys"]) <= key_index
+ for key_index in param["get_examples_key_indices"]
+ )
+ ):
continue
# To reduce the number of tests,
# drop combinations of indices and keys.
# (check only `slice[indices]` and `slice[:, keys]`)
- if (not (param['indices'] == slice(None)
- and param['get_examples_indices'] is None)
- and not (param['keys'] is None
- and param['get_examples_key_indices'] is None)):
+ if not (
+ param["indices"] == slice(None)
+ and param["get_examples_indices"] is None
+ ) and not (
+ param["keys"] is None and param["get_examples_key_indices"] is None
+ ):
continue
yield param
-params = _filter_params(product_dict(
+params = _filter_params(
product_dict(
- [{'mode': tuple}, {'mode': dict}],
+ product_dict(
+ [{"mode": tuple}, {"mode": dict}],
+ [
+ {"keys": None, "expected_keys": ("a", "b", "c")},
+ {"keys": 1, "expected_keys": ("b",)},
+ {"keys": (1,), "expected_keys": ("b",)},
+ {"keys": 3, "key_exception": IndexError},
+ {"keys": (3,), "key_exception": IndexError},
+ {"keys": "c", "expected_keys": ("c",)},
+ {"keys": ("c",), "expected_keys": ("c",)},
+ {"keys": "d", "key_exception": KeyError},
+ {"keys": ("d",), "key_exception": KeyError},
+ {"keys": (-1, "a"), "expected_keys": ("c", "a")},
+ {"keys": (), "expected_keys": ()},
+ ],
+ )
+ + product_dict(
+ [{"mode": None}],
+ [
+ {"keys": None, "expected_keys": ("a",)},
+ {"keys": 0, "expected_keys": ("a",)},
+ {"keys": (0,), "expected_keys": ("a",)},
+ {"keys": 1, "key_exception": IndexError},
+ {"keys": (1,), "key_exception": IndexError},
+ {"keys": "a", "expected_keys": ("a",)},
+ {"keys": ("a",), "expected_keys": ("a",)},
+ {"keys": "b", "key_exception": KeyError},
+ {"keys": ("b",), "key_exception": KeyError},
+ {"keys": (), "expected_keys": ()},
+ ],
+ ),
+ product(
+ {
+ "return_array": [True, False],
+ "integer": [int, np.int32],
+ }
+ ),
[
- {'keys': None, 'expected_keys': ('a', 'b', 'c')},
- {'keys': 1, 'expected_keys': ('b',)},
- {'keys': (1,), 'expected_keys': ('b',)},
- {'keys': 3, 'key_exception': IndexError},
- {'keys': (3,), 'key_exception': IndexError},
- {'keys': 'c', 'expected_keys': ('c',)},
- {'keys': ('c',), 'expected_keys': ('c',)},
- {'keys': 'd', 'key_exception': KeyError},
- {'keys': ('d',), 'key_exception': KeyError},
- {'keys': (-1, 'a'), 'expected_keys': ('c', 'a')},
- {'keys': (), 'expected_keys': ()},
+ {"indices": slice(None), "expected_len": 10},
+ {"indices": [3, -2], "expected_len": 2},
+ {"indices": [11, 1], "index_exception": IndexError},
+ {"indices": [i in {1, 3} for i in range(10)], "expected_len": 2},
+ {"indices": [True] * 11, "index_exception": ValueError},
+ {"indices": slice(3, None, -2), "expected_len": 2},
+ {"indices": [False, 3, 9, 5, True], "expected_len": 5},
+ {"indices": [], "expected_len": 0},
],
+ product(
+ {
+ "get_examples_indices": [
+ None,
+ [1],
+ [1, 0],
+ slice(0, 2, 1),
+ slice(1, None, -1),
+ [],
+ ],
+ "get_examples_key_indices": [None, (1,), (1, 0), ()],
+ }
+ ),
)
- + product_dict(
- [{'mode': None}],
- [
- {'keys': None, 'expected_keys': ('a',)},
- {'keys': 0, 'expected_keys': ('a',)},
- {'keys': (0,), 'expected_keys': ('a',)},
- {'keys': 1, 'key_exception': IndexError},
- {'keys': (1,), 'key_exception': IndexError},
- {'keys': 'a', 'expected_keys': ('a',)},
- {'keys': ('a',), 'expected_keys': ('a',)},
- {'keys': 'b', 'key_exception': KeyError},
- {'keys': ('b',), 'key_exception': KeyError},
- {'keys': (), 'expected_keys': ()},
- ],
- ),
- product({
- 'return_array': [True, False],
- 'integer': [int, np.int32],
- }),
- [
- {'indices': slice(None), 'expected_len': 10},
- {'indices': [3, -2], 'expected_len': 2},
- {'indices': [11, 1], 'index_exception': IndexError},
- {'indices': [i in {1, 3} for i in range(10)], 'expected_len': 2},
- {'indices': [True] * 11, 'index_exception': ValueError},
- {'indices': slice(3, None, -2), 'expected_len': 2},
- {'indices': [False, 3, 9, 5, True], 'expected_len': 5},
- {'indices': [], 'expected_len': 0},
- ],
- product({
- 'get_examples_indices': [
- None, [1], [1, 0], slice(0, 2, 1), slice(1, None, -1), []],
- 'get_examples_key_indices': [None, (1,), (1, 0), ()],
- }),
-))
-
-
-@pytest.mark.parametrize(
- 'test_args',
- params
)
+
+
+@pytest.mark.parametrize("test_args", params)
def test_slice(test_args):
- exception = test_args.get('index_exception', None) \
- or test_args.get('key_exception', None)
+ exception = test_args.get("index_exception", None) or test_args.get(
+ "key_exception", None
+ )
- indices = test_args['indices']
- keys = test_args['keys']
- mode = test_args['mode']
- return_array = test_args['return_array']
- get_examples_indices = test_args['get_examples_indices']
- get_examples_key_indices = test_args['get_examples_key_indices']
+ indices = test_args["indices"]
+ keys = test_args["keys"]
+ mode = test_args["mode"]
+ return_array = test_args["return_array"]
+ get_examples_indices = test_args["get_examples_indices"]
+ get_examples_key_indices = test_args["get_examples_key_indices"]
if isinstance(indices, list):
indices = [
- index if isinstance(index, bool) else test_args['integer'](index)
- for index in indices]
+ index if isinstance(index, bool) else test_args["integer"](index)
+ for index in indices
+ ]
def callback(indices, key_indices):
- if isinstance(indices, list) \
- or isinstance(get_examples_indices, list):
+ if isinstance(indices, list) or isinstance(get_examples_indices, list):
assert isinstance(indices, list)
- elif isinstance(indices, slice) \
- or isinstance(get_examples_indices, slice):
+ elif isinstance(indices, slice) or isinstance(
+ get_examples_indices, slice
+ ):
assert isinstance(indices, slice)
else:
assert indices is None
@@ -165,8 +193,8 @@ def callback(indices, key_indices):
assert isinstance(key_indices, tuple)
dataset = dummy_dataset.DummyDataset(
- mode=mode, return_array=return_array, callback=callback,
- convert=True)
+ mode=mode, return_array=return_array, callback=callback, convert=True
+ )
if exception is not None:
with pytest.raises(exception):
@@ -184,15 +212,13 @@ def callback(indices, key_indices):
if isinstance(keys, tuple):
keys = keys
else:
- keys = keys,
- key_indices = [
- {'a': 0, 'b': 1, 'c': 2}.get(key, key) for key in keys]
- data = dataset.data[key_indices][
- :, _indices_for_numpy(indices)]
+ keys = (keys,)
+ key_indices = [{"a": 0, "b": 1, "c": 2}.get(key, key) for key in keys]
+ data = dataset.data[key_indices][:, _indices_for_numpy(indices)]
assert isinstance(view, ppe.dataset.TabularDataset)
- assert len(view) == test_args['expected_len']
- assert view.keys == test_args['expected_keys']
+ assert len(view) == test_args["expected_len"]
+ assert view.keys == test_args["expected_keys"]
if keys is None:
assert view.mode == mode
elif isinstance(keys, tuple):
@@ -200,8 +226,7 @@ def callback(indices, key_indices):
else:
assert view.mode is None
- output = view.get_examples(
- get_examples_indices, get_examples_key_indices)
+ output = view.get_examples(get_examples_indices, get_examples_key_indices)
if get_examples_indices is not None:
data = data[:, _indices_for_numpy(get_examples_indices)]
@@ -215,22 +240,24 @@ def callback(indices, key_indices):
else:
assert isinstance(out, list)
- assert view.convert(output) == 'converted'
+ assert view.convert(output) == "converted"
# Replace list of bool with ndarray of bool
# since old numpy cannot handle list of bool.
def _indices_for_numpy(indices):
with warnings.catch_warnings():
- warnings.simplefilter('ignore', FutureWarning)
+ warnings.simplefilter("ignore", FutureWarning)
if len(np.empty(2)[[False, True]]) == 1:
# new numpy
return indices
# old numpy
- if isinstance(indices, list) and \
- len(indices) > 0 and \
- isinstance(indices[0], bool):
+ if (
+ isinstance(indices, list)
+ and len(indices) > 0
+ and isinstance(indices[0], bool)
+ ):
return np.array(indices)
else:
return indices
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py
index 63885f243..84390cab8 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_tabular_dataset.py
@@ -2,10 +2,9 @@
import numpy as np
import pytest
-
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( # NOQA
dummy_dataset,
-) # NOQA
+)
@pytest.mark.parametrize(
@@ -42,7 +41,8 @@ def callback(indices, key_indices):
def test_convert(self, mode, return_array):
dataset = dummy_dataset.DummyDataset(
- mode=mode, return_array=return_array)
+ mode=mode, return_array=return_array
+ )
output = dataset.convert(dataset.fetch())
if mode is tuple:
@@ -80,7 +80,8 @@ def callback(indices, key_indices):
def test_iter(self, mode, return_array):
dataset = dummy_dataset.DummyDataset(
- mode=mode, return_array=return_array)
+ mode=mode, return_array=return_array
+ )
it = iter(dataset)
for i in range(10):
if mode is tuple:
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py
index 3391ed1ae..63744a587 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_transform.py
@@ -2,35 +2,41 @@
import numpy as np
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
# filter out invalid combinations of params
def _filter_params(params):
for param in params:
- if param[1] is None and \
- isinstance(param[3], tuple) and \
- any(1 <= key_index
- for key_index in param[3]):
+ if (
+ param[1] is None
+ and isinstance(param[3], tuple)
+ and any(1 <= key_index for key_index in param[3])
+ ):
continue
yield param
@pytest.mark.parametrize(
- 'in_mode, out_mode, indices, key_indices, with_batch',
- _filter_params(itertools.product(
- [tuple, dict, None],
- [tuple, dict, None],
- [None, [1, 3], slice(None, 2)],
- [None, (0,), (1,), (1, 0)],
- [False, True]))
+ "in_mode, out_mode, indices, key_indices, with_batch",
+ _filter_params(
+ itertools.product(
+ [tuple, dict, None],
+ [tuple, dict, None],
+ [None, [1, 3], slice(None, 2)],
+ [None, (0,), (1,), (1, 0)],
+ [False, True],
+ )
+ ),
)
def test_transform(in_mode, out_mode, indices, key_indices, with_batch):
dataset = dummy_dataset.DummyDataset(
- mode=in_mode, return_array=True, convert=True)
+ mode=in_mode, return_array=True, convert=True
+ )
def transform(*args, **kwargs):
if in_mode is tuple:
@@ -40,11 +46,11 @@ def transform(*args, **kwargs):
elif in_mode is dict:
assert len(args) == 0
assert len(kwargs) == 3
- a, b, c = kwargs['a'], kwargs['b'], kwargs['c']
+ a, b, c = kwargs["a"], kwargs["b"], kwargs["c"]
elif in_mode is None:
assert len(args) == 1
assert len(kwargs) == 0
- a, = args
+ (a,) = args
b, c = a, a
if with_batch:
@@ -59,7 +65,7 @@ def transform(*args, **kwargs):
if out_mode is tuple:
return a + b, b + c
elif out_mode is dict:
- return {'alpha': a + b, 'beta': b + c}
+ return {"alpha": a + b, "beta": b + c}
elif out_mode is None:
return a + b + c
@@ -71,11 +77,11 @@ def transform_alpha(*args, **kwargs):
elif in_mode is dict:
assert len(args) == 0
assert len(kwargs) == 3
- a, b, c = kwargs['a'], kwargs['b'], kwargs['c']
+ a, b, c = kwargs["a"], kwargs["b"], kwargs["c"]
elif in_mode is None:
assert len(args) == 1
assert len(kwargs) == 0
- a, = args
+ (a,) = args
b, c = a, a
if with_batch:
@@ -88,9 +94,9 @@ def transform_alpha(*args, **kwargs):
assert isinstance(c, float)
if out_mode is tuple:
- return a + b,
+ return (a + b,)
elif out_mode is dict:
- return {'alpha': a + b}
+ return {"alpha": a + b}
elif out_mode is None:
return a + b + c
@@ -102,11 +108,11 @@ def transform_beta(*args, **kwargs):
elif in_mode is dict:
assert len(args) == 0
assert len(kwargs) == 3
- a, b, c = kwargs['a'], kwargs['b'], kwargs['c']
+ a, b, c = kwargs["a"], kwargs["b"], kwargs["c"]
elif in_mode is None:
assert len(args) == 1
assert len(kwargs) == 0
- a, = args
+ (a,) = args
b, c = a, a
if with_batch:
@@ -119,51 +125,49 @@ def transform_beta(*args, **kwargs):
assert isinstance(c, float)
if out_mode is tuple:
- return b + c,
+ return (b + c,)
elif out_mode is dict:
- return {'beta': b + c}
+ return {"beta": b + c}
elif out_mode is None:
return a + b + c
if in_mode is not None:
a, b, c = dataset.data
else:
- a, = dataset.data
+ (a,) = dataset.data
b, c = a, a
if out_mode is not None:
if in_mode is not None:
- d_transform = [
- ((('a', 'b', 'c'), ('alpha', 'beta')), transform)]
+ d_transform = [((("a", "b", "c"), ("alpha", "beta")), transform)]
else:
d_transform = [
- ((('a',), ('alpha',)), transform_alpha),
- ((('a',), ('beta',)), transform_beta)]
+ ((("a",), ("alpha",)), transform_alpha),
+ ((("a",), ("beta",)), transform_beta),
+ ]
if with_batch:
- view = dataset.transform_batch(('alpha', 'beta'), d_transform)
+ view = dataset.transform_batch(("alpha", "beta"), d_transform)
else:
- view = dataset.transform(('alpha', 'beta'), d_transform)
+ view = dataset.transform(("alpha", "beta"), d_transform)
data = np.vstack((a + b, b + c))
else:
if in_mode is not None:
- d_transform = [
- ((('a', 'b', 'c'), ('alpha',)), transform_alpha)]
+ d_transform = [((("a", "b", "c"), ("alpha",)), transform_alpha)]
else:
- d_transform = [
- ((('a',), ('alpha',)), transform_alpha)]
+ d_transform = [((("a",), ("alpha",)), transform_alpha)]
if with_batch:
- view = dataset.transform_batch(('alpha',), d_transform)
+ view = dataset.transform_batch(("alpha",), d_transform)
else:
- view = dataset.transform(('alpha',), d_transform)
+ view = dataset.transform(("alpha",), d_transform)
data = (a + b + c)[None]
assert isinstance(view, ppe.dataset.TabularDataset)
assert len(view) == len(dataset)
if out_mode is not None:
- assert view.keys == ('alpha', 'beta')
+ assert view.keys == ("alpha", "beta")
assert view.mode == out_mode
else:
- assert view.keys == ('alpha',)
+ assert view.keys == ("alpha",)
assert view.mode == out_mode
output = view.get_examples(indices, key_indices)
@@ -180,15 +184,11 @@ def transform_beta(*args, **kwargs):
else:
assert isinstance(out, list)
- assert view.convert(view.fetch()) == 'converted'
+ assert view.convert(view.fetch()) == "converted"
-@pytest.mark.parametrize(
- 'mode',
- [tuple, dict, None]
-)
+@pytest.mark.parametrize("mode", [tuple, dict, None])
class TestTransformInvalid:
-
def setup_method(self):
self.count = 0
@@ -205,9 +205,9 @@ def _transform(self, a, b, c):
mode = tuple
if mode is tuple:
- return a,
+ return (a,)
elif mode is dict:
- return {'a': a}
+ return {"a": a}
elif mode is None:
return a
@@ -215,8 +215,8 @@ def test_transform_inconsistent_mode(self, mode):
dataset = dummy_dataset.DummyDataset()
self.mode = mode
view = dataset.transform(
- ('a',),
- [((('a', 'b', 'c'), ('a',)), self._transform)])
+ ("a",), [((("a", "b", "c"), ("a",)), self._transform)]
+ )
view.get_examples([0], None)
with pytest.raises(ValueError):
view.get_examples([0], None)
@@ -225,8 +225,8 @@ def test_transform_batch_inconsistent_mode(self, mode):
dataset = dummy_dataset.DummyDataset()
self.mode = mode
view = dataset.transform_batch(
- ('a',),
- [((('a', 'b', 'c'), ('a',)), self._transform)])
+ ("a",), [((("a", "b", "c"), ("a",)), self._transform)]
+ )
view.get_examples(None, None)
with pytest.raises(ValueError):
view.get_examples(None, None)
@@ -237,14 +237,14 @@ def test_transform_batch_length_changed(self, mode):
def transform_batch(a, b, c):
if self.mode is tuple:
- return a + [0],
+ return (a + [0],)
elif self.mode is dict:
- return {'a': a + [0]}
+ return {"a": a + [0]}
elif self.mode is None:
return a + [0]
view = dataset.transform_batch(
- ('a',),
- [((('a', 'b', 'c'), ('a',)), transform_batch)])
+ ("a",), [((("a", "b", "c"), ("a",)), transform_batch)]
+ )
with pytest.raises(ValueError):
view.get_examples(None, None)
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py
index 0de654ad5..d8a135d42 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_converter.py
@@ -1,14 +1,12 @@
import numpy as np
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import dummy_dataset # NOQA
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+ dummy_dataset, # NOQA
+)
-@pytest.mark.parametrize(
- 'mode',
- [tuple, dict, None]
-)
+@pytest.mark.parametrize("mode", [tuple, dict, None])
def test_with_converter(mode):
dataset = dummy_dataset.DummyDataset(mode=mode)
@@ -19,18 +17,18 @@ def converter(*args, **kwargs):
elif mode is dict:
assert args == ()
np.testing.assert_equal(
- kwargs, dict(zip(('a', 'b', 'c'), dataset.data)))
+ kwargs, dict(zip(("a", "b", "c"), dataset.data))
+ )
elif mode is None:
np.testing.assert_equal(args, tuple(dataset.data))
assert kwargs == {}
- return 'converted'
+ return "converted"
view = dataset.with_converter(converter)
assert isinstance(view, ppe.dataset.TabularDataset)
assert len(view) == len(dataset)
assert view.keys == dataset.keys
assert view.mode == dataset.mode
- assert (
- view.get_examples(None, None) == dataset.get_examples(None, None))
- assert view.convert(view.fetch()) == 'converted'
+ assert view.get_examples(None, None) == dataset.get_examples(None, None)
+ assert view.convert(view.fetch()) == "converted"
diff --git a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py
index 84d1beb8a..ca63f4880 100644
--- a/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py
+++ b/tests/pytorch_pfn_extras_tests/dataset_tests/tabular_tests/test_with_torch_dataloader.py
@@ -1,23 +1,22 @@
import pytest
import torch
-
-from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import (
+from pytorch_pfn_extras_tests.dataset_tests.tabular_tests import ( # NOQA
dummy_dataset,
-) # NOQA
+)
@pytest.mark.parametrize(
- 'batch_size,mode',
+ "batch_size,mode",
[(1, dict), (2, dict), (8, dict), (1, tuple), (2, tuple), (8, tuple)],
)
def test_with_dataloader(batch_size, mode):
size = 10
- keys = ('a', 'b', 'c')
+ keys = ("a", "b", "c")
dataset = dummy_dataset.DummyDataset(size=size, keys=keys, mode=mode)
expected = torch.tensor(dataset.data).type(torch.float64)
expected_per_key = [
[
- expected[i, j * batch_size:(j + 1) * batch_size]
+ expected[i, j * batch_size : (j + 1) * batch_size]
for j in range((size + batch_size - 1) // batch_size)
]
for i in range(len(keys))
@@ -27,4 +26,5 @@ def test_with_dataloader(batch_size, mode):
for i, example in enumerate(dataloader):
for j, key in enumerate(keys):
assert torch.allclose(
- expected_per_key[j][i], example[key if mode == dict else j])
+ expected_per_key[j][i], example[key if mode == dict else j]
+ )
diff --git a/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py b/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py
index 1554d4d6f..c8dbb4c0f 100644
--- a/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py
+++ b/tests/pytorch_pfn_extras_tests/distributed_tests/test_distributed_validation_sampler.py
@@ -1,8 +1,7 @@
-import torch.distributed as dist
-
-import pytest
from unittest import mock
+import pytest
+import torch.distributed as dist
from pytorch_pfn_extras.distributed import DistributedValidationSampler
_world_size = 4
@@ -17,10 +16,11 @@ def base_dataset():
def test_default(base_dataset):
expected_lengths = [6, 5, 5, 5]
sample_idxs = []
- with mock.patch.object(dist, 'get_world_size', return_value=_world_size), \
- mock.patch.object(dist, 'is_available', return_value=True):
+ with mock.patch.object(
+ dist, "get_world_size", return_value=_world_size
+ ), mock.patch.object(dist, "is_available", return_value=True):
for rank in range(_world_size):
- with mock.patch.object(dist, 'get_rank', return_value=rank):
+ with mock.patch.object(dist, "get_rank", return_value=rank):
sampler = DistributedValidationSampler(base_dataset)
assert len(sampler) == expected_lengths[rank]
sample_idxs += list(sampler)
@@ -39,11 +39,14 @@ def test_no_shuffle(base_dataset):
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
]
- with mock.patch.object(dist, 'get_world_size', return_value=_world_size), \
- mock.patch.object(dist, 'is_available', return_value=True):
+ with mock.patch.object(
+ dist, "get_world_size", return_value=_world_size
+ ), mock.patch.object(dist, "is_available", return_value=True):
for rank in range(_world_size):
- with mock.patch.object(dist, 'get_rank', return_value=rank):
- sampler = DistributedValidationSampler(base_dataset, shuffle=False)
+ with mock.patch.object(dist, "get_rank", return_value=rank):
+ sampler = DistributedValidationSampler(
+ base_dataset, shuffle=False
+ )
assert list(sampler) == expected_samples[rank]
@@ -51,17 +54,27 @@ def test_manual_num_replicas_and_ranks(base_dataset):
# When manually specifying num_replicas and rank,
# it doesn't rely on these torch.distributed functions.
expected_lengths = [6, 5, 5, 5]
- with mock.patch.object(dist, 'get_world_size', side_effect=AssertionError()), \
- mock.patch.object(dist, 'is_available', side_effect=AssertionError()), \
- mock.patch.object(dist, 'get_rank', side_effect=AssertionError()):
+ with mock.patch.object(
+ dist, "get_world_size", side_effect=AssertionError()
+ ), mock.patch.object(
+ dist, "is_available", side_effect=AssertionError()
+ ), mock.patch.object(
+ dist, "get_rank", side_effect=AssertionError()
+ ):
for rank in range(_world_size):
- sampler = DistributedValidationSampler(base_dataset, num_replicas=_world_size, rank=rank)
+ sampler = DistributedValidationSampler(
+ base_dataset, num_replicas=_world_size, rank=rank
+ )
assert len(sampler) == expected_lengths[rank]
def test_seed(base_dataset):
- sampler1 = DistributedValidationSampler(base_dataset, num_replicas=_world_size, rank=0, seed=1)
- sampler2 = DistributedValidationSampler(base_dataset, num_replicas=_world_size, rank=0, seed=2)
+ sampler1 = DistributedValidationSampler(
+ base_dataset, num_replicas=_world_size, rank=0, seed=1
+ )
+ sampler2 = DistributedValidationSampler(
+ base_dataset, num_replicas=_world_size, rank=0, seed=2
+ )
assert list(sampler1) != list(sampler2)
@@ -73,7 +86,7 @@ def test_no_distributed_available(base_dataset):
def test_invalid_rank(base_dataset):
- with mock.patch.object(dist, 'get_world_size', return_value=_world_size):
+ with mock.patch.object(dist, "get_world_size", return_value=_world_size):
with pytest.raises(ValueError):
DistributedValidationSampler(base_dataset, rank=-1)
with pytest.raises(ValueError):
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py
index 1e758cb9e..c53183b15 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_ensure_shape.py
@@ -1,16 +1,15 @@
import pytest
import torch
-
from pytorch_pfn_extras.nn import Ensure, ensure
class TestEnsure:
def test_wrong_initialization(self):
- with pytest.raises(ValueError, match='both arguments'):
+ with pytest.raises(ValueError, match="both arguments"):
Ensure(shape=None, dtype=None)
@pytest.mark.parametrize(
- 'shape', [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)]
+ "shape", [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)]
)
def test_valid_shape(self, shape):
tensor = torch.zeros(shape)
@@ -20,69 +19,81 @@ def test_valid_shape(self, shape):
ensure(tensor, shape)
@pytest.mark.parametrize(
- 'shape', [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)]
+ "shape", [(), (1,), (1, 1), (2,), (2, 4), (2, 3, 4)]
)
def test_invalid_shape(self, shape):
tensor = torch.zeros((1, 2, 3))
module = Ensure(shape=shape)
- with pytest.raises(ValueError, match='input shape is'):
+ with pytest.raises(ValueError, match="input shape is"):
module(tensor)
- with pytest.raises(ValueError, match='input shape is'):
+ with pytest.raises(ValueError, match="input shape is"):
ensure(tensor, shape)
- @pytest.mark.parametrize('shape_t, shape_c', [
- ((), (2,)),
- ((3,), ()),
- ((1,), (2,)),
- ((1, 1), (2, 1)),
- ((1, 1), (2,)),
- ((2, 1), (2, 2)),
- ((2, 4), (4,)),
- ((2, 3, 4), (1, 4)),
- ])
+ @pytest.mark.parametrize(
+ "shape_t, shape_c",
+ [
+ ((), (2,)),
+ ((3,), ()),
+ ((1,), (2,)),
+ ((1, 1), (2, 1)),
+ ((1, 1), (2,)),
+ ((2, 1), (2, 2)),
+ ((2, 4), (4,)),
+ ((2, 3, 4), (1, 4)),
+ ],
+ )
def test_broadcastable_shape(self, shape_t, shape_c):
tensor = torch.zeros(shape_t)
module = Ensure(shape=shape_c, broadcastable=True)
module(tensor)
- @pytest.mark.parametrize('shape_t, shape_c', [
- ((3,), (2,)),
- ((2, 3), (2, 2)),
- ((2, 4), (3, 4)),
- ((2, 4), (3,)),
- ((2, 3, 4), (2, 2, 1)),
- ])
+ @pytest.mark.parametrize(
+ "shape_t, shape_c",
+ [
+ ((3,), (2,)),
+ ((2, 3), (2, 2)),
+ ((2, 4), (3, 4)),
+ ((2, 4), (3,)),
+ ((2, 3, 4), (2, 2, 1)),
+ ],
+ )
def test_nonbroadcastable_shape(self, shape_t, shape_c):
tensor = torch.zeros(shape_t)
module = Ensure(shape=shape_c, broadcastable=True)
- with pytest.raises(ValueError, match='non broadcastable'):
+ with pytest.raises(ValueError, match="non broadcastable"):
module(tensor)
- @pytest.mark.parametrize('shape_t, shape_c', [
- ((2,), (None,)),
- ((2, 2), (2, None)),
- ((2, 1), (None, 2)),
- ((1, 4), (None, 4)),
- ((2, 3, 4), (2, None, 4)),
- ])
+ @pytest.mark.parametrize(
+ "shape_t, shape_c",
+ [
+ ((2,), (None,)),
+ ((2, 2), (2, None)),
+ ((2, 1), (None, 2)),
+ ((1, 4), (None, 4)),
+ ((2, 3, 4), (2, None, 4)),
+ ],
+ )
def test_unknown_shape(self, shape_t, shape_c):
tensor = torch.zeros(shape_t)
module = Ensure(shape=shape_c)
module(tensor)
- @pytest.mark.parametrize('shape_t, shape_c', [
- ((3, 2), (2, None)),
- ((1, 4), (None, 2)),
- ((2, 3, 4), (3, None, 4)),
- ])
+ @pytest.mark.parametrize(
+ "shape_t, shape_c",
+ [
+ ((3, 2), (2, None)),
+ ((1, 4), (None, 2)),
+ ((2, 3, 4), (3, None, 4)),
+ ],
+ )
def test_invalid_unknown_shape(self, shape_t, shape_c):
tensor = torch.zeros(shape_t)
module = Ensure(shape=shape_c)
- with pytest.raises(ValueError, match='non broadcastable'):
+ with pytest.raises(ValueError, match="non broadcastable"):
module(tensor)
@pytest.mark.parametrize(
- 'dtype', [torch.int32, torch.float32, torch.complex64]
+ "dtype", [torch.int32, torch.float32, torch.complex64]
)
def test_valid_dtypes(self, dtype):
tensor = torch.zeros(1, dtype=dtype)
@@ -90,36 +101,45 @@ def test_valid_dtypes(self, dtype):
module(tensor)
ensure(tensor, None, dtype)
- @pytest.mark.parametrize('dtype_t, dtype_c', [
- (torch.int32, torch.int16),
- (torch.int32, torch.float32),
- (torch.float32, torch.float64),
- (torch.float32, torch.complex64),
- ])
+ @pytest.mark.parametrize(
+ "dtype_t, dtype_c",
+ [
+ (torch.int32, torch.int16),
+ (torch.int32, torch.float32),
+ (torch.float32, torch.float64),
+ (torch.float32, torch.complex64),
+ ],
+ )
def test_invalid_dtypes(self, dtype_t, dtype_c):
tensor = torch.zeros(1, dtype=dtype_t)
module = Ensure(shape=None, dtype=dtype_c)
- with pytest.raises(ValueError, match='input dtype'):
+ with pytest.raises(ValueError, match="input dtype"):
module(tensor)
- @pytest.mark.parametrize('dtype_t, dtype_c', [
- (torch.int32, torch.float32),
- (torch.int32, torch.complex128),
- (torch.int8, torch.float16),
- ])
+ @pytest.mark.parametrize(
+ "dtype_t, dtype_c",
+ [
+ (torch.int32, torch.float32),
+ (torch.int32, torch.complex128),
+ (torch.int8, torch.float16),
+ ],
+ )
def test_dtypes_with_cast(self, dtype_t, dtype_c):
tensor = torch.zeros(1, dtype=dtype_t)
module = Ensure(shape=None, dtype=dtype_c, can_cast=True)
module(tensor)
- @pytest.mark.parametrize('dtype_t, dtype_c', [
- (torch.complex64, torch.int32),
- (torch.float32, torch.int32),
- ])
+ @pytest.mark.parametrize(
+ "dtype_t, dtype_c",
+ [
+ (torch.complex64, torch.int32),
+ (torch.float32, torch.int32),
+ ],
+ )
def test_invalid_dtypes_with_cast(self, dtype_t, dtype_c):
tensor = torch.zeros(1, dtype=dtype_t)
module = Ensure(shape=None, dtype=dtype_c, can_cast=True)
- with pytest.raises(ValueError, match='be casted to'):
+ with pytest.raises(ValueError, match="be casted to"):
module(tensor)
def test_valid_shape_and_dtype(self):
@@ -131,7 +151,7 @@ def test_valid_shape_and_dtype(self):
ensure(tensor, shape, dtype)
# Too many warnings to list them all
- @pytest.mark.filterwarnings('ignore')
+ @pytest.mark.filterwarnings("ignore")
def test_jit_module(self):
shape = (10, 5)
dtype = torch.float32
@@ -146,7 +166,7 @@ def test_jit_module(self):
jit_module(tensor)
# Tracing with a different shape fails during trace process
- with pytest.raises(ValueError, match='input shape is'):
+ with pytest.raises(ValueError, match="input shape is"):
jit_module = torch.jit.trace(module, (tensor,))
def test_torchscript_module(self):
@@ -160,5 +180,5 @@ def test_torchscript_module(self):
shape = (5, 5)
tensor = torch.zeros(shape, dtype=dtype)
# torchscript changes the exception type
- with pytest.raises(torch.jit.Error, match='input shape is'):
+ with pytest.raises(torch.jit.Error, match="input shape is"):
jit_module(tensor)
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py
index b9fc52d5e..0a0f09cb5 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_extended_sequential.py
@@ -1,14 +1,14 @@
-import unittest
-import pytest
import functools
+import unittest
import warnings
import numpy
+import pytest
+import pytorch_pfn_extras as ppe
import torch
from torch import nn
-import pytorch_pfn_extras as ppe
-assertions = unittest.TestCase('__init__')
+assertions = unittest.TestCase("__init__")
class UserDefinedLayer(nn.Module):
@@ -19,27 +19,29 @@ def forward(self):
pass
-@pytest.mark.parametrize('container', [
- nn.Sequential,
- nn.ModuleList,
- nn.ModuleDict,
-])
-@pytest.mark.parametrize('irregular_layer', [
- UserDefinedLayer,
- # No reset_parameters
- nn.ReLU,
- # use reset_running_stats
- functools.partial(
- nn.BatchNorm1d, 1),
- # use _reset_parameters
- functools.partial(
- nn.MultiheadAttention, 1, 1),
- # ppe.nn layer
- functools.partial(
- ppe.nn.LazyConv1d, None, 1, 1),
-])
+@pytest.mark.parametrize(
+ "container",
+ [
+ nn.Sequential,
+ nn.ModuleList,
+ nn.ModuleDict,
+ ],
+)
+@pytest.mark.parametrize(
+ "irregular_layer",
+ [
+ UserDefinedLayer,
+ # No reset_parameters
+ nn.ReLU,
+ # use reset_running_stats
+ functools.partial(nn.BatchNorm1d, 1),
+ # use _reset_parameters
+ functools.partial(nn.MultiheadAttention, 1, 1),
+ # ppe.nn layer
+ functools.partial(ppe.nn.LazyConv1d, None, 1, 1),
+ ],
+)
class TestExtendedSequential(object):
-
@pytest.fixture(autouse=True)
def setUp(self, container, irregular_layer):
self.l1 = ppe.nn.LazyLinear(None, 3)
@@ -51,9 +53,7 @@ def setUp(self, container, irregular_layer):
if container == nn.Sequential:
self.s1 = container(self.l1, self.l2)
elif container == nn.ModuleDict:
- self.s1 = container({
- 'l1': self.l1,
- 'l2': self.l2})
+ self.s1 = container({"l1": self.l1, "l2": self.l2})
else:
self.s1 = container([self.l1, self.l2])
self.container = container
@@ -71,36 +71,46 @@ def test_repeat_with_init(self):
# bias is filled with 0, so they should have the same values
if self.container == nn.ModuleDict:
numpy.testing.assert_array_equal(
- ret[0][0]['l1'].bias.detach().numpy(),
- ret[1][0]['l1'].bias.detach().numpy())
+ ret[0][0]["l1"].bias.detach().numpy(),
+ ret[1][0]["l1"].bias.detach().numpy(),
+ )
else:
numpy.testing.assert_array_equal(
ret[0][0][0].bias.detach().numpy(),
- ret[1][0][0].bias.detach().numpy())
+ ret[1][0][0].bias.detach().numpy(),
+ )
# weight is initialized randomly, so they should be different
assertions.assertFalse(
- numpy.array_equal(ret[0][1].weight.detach().numpy(),
- self.l3.weight.detach().numpy()))
+ numpy.array_equal(
+ ret[0][1].weight.detach().numpy(),
+ self.l3.weight.detach().numpy(),
+ )
+ )
# And the object should also be different
- assertions.assertIsNot(ret[0][1].weight.detach().numpy(),
- self.l3.weight.detach().numpy())
+ assertions.assertIsNot(
+ ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy()
+ )
# Repeated elements should be different objects
assertions.assertIsNot(ret[0], ret[1])
# Also for the arrays
- assertions.assertIsNot(ret[0][1].weight.detach().numpy(),
- ret[1][1].weight.detach().numpy())
+ assertions.assertIsNot(
+ ret[0][1].weight.detach().numpy(), ret[1][1].weight.detach().numpy()
+ )
# And values should be different
assertions.assertFalse(
- numpy.array_equal(ret[0][1].weight.detach().numpy(),
- ret[1][1].weight.detach().numpy()))
+ numpy.array_equal(
+ ret[0][1].weight.detach().numpy(),
+ ret[1][1].weight.detach().numpy(),
+ )
+ )
assertions.assertEqual(len(ret), 2)
- ret = self.s2.repeat(0, mode='init')
+ ret = self.s2.repeat(0, mode="init")
assertions.assertEqual(len(ret), 0)
def test_repeat_with_copy(self):
# s2 ((l1 -> l2) -> l3 -> l4) -> s2 ((l1 -> l2) -> l3 -> l4)
- ret = self.s2.repeat(2, mode='copy')
+ ret = self.s2.repeat(2, mode="copy")
assertions.assertIsNot(ret[0], self.s2)
assertions.assertIs(type(ret[0]), type(self.s2))
assertions.assertIsNot(ret[1], self.s2)
@@ -111,14 +121,17 @@ def test_repeat_with_copy(self):
if self.container == nn.ModuleDict:
numpy.testing.assert_array_equal(
ret[0][0]["l1"].bias.detach().numpy(),
- ret[1][0]["l1"].bias.detach().numpy())
+ ret[1][0]["l1"].bias.detach().numpy(),
+ )
else:
numpy.testing.assert_array_equal(
ret[0][0][0].bias.detach().numpy(),
- ret[1][0][0].bias.detach().numpy())
+ ret[1][0][0].bias.detach().numpy(),
+ )
# W is shallowy copied, so the values should be same
numpy.testing.assert_array_equal(
- ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy())
+ ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy()
+ )
# But the object should be different
assertions.assertIsNot(ret[0][1].weight, self.l3.weight)
# Repeated elements should be different objects
@@ -127,16 +140,16 @@ def test_repeat_with_copy(self):
assertions.assertIsNot(ret[0][1].weight, ret[1][1].weight)
# But the values should be same
numpy.testing.assert_array_equal(
- ret[0][1].weight.detach().numpy(),
- ret[1][1].weight.detach().numpy())
+ ret[0][1].weight.detach().numpy(), ret[1][1].weight.detach().numpy()
+ )
assertions.assertEqual(len(ret), 2)
- ret = self.s2.repeat(0, mode='copy')
+ ret = self.s2.repeat(0, mode="copy")
assertions.assertEqual(len(ret), 0)
def test_repeat_with_share(self):
# s2 ((l1 -> l2) -> l3 -> l4) -> s2 ((l1 -> l2) -> l3 -> l4)
- ret = self.s2.repeat(2, mode='share')
+ ret = self.s2.repeat(2, mode="share")
assertions.assertIsNot(ret[0], self.s2)
assertions.assertIs(type(ret[0]), type(self.s2))
assertions.assertIsNot(ret[1], self.s2)
@@ -146,16 +159,20 @@ def test_repeat_with_share(self):
if self.container == nn.ModuleDict:
numpy.testing.assert_array_equal(
ret[0][0]["l1"].bias.detach().numpy(),
- ret[1][0]["l1"].bias.detach().numpy())
+ ret[1][0]["l1"].bias.detach().numpy(),
+ )
else:
numpy.testing.assert_array_equal(
ret[0][0][0].bias.detach().numpy(),
- ret[1][0][0].bias.detach().numpy())
+ ret[1][0][0].bias.detach().numpy(),
+ )
# W is shallowy copied, so the values should be same
numpy.testing.assert_array_equal(
- ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy())
+ ret[0][1].weight.detach().numpy(), self.l3.weight.detach().numpy()
+ )
numpy.testing.assert_array_equal(
- ret[1][1].weight.detach().numpy(), self.l3.weight.detach().numpy())
+ ret[1][1].weight.detach().numpy(), self.l3.weight.detach().numpy()
+ )
# And the object should also be same
assertions.assertIs(ret[0][1].weight, self.l3.weight)
assertions.assertIs(ret[1][1].weight, self.l3.weight)
@@ -163,7 +180,7 @@ def test_repeat_with_share(self):
assertions.assertIsNot(ret[0], ret[1])
assertions.assertEqual(len(ret), 2)
- ret = self.s2.repeat(0, mode='share')
+ ret = self.s2.repeat(0, mode="share")
assertions.assertEqual(len(ret), 0)
@@ -193,7 +210,7 @@ class UserDefinedLayerWithParameters(nn.Module):
def __init__(self):
super().__init__()
param = nn.Parameter(torch.zeros(1, 1))
- self.register_parameter('weight', param)
+ self.register_parameter("weight", param)
def forward(self):
pass
@@ -202,22 +219,25 @@ def forward(self):
class UserDefinedLayerWithBuffer(nn.Module):
def __init__(self):
super().__init__()
- self.register_buffer('weight', torch.zeros(1, 1))
+ self.register_buffer("weight", torch.zeros(1, 1))
def forward(self):
pass
-@pytest.mark.parametrize('module', [
- # buit-in, no parameters
- nn.ReLU,
- # no parameters
- UserDefinedLayer,
- # has `_reset_parameters`
- UserDefinedLayerWithUnderScoreReset,
- # has `reset_parameters`
- UserDefinedLayerWithReset,
-])
+@pytest.mark.parametrize(
+ "module",
+ [
+ # buit-in, no parameters
+ nn.ReLU,
+ # no parameters
+ UserDefinedLayer,
+ # has `_reset_parameters`
+ UserDefinedLayerWithUnderScoreReset,
+ # has `reset_parameters`
+ UserDefinedLayerWithReset,
+ ],
+)
def test_no_warning_when_repeat(module):
model = ppe.nn.ExtendedSequential(module())
# no warnings are raised on these modules
@@ -226,10 +246,13 @@ def test_no_warning_when_repeat(module):
model.repeat(2)
-@pytest.mark.parametrize('module', [
- UserDefinedLayerWithParameters,
- UserDefinedLayerWithBuffer,
-])
+@pytest.mark.parametrize(
+ "module",
+ [
+ UserDefinedLayerWithParameters,
+ UserDefinedLayerWithBuffer,
+ ],
+)
def test_warning_when_repeat(module):
model = ppe.nn.ExtendedSequential(module())
# warnings are raised on these modules
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py
index 25297d181..c91cda51b 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy.py
@@ -2,21 +2,21 @@
import pytest
import torch
+from pytorch_pfn_extras.nn.modules.lazy import (
+ LazyInitializationMixin,
+ UninitializedParameter,
+)
from torch import nn
from torch.nn import functional as F
-from pytorch_pfn_extras.nn.modules.lazy import LazyInitializationMixin
-from pytorch_pfn_extras.nn.modules.lazy import UninitializedParameter
-
class _MyFunc(torch.nn.Module):
-
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
- self.register_buffer('const', torch.full((in_features,), 1.0))
+ self.register_buffer("const", torch.full((in_features,), 1.0))
self._reset_params()
def forward(self, input):
@@ -27,9 +27,8 @@ def _reset_params(self):
class _LazyMyFunc(LazyInitializationMixin, _MyFunc):
-
- lazy_parameter_names = ('weight',)
- lazy_buffer_names = ('const',)
+ lazy_parameter_names = ("weight",)
+ lazy_buffer_names = ("const",)
def __init__(self, in_features, out_features):
super().__init__(in_features or 0, out_features)
@@ -41,7 +40,8 @@ def forward(self, input):
if isinstance(self.weight, UninitializedParameter):
self.in_features = input.shape[-1]
self.weight = torch.nn.Parameter(
- self.weight.new_empty((self.out_features, self.in_features)))
+ self.weight.new_empty((self.out_features, self.in_features))
+ )
self.const = self.const.new_full((self.in_features,), 1)
self._reset_params()
self.to(input.device)
@@ -53,7 +53,6 @@ def _reset_params(self):
class LazyTestBase:
-
def get_original_module(self):
raise NotImplementedError
@@ -122,15 +121,20 @@ def test_lazy_warning(self):
m = self.get_lazy_module()
with pytest.warns(UserWarning) as record:
torch.optim.SGD(m.parameters(), lr=0.1)
- assert ('Use of uninitialized lazy parameter in Optimizer '
- 'has been detected' in record[0].message.args[0])
-
- @pytest.mark.parametrize('init_src, init_dst', [
- (True, True),
- (True, False),
- (False, True),
- (False, False),
- ])
+ assert (
+ "Use of uninitialized lazy parameter in Optimizer "
+ "has been detected" in record[0].message.args[0]
+ )
+
+ @pytest.mark.parametrize(
+ "init_src, init_dst",
+ [
+ (True, True),
+ (True, False),
+ (False, True),
+ (False, False),
+ ],
+ )
def test_save_load(self, init_src, init_dst):
torch.manual_seed(0)
input = self.get_input()
@@ -148,8 +152,7 @@ def test_save_load(self, init_src, init_dst):
for name in model_dst.lazy_parameter_names
]
module_buffers = [
- getattr(model_dst, name)
- for name in model_dst.lazy_buffer_names
+ getattr(model_dst, name) for name in model_dst.lazy_buffer_names
]
with tempfile.NamedTemporaryFile(delete=False) as f:
@@ -181,7 +184,6 @@ def test_save_load(self, init_src, init_dst):
class TestLazyMyFunc(LazyTestBase):
-
def get_original_module(self):
return _MyFunc(10, 20)
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py
index a27764929..6c236c436 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_batchnorm.py
@@ -1,14 +1,16 @@
import torch
+from pytorch_pfn_extras.nn import ( # NOQA
+ LazyBatchNorm1d,
+ LazyBatchNorm2d,
+ LazyBatchNorm3d,
+)
+from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import (
+ LazyTestBase,
+)
from torch import nn
-from pytorch_pfn_extras.nn import LazyBatchNorm1d, LazyBatchNorm2d, LazyBatchNorm3d # NOQA
-
-from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import \
- LazyTestBase
-
class TestLazyBatchNorm1d(LazyTestBase):
-
def get_original_module(self):
return nn.BatchNorm1d(10)
@@ -20,7 +22,6 @@ def get_input(self):
class TestLazyBatchNorm2d(LazyTestBase):
-
def get_original_module(self):
return nn.BatchNorm2d(10)
@@ -32,7 +33,6 @@ def get_input(self):
class TestLazyBatchNorm3d(LazyTestBase):
-
def get_original_module(self):
return nn.BatchNorm3d(10)
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py
index c71d2d2ee..8488e212c 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_conv.py
@@ -1,14 +1,12 @@
import torch
-from torch import nn
-
from pytorch_pfn_extras.nn import LazyConv1d, LazyConv2d, LazyConv3d
-
-from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import \
- LazyTestBase
+from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import (
+ LazyTestBase,
+)
+from torch import nn
class TestLazyConv1d(LazyTestBase):
-
def get_original_module(self):
return nn.Conv1d(3, 4, 2)
@@ -20,7 +18,6 @@ def get_input(self):
class TestLazyConv2d(LazyTestBase):
-
def get_original_module(self):
return nn.Conv2d(3, 4, 2)
@@ -32,7 +29,6 @@ def get_input(self):
class TestLazyConv3d(LazyTestBase):
-
def get_original_module(self):
return nn.Conv3d(3, 4, 2)
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py
index c57b72d31..f38702c0e 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/modules_tests/test_lazy_linear.py
@@ -1,14 +1,12 @@
import torch
-from torch import nn
-
from pytorch_pfn_extras.nn import LazyLinear
-
-from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import \
- LazyTestBase
+from pytorch_pfn_extras_tests.nn_tests.modules_tests.test_lazy import (
+ LazyTestBase,
+)
+from torch import nn
class TestLazyLinear(LazyTestBase):
-
def get_original_module(self):
return nn.Linear(10, 20)
diff --git a/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py b/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py
index ec2fd1b01..657402d8e 100644
--- a/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py
+++ b/tests/pytorch_pfn_extras_tests/nn_tests/parallel_tests/test_distributed.py
@@ -1,20 +1,18 @@
import os
import sys
-import urllib.request
import tempfile
+import urllib.request
import numpy as np
import pytest
import pytorch_pfn_extras
import torch
-from torch import multiprocessing as mp
+from pytorch_pfn_extras.nn.parallel import DistributedDataParallel
from torch import distributed as dist
+from torch import multiprocessing as mp
from torch import nn
from torch.utils.checkpoint import checkpoint
-from pytorch_pfn_extras.nn.parallel import DistributedDataParallel
-
-
context = mp.get_context("spawn")
@@ -57,8 +55,8 @@ def _to_zero(values, group):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
- self.param0 = nn.Parameter(torch.tensor(-1.))
- self.param1 = nn.Parameter(torch.tensor(1.))
+ self.param0 = nn.Parameter(torch.tensor(-1.0))
+ self.param1 = nn.Parameter(torch.tensor(1.0))
buf = torch.zeros(1)
self.register_buffer("buffer", buf)
@@ -87,9 +85,9 @@ def forward(self, x):
def _run(init_file, input, module, rank, args, step, device_type):
init_method = "file://{}".format(urllib.request.pathname2url(init_file))
- dist.init_process_group(backend="gloo",
- init_method=init_method,
- world_size=2, rank=rank)
+ dist.init_process_group(
+ backend="gloo", init_method=init_method, world_size=2, rank=rank
+ )
if device_type == "cpu":
device = torch.device(device_type)
elif device_type == "cuda":
@@ -106,11 +104,9 @@ def _run(init_file, input, module, rank, args, step, device_type):
return output.detach(), module.state_dict(), grads
-def _launch(inputs,
- modules=None,
- args=None,
- step=Steps._step,
- device_type="cpu"):
+def _launch(
+ inputs, modules=None, args=None, step=Steps._step, device_type="cpu"
+):
procs = []
with tempfile.TemporaryDirectory() as tmpdir, context.Pool(2) as pool:
if modules is None:
@@ -121,9 +117,8 @@ def _launch(inputs,
file = os.path.join(tmpdir, "init")
for i, (input, module) in enumerate(zip(inputs, modules)):
p = pool.apply_async(
- _run,
- args=(file, input, module, i, args, step,
- device_type))
+ _run, args=(file, input, module, i, args, step, device_type)
+ )
procs.append(p)
return [p.get() for p in procs]
@@ -136,39 +131,44 @@ def _device_types():
@pytest.mark.skipif(
- sys.platform == 'win32',
- reason='DDP not fully supported on Windows')
+ sys.platform == "win32", reason="DDP not fully supported on Windows"
+)
class TestDistributedDataParallel:
def test_save_load(self):
module = MyModule()
with_ddp = DistributedDataParallel(module)
assert module.state_dict().keys() == with_ddp.state_dict().keys()
module.load_state_dict(with_ddp.state_dict())
- assert np.array_equal(module.state_dict()["param0"],
- with_ddp.state_dict()["param0"])
- assert np.array_equal(module.state_dict()["param1"],
- with_ddp.state_dict()["param1"])
- assert np.array_equal(module.state_dict()["buffer"],
- with_ddp.state_dict()["buffer"])
-
- @pytest.mark.parametrize('device_type', _device_types())
+ assert np.array_equal(
+ module.state_dict()["param0"], with_ddp.state_dict()["param0"]
+ )
+ assert np.array_equal(
+ module.state_dict()["param1"], with_ddp.state_dict()["param1"]
+ )
+ assert np.array_equal(
+ module.state_dict()["buffer"], with_ddp.state_dict()["buffer"]
+ )
+
+ @pytest.mark.parametrize("device_type", _device_types())
def test_sync_init_params(self, device_type):
module0 = MyModule()
- module0.param0.data = torch.tensor([1.])
+ module0.param0.data = torch.tensor([1.0])
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
modules=[module0, MyModule()],
- device_type=device_type)
+ device_type=device_type,
+ )
assert r0[0].item() == 1
assert r1[0].item() == 2
assert r0[1]["param0"].item() == 1.0
assert r1[1]["param0"].item() == 1.0
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_all_reduce(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
- device_type=device_type)
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
+ device_type=device_type,
+ )
assert r0[0].item() == -1
assert r1[0].item() == -2
assert r0[2]["module.param0"].item() == 1.5
@@ -176,52 +176,59 @@ def test_all_reduce(self, device_type):
assert r0[2]["module.param1"] is None
assert r1[2]["module.param1"] is None
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_specific_reduce(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
args={"reduce_function": Collectives._to_zero},
- device_type=device_type)
+ device_type=device_type,
+ )
assert r0[2]["module.param0"].item() == 0.0
assert r1[2]["module.param0"].item() == 0.0
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_nosync_buffer(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
args={"broadcast_buffers": False},
- device_type=device_type)
+ device_type=device_type,
+ )
assert r0[0].item() == -1
assert r1[0].item() == -2
assert r0[1]["buffer"].item() == 1
assert r1[1]["buffer"].item() == 2
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_sync_buffer(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
args={"broadcast_buffers": True},
- device_type=device_type)
+ device_type=device_type,
+ )
assert r0[0].item() == -1
assert r1[0].item() == -2
assert r0[1]["buffer"].item() == 1
assert r1[1]["buffer"].item() == 1
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_specific_broadcast(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
- args={"broadcast_function": Collectives._to_zero,
- "broadcast_buffers": True},
- device_type=device_type)
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
+ args={
+ "broadcast_function": Collectives._to_zero,
+ "broadcast_buffers": True,
+ },
+ device_type=device_type,
+ )
assert r0[1]["buffer"].item() == 0.0
assert r1[1]["buffer"].item() == 0.0
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_define_by_run(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([-1])],
- device_type=device_type)
+ inputs=[torch.tensor([1.0]), torch.tensor([-1])],
+ device_type=device_type,
+ )
assert r0[0].item() == -1
assert r1[0].item() == 1
assert r0[2]["module.param0"].item() == 0.5
@@ -229,12 +236,13 @@ def test_define_by_run(self, device_type):
assert r0[2]["module.param1"].item() == 0.5
assert r1[2]["module.param1"].item() == 0.5
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_no_sync(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
step=Steps._step_with_no_sync,
- device_type=device_type)
+ device_type=device_type,
+ )
assert r0[0].item() == -1
assert r1[0].item() == -2
assert r0[2]["module.param0"].item() == 1
@@ -242,12 +250,13 @@ def test_no_sync(self, device_type):
assert r0[2]["module.param1"] is None
assert r1[2]["module.param1"] is None
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
def test_hook(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([1.]), torch.tensor([2.])],
+ inputs=[torch.tensor([1.0]), torch.tensor([2.0])],
step=Steps._step_with_hook,
- device_type=device_type)
+ device_type=device_type,
+ )
assert r0[0].item() == -1
assert r1[0].item() == -2
assert r0[2]["module.param0"].item() == 0
@@ -255,19 +264,22 @@ def test_hook(self, device_type):
assert r0[2]["module.param1"].item() == 0
assert r1[2]["module.param1"].item() == 0
- @pytest.mark.parametrize('device_type', _device_types())
+ @pytest.mark.parametrize("device_type", _device_types())
@pytest.mark.skipif(
not pytorch_pfn_extras.requires("1.6.0"),
reason="Variable._execution_engine.queue_callback does not work "
- "with checkpointing when torch < 1.6.0")
+ "with checkpointing when torch < 1.6.0",
+ )
def test_checkpoint(self, device_type):
r0, r1 = _launch(
- inputs=[torch.tensor([[1.]]), torch.tensor([[2.]])],
+ inputs=[torch.tensor([[1.0]]), torch.tensor([[2.0]])],
modules=[MyModuleWithCheckpoint(), MyModuleWithCheckpoint()],
step=Steps._step_with_hook,
- device_type=device_type)
+ device_type=device_type,
+ )
grad0 = r0[2]
grad1 = r1[2]
for key in grad0.keys():
- assert np.array_equal(grad0[key].cpu().numpy(),
- grad1[key].cpu().numpy())
+ assert np.array_equal(
+ grad0[key].cpu().numpy(), grad1[key].cpu().numpy()
+ )
diff --git a/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py b/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py
index 2ac018b4d..00a4dbcb6 100644
--- a/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py
+++ b/tests/pytorch_pfn_extras_tests/profiler_tests/test_record.py
@@ -1,39 +1,34 @@
import os
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
-
-_profiler_available = (
- os.name != 'nt'
- or ppe.requires("1.9")
-)
+_profiler_available = os.name != "nt" or ppe.requires("1.9")
@pytest.mark.skipif(not _profiler_available, reason="profiler is not available")
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_record(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = torch.nn.Linear(30, 40)
model.to(device)
x = torch.arange(30, dtype=torch.float32).to(device)
with torch.profiler.profile() as prof:
- with ppe.profiler.record('my_tag_1'):
+ with ppe.profiler.record("my_tag_1"):
model(x)
keys = [event.key for event in prof.key_averages()]
- assert 'my_tag_1' in keys
- assert 'aten::linear' in keys
+ assert "my_tag_1" in keys
+ assert "aten::linear" in keys
@pytest.mark.skipif(not _profiler_available, reason="profiler is not available")
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_record_without_tag(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = torch.nn.Linear(30, 40)
model.to(device)
@@ -44,19 +39,19 @@ def test_record_without_tag(device):
model(x)
keys = [event.key for event in prof.key_averages()]
- assert 'aten::linear' in keys
- assert any(k.endswith('test_record_without_tag') for k in keys)
+ assert "aten::linear" in keys
+ assert any(k.endswith("test_record_without_tag") for k in keys)
@pytest.mark.skipif(not _profiler_available, reason="profiler is not available")
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_record_function(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = torch.nn.Linear(30, 40)
model.to(device)
- @ppe.profiler.record_function('my_tag_2')
+ @ppe.profiler.record_function("my_tag_2")
def my_run(x):
model(x)
@@ -65,14 +60,14 @@ def my_run(x):
my_run(x)
keys = [event.key for event in prof.key_averages()]
- assert 'aten::linear' in keys
- assert 'my_tag_2' in keys
+ assert "aten::linear" in keys
+ assert "my_tag_2" in keys
@pytest.mark.skipif(not _profiler_available, reason="profiler is not available")
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_record_function_without_tag(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = torch.nn.Linear(30, 40)
model.to(device)
@@ -86,14 +81,14 @@ def my_run(x):
my_run(x)
keys = [event.key for event in prof.key_averages()]
- assert 'aten::linear' in keys
- assert 'my_run' in keys
+ assert "aten::linear" in keys
+ assert "my_run" in keys
@pytest.mark.skipif(not _profiler_available, reason="profiler is not available")
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_record_iterable(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = torch.nn.Linear(30, 40)
model.to(device)
@@ -102,20 +97,20 @@ def test_record_iterable(device):
iters = [x, x, x]
with torch.profiler.profile() as prof:
- for x in ppe.profiler.record_iterable('my_tag_3', iters):
+ for x in ppe.profiler.record_iterable("my_tag_3", iters):
model(x)
keys = [event.key for event in prof.key_averages()]
- assert 'aten::linear' in keys
- assert 'my_tag_3-0' in keys
- assert 'my_tag_3-1' in keys
- assert 'my_tag_3-2' in keys
+ assert "aten::linear" in keys
+ assert "my_tag_3-0" in keys
+ assert "my_tag_3-1" in keys
+ assert "my_tag_3-2" in keys
@pytest.mark.skipif(not _profiler_available, reason="profiler is not available")
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_record_iterable_without_tag(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = torch.nn.Linear(30, 40)
model.to(device)
@@ -128,7 +123,7 @@ def test_record_iterable_without_tag(device):
model(x)
keys = [event.key for event in prof.key_averages()]
- assert 'aten::linear' in keys
- assert any(k.endswith('test_record_iterable_without_tag-0') for k in keys)
- assert any(k.endswith('test_record_iterable_without_tag-1') for k in keys)
- assert any(k.endswith('test_record_iterable_without_tag-2') for k in keys)
+ assert "aten::linear" in keys
+ assert any(k.endswith("test_record_iterable_without_tag-0") for k in keys)
+ assert any(k.endswith("test_record_iterable_without_tag-1") for k in keys)
+ assert any(k.endswith("test_record_iterable_without_tag-2") for k in keys)
diff --git a/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py b/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py
index 91388b7bc..62eafcc33 100644
--- a/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py
+++ b/tests/pytorch_pfn_extras_tests/profiler_tests/test_time_summary.py
@@ -4,7 +4,6 @@
import time
import pytest
-
from pytorch_pfn_extras.profiler import TimeSummary, get_time_summary
@@ -43,8 +42,9 @@ def worker(summary):
@pytest.mark.skipif(
- sys.platform == 'win32',
- reason='Multiprocessing not fully supported on Windows')
+ sys.platform == "win32",
+ reason="Multiprocessing not fully supported on Windows",
+)
def test_report_from_other_process():
summary = TimeSummary()
p = mp.Process(target=worker, args=(summary,))
@@ -65,8 +65,9 @@ def worker1():
@pytest.mark.skipif(
- sys.platform == 'win32',
- reason='Multiprocessing not fully supported on Windows')
+ sys.platform == "win32",
+ reason="Multiprocessing not fully supported on Windows",
+)
def test_global_summary():
time_summary = get_time_summary()
time_summary.initialize()
@@ -98,10 +99,14 @@ def test_clear():
def test_multiprocessing_start_method():
# Ensure that importing PPE does not initialize multiprocessing context.
# See #238 for the context.
- subprocess.check_call([
- sys.executable,
- '-c',
- ('import multiprocessing as mp; '
- + 'import pytorch_pfn_extras; '
- + 'mp.set_start_method("spawn"); ')
- ])
+ subprocess.check_call(
+ [
+ sys.executable,
+ "-c",
+ (
+ "import multiprocessing as mp; "
+ + "import pytorch_pfn_extras; "
+ + 'mp.set_start_method("spawn"); '
+ ),
+ ]
+ )
diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py
index b1115e810..b8a97faee 100644
--- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py
+++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_jit_runtime.py
@@ -1,20 +1,17 @@
import types
import warnings
-import torch
-import torch.nn.functional as F
-
import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.utils.comparer as _comp
+import torch
+import torch.nn.functional as F
from pytorch_pfn_extras.onnx._as_output import trace
class JITRuntime(ppe.runtime.PyTorchRuntime):
-
def move_module(self, module):
-
def new_forward(self, *args):
- if hasattr(self, '_traced_mod'):
+ if hasattr(self, "_traced_mod"):
out = self._traced_mod(*args)
inter_size = len(self._names)
if inter_size == 0:
@@ -23,7 +20,10 @@ def new_forward(self, *args):
out = [out]
return dict(
**{str(i): x for i, x in enumerate(out[:-inter_size])},
- **{name: x for name, x in zip(self._names, out[-inter_size:])},
+ **{
+ name: x
+ for name, x in zip(self._names, out[-inter_size:])
+ },
)
new_forward = self.forward
@@ -33,7 +33,8 @@ def new_forward(self, *args):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self._traced_mod = torch.jit.trace_module(
- new_module, {"forward": args})
+ new_module, {"forward": args}
+ )
self._names = [out.name for out in outputs.values]
self.forward = new_forward
@@ -41,7 +42,7 @@ def new_forward(self, *args):
def forward_with_init(self, *args, **kwargs):
# `module.forward` is called multiple times while tracing.
- handler = getattr(_comp._thread_local, 'handler', None)
+ handler = getattr(_comp._thread_local, "handler", None)
if handler is not None:
handler._reset_intermediate_values()
return self._orig_forward(*args, **kwargs)
@@ -67,9 +68,9 @@ def __init__(self):
def forward(self, x, t):
y = self.model(x)
- prefix = 'train' if self.training else 'val'
+ prefix = "train" if self.training else "val"
loss = F.l1_loss(y, t)
- ppe.reporting.report({prefix + '/loss': loss})
+ ppe.reporting.report({prefix + "/loss": loss})
return loss
@@ -82,26 +83,60 @@ def _get_jit_cpu_model(device_type):
def test_jit_runtime_trainer():
- model, optimizer = _get_jit_cpu_model('jit-cpu')
- trainer = ppe.engine.create_trainer(model, optimizer, 10, device='jit-cpu')
+ model, optimizer = _get_jit_cpu_model("jit-cpu")
+ trainer = ppe.engine.create_trainer(model, optimizer, 10, device="jit-cpu")
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
trainer.run(data)
def test_jit_runtime_evaluator():
- model, optimizer = _get_jit_cpu_model('jit-cpu')
- evaluator = ppe.engine.create_evaluator(model, device='jit-cpu')
+ model, optimizer = _get_jit_cpu_model("jit-cpu")
+ evaluator = ppe.engine.create_evaluator(model, device="jit-cpu")
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
evaluator.run(data)
def test_jit_runtime_trainer_with_evaluator():
- model, optimizer = _get_jit_cpu_model('jit-cpu')
- evaluator = ppe.engine.create_evaluator(model, device='jit-cpu')
+ model, optimizer = _get_jit_cpu_model("jit-cpu")
+ evaluator = ppe.engine.create_evaluator(model, device="jit-cpu")
trainer = ppe.engine.create_trainer(
- model, optimizer, 10, device='jit-cpu', evaluator=evaluator)
+ model, optimizer, 10, device="jit-cpu", evaluator=evaluator
+ )
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
trainer.run(data, data)
diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py
index 696dd2582..2ee50f7ab 100644
--- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py
+++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_registry.py
@@ -1,6 +1,5 @@
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
class FallbackRuntime(ppe.runtime.BaseRuntime):
@@ -13,26 +12,26 @@ class MyCustomRuntime(ppe.runtime.PyTorchRuntime):
def test_registry_register():
registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime)
- registry.register('dummy_device', MyCustomRuntime)
+ registry.register("dummy_device", MyCustomRuntime)
assert (
- registry.get_runtime_class_for_device_spec('dummy_device')
+ registry.get_runtime_class_for_device_spec("dummy_device")
== MyCustomRuntime
)
def test_registry_fallback():
registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime)
- registry.register('dummy_device', MyCustomRuntime)
+ registry.register("dummy_device", MyCustomRuntime)
assert (
- registry.get_runtime_class_for_device_spec('unknown_device')
+ registry.get_runtime_class_for_device_spec("unknown_device")
== FallbackRuntime
)
def test_registry_torch_device():
registry = ppe.runtime._registry._RuntimeRegistry(FallbackRuntime)
- registry.register('cpu', MyCustomRuntime)
+ registry.register("cpu", MyCustomRuntime)
assert (
- registry.get_runtime_class_for_device_spec(torch.device('cpu'))
+ registry.get_runtime_class_for_device_spec(torch.device("cpu"))
== MyCustomRuntime
)
diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py
index aa6b94ad4..b0a921b1a 100644
--- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py
+++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_runtime.py
@@ -1,22 +1,19 @@
import contextlib
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
@pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason='Moving across devices requires CUDA'
+ not torch.cuda.is_available(), reason="Moving across devices requires CUDA"
)
class TestPytorchRuntime:
- @pytest.mark.parametrize('device', ['cpu', 'cuda'])
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize(
- 'batch', [{'x': torch.zeros(1)},
- [torch.zeros(1)],
- torch.zeros(1),
- object()])
+ "batch",
+ [{"x": torch.zeros(1)}, [torch.zeros(1)], torch.zeros(1), object()],
+ )
def test_convert_batch(self, device, batch):
rt = ppe.runtime.PyTorchRuntime(device, {})
cbatch = rt.convert_batch(batch)
@@ -31,14 +28,14 @@ def test_convert_batch(self, device, batch):
else:
assert cbatch is batch
- @pytest.mark.parametrize('device', ['cpu', 'cuda'])
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_move_module(self, device):
rt = ppe.runtime.PyTorchRuntime(device, {})
module = torch.nn.Linear(1, 1)
module = rt.move_module(module)
assert module.weight.device.type == device
- @pytest.mark.parametrize('device', ['cpu', 'cuda'])
+ @pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_move_tensor(self, device):
rt = ppe.runtime.PyTorchRuntime(device, {})
tensor = torch.zeros(10)
@@ -66,20 +63,20 @@ def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
self.layer2 = torch.nn.Linear(10, 10)
- ppe.to(self.layer2, device='dummy', runtime_class=DummyRuntime)
+ ppe.to(self.layer2, device="dummy", runtime_class=DummyRuntime)
def test_runtime_container():
module = MyModule()
# This is a top module, so it won't show child ones
for _ in ppe.runtime._runtime.named_runtime_modules(module):
- pytest.fail('Never reach')
+ pytest.fail("Never reach")
def test_split_runtime_container():
module = SplitModule()
for name, mod in ppe.runtime._runtime.named_runtime_modules(module):
- assert name == 'layer2'
+ assert name == "layer2"
assert mod is module.layer2
@@ -89,24 +86,29 @@ def __init__(self):
super().__init__()
self.layer1 = SplitModule()
self.layer2 = SplitModule()
+
module = MultiLevelSplitModule()
- expected = [('layer2', module.layer1.layer2),
- ('layer2', module.layer2.layer2)]
+ expected = [
+ ("layer2", module.layer1.layer2),
+ ("layer2", module.layer2.layer2),
+ ]
for expected, (name, mod) in zip(
- expected, ppe.runtime._runtime.named_runtime_modules(module)):
+ expected, ppe.runtime._runtime.named_runtime_modules(module)
+ ):
assert name == expected[0]
assert mod is expected[1]
for _ in zip(
- expected, ppe.runtime._runtime.named_runtime_modules(
- module, recursive=False)):
- pytest.fail('Never reach')
+ expected,
+ ppe.runtime._runtime.named_runtime_modules(module, recursive=False),
+ ):
+ pytest.fail("Never reach")
def test_module_change_forward():
class Module1(torch.nn.Module):
def forward(self, input):
- raise RuntimeError('The module forward should never be executed')
+ raise RuntimeError("The module forward should never be executed")
class Module2:
def __init__(self):
@@ -129,7 +131,7 @@ def move_module(self, module):
with pytest.raises(RuntimeError):
module(None)
- ppe.to(module, device='dummy', runtime_class=ForwardIntercepterRuntime)
+ ppe.to(module, device="dummy", runtime_class=ForwardIntercepterRuntime)
assert int(module(None)) == 5
@@ -164,9 +166,9 @@ def trace(cls, event_name, arg):
called = 2
assert called == 0
- with ppe.runtime.BaseRuntime.trace('dummy', None):
+ with ppe.runtime.BaseRuntime.trace("dummy", None):
assert called == 0
assert called == 0
- with TracerRuntime.trace('dummy', None):
+ with TracerRuntime.trace("dummy", None):
assert called == 1
assert called == 2
diff --git a/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py b/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py
index 86a46f7b9..37eb725db 100644
--- a/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py
+++ b/tests/pytorch_pfn_extras_tests/runtime_tests/test_to.py
@@ -1,13 +1,13 @@
import io
-import pytest
-import torch
+import pytest
import pytorch_pfn_extras as ppe
+import torch
@pytest.mark.gpu
def test_tensor_ppe_to():
- device = 'cuda:0'
+ device = "cuda:0"
tensor = torch.zeros(10)
out = ppe.to(tensor, device)
assert str(out.device) == device
@@ -22,7 +22,7 @@ def __init__(self):
@pytest.mark.gpu
def test_module_ppe_to():
- device = 'cuda:0'
+ device = "cuda:0"
module = MyModule()
ppe.to(module, device)
assert all([str(p.device) == device for p in module.parameters()])
@@ -30,7 +30,7 @@ def test_module_ppe_to():
def test_invalid_ppe_to():
- device = 'cpu'
+ device = "cpu"
with pytest.raises(ValueError):
ppe.to(object(), device)
@@ -46,24 +46,23 @@ def initialize_module(self, module, loader_or_batch):
def test_module_split_ppe_to():
module = MyModule()
- ppe.to(module.layer2, 'dummy', runtime_class=MyRuntime,
- options={'opt': 1})
+ ppe.to(module.layer2, "dummy", runtime_class=MyRuntime, options={"opt": 1})
rt_layer1 = ppe.runtime._runtime._module_runtime_tag(module.layer1)
rt_layer2 = ppe.runtime._runtime._module_runtime_tag(module.layer2)
assert str(next(iter(module.layer1.parameters())).device) == "cpu"
assert rt_layer1 is None
assert isinstance(rt_layer2, MyRuntime)
- assert rt_layer2.device_spec == 'dummy'
- assert rt_layer2.options['opt'] == 1
+ assert rt_layer2.device_spec == "dummy"
+ assert rt_layer2.options["opt"] == 1
def test_module_split_ppe_to_config():
# Deprecated "config" option.
module = MyModule()
- ppe.to(module, 'dummy', runtime_class=MyRuntime, config={'opt': 1})
+ ppe.to(module, "dummy", runtime_class=MyRuntime, config={"opt": 1})
rt_layer1 = ppe.runtime._runtime._module_runtime_tag(module)
assert isinstance(rt_layer1, MyRuntime)
- assert rt_layer1.options['opt'] == 1
+ assert rt_layer1.options["opt"] == 1
class NonPicklableRuntime(ppe.runtime.BaseRuntime):
diff --git a/tests/pytorch_pfn_extras_tests/test_config.py b/tests/pytorch_pfn_extras_tests/test_config.py
index 302b6b0e7..60c151d5b 100644
--- a/tests/pytorch_pfn_extras_tests/test_config.py
+++ b/tests/pytorch_pfn_extras_tests/test_config.py
@@ -2,28 +2,26 @@
import os
import tempfile
import unittest
-import pytest
-from pytorch_pfn_extras.config import Config
-from pytorch_pfn_extras.config import customize_type
+import pytest
+from pytorch_pfn_extras.config import Config, customize_type
def func_0(a, b, c=10):
return a + b + c
-@customize_type(c='/foo/v0')
+@customize_type(c="/foo/v0")
def func_1(a, b, c):
- return {'d': a * b, 'e': c}
+ return {"d": a * b, "e": c}
-@customize_type(config='!/')
+@customize_type(config="!/")
def func_2(config):
return json.dumps(config)
class Cls0(object):
-
def __init__(self, a, b, c=10):
self.a = a
self.b = b
@@ -33,9 +31,8 @@ def __eq__(self, other):
return (self.a, self.b, self.c) == (other.a, other.b, other.c)
-@customize_type(c='../../foo/v0')
+@customize_type(c="../../foo/v0")
class Cls1(object):
-
def __init__(self, a, b, c):
self.d = a * b
self.e = c
@@ -45,263 +42,280 @@ def __eq__(self, other):
class TestConfig(unittest.TestCase):
-
types = {
- 'func_0': func_0,
- 'func_1': func_1,
- 'func_2': func_2,
- 'cls_0': Cls0,
- 'cls_1': Cls1,
+ "func_0": func_0,
+ "func_1": func_1,
+ "func_2": func_2,
+ "cls_0": Cls0,
+ "cls_1": Cls1,
}
def test_config(self):
- config = Config({
- 'foo': {
- 'v0': {'type': 'func_0', 'a': 1, 'b': 2},
- 'v1': {'type': 'func_0', 'a': 1, 'b': 2, 'c': 3},
- 'v2': {'type': 'func_1', 'a': 1, 'b': 2},
- 'v3': {'type': 'func_1', 'a': 1, 'b': 2, 'c': 3},
- },
- 'bar': [
- {'type': 'cls_0', 'a': 1, 'b': 2},
- {'type': 'cls_0', 'a': 1, 'b': 2, 'c': 3},
- {'type': 'cls_1', 'a': 1, 'b': 2},
- {'type': 'cls_1', 'a': 1, 'b': 2, 'c': 3},
- ],
- 'baz': {
- 'v0': '@/foo/v2.d',
- 'v1': '@../bar/1/c',
- 'v2': '@/bar/3.d',
- 'v3': '@../foo/v3',
- }
- }, self.types)
-
- self.assertEqual(config['/'], {
- 'foo': {
- 'v0': 13,
- 'v1': 6,
- 'v2': {'d': 2, 'e': 13},
- 'v3': {'d': 2, 'e': 3},
+ config = Config(
+ {
+ "foo": {
+ "v0": {"type": "func_0", "a": 1, "b": 2},
+ "v1": {"type": "func_0", "a": 1, "b": 2, "c": 3},
+ "v2": {"type": "func_1", "a": 1, "b": 2},
+ "v3": {"type": "func_1", "a": 1, "b": 2, "c": 3},
+ },
+ "bar": [
+ {"type": "cls_0", "a": 1, "b": 2},
+ {"type": "cls_0", "a": 1, "b": 2, "c": 3},
+ {"type": "cls_1", "a": 1, "b": 2},
+ {"type": "cls_1", "a": 1, "b": 2, "c": 3},
+ ],
+ "baz": {
+ "v0": "@/foo/v2.d",
+ "v1": "@../bar/1/c",
+ "v2": "@/bar/3.d",
+ "v3": "@../foo/v3",
+ },
},
- 'bar': [
- Cls0(1, 2, 10),
- Cls0(1, 2, 3),
- Cls1(1, 2, 13),
- Cls1(1, 2, 3),
- ],
- 'baz': {
- 'v0': 2,
- 'v1': 3,
- 'v2': 2,
- 'v3': {'d': 2, 'e': 3},
+ self.types,
+ )
+
+ self.assertEqual(
+ config["/"],
+ {
+ "foo": {
+ "v0": 13,
+ "v1": 6,
+ "v2": {"d": 2, "e": 13},
+ "v3": {"d": 2, "e": 3},
+ },
+ "bar": [
+ Cls0(1, 2, 10),
+ Cls0(1, 2, 3),
+ Cls1(1, 2, 13),
+ Cls1(1, 2, 3),
+ ],
+ "baz": {
+ "v0": 2,
+ "v1": 3,
+ "v2": 2,
+ "v3": {"d": 2, "e": 3},
+ },
},
- })
+ )
def test_config_escape(self):
pre_eval_config = {
- 'foo': {
- 'v0': {'type': 'func_0', 'a': 1, 'b': 2},
+ "foo": {
+ "v0": {"type": "func_0", "a": 1, "b": 2},
},
- 'bar': {'type': 'func_2'},
+ "bar": {"type": "func_2"},
}
config = Config(pre_eval_config, self.types)
- self.assertEqual(config['!/foo'], {
- 'v0': {'type': 'func_0', 'a': 1, 'b': 2},
- })
- self.assertEqual(json.loads(config['/bar']), pre_eval_config)
+ self.assertEqual(
+ config["!/foo"],
+ {
+ "v0": {"type": "func_0", "a": 1, "b": 2},
+ },
+ )
+ self.assertEqual(json.loads(config["/bar"]), pre_eval_config)
def test_config_load_path(self):
- with tempfile.TemporaryDirectory() as temp0, \
- tempfile.TemporaryDirectory() as temp1:
- with open(os.path.join(temp0, 'foo.json'), mode='w') as f:
- json.dump({
- 'foo': {'v0': {'type': 'func_0', 'a': 1, 'b': 2}},
- 'bar': {'import': os.path.join(temp1, 'bar.json')},
- 'baz': {
- 'import': 'baz.json',
- '0/b': 3,
- '1/d': [1, 2],
+ with tempfile.TemporaryDirectory() as temp0, tempfile.TemporaryDirectory() as temp1:
+ with open(os.path.join(temp0, "foo.json"), mode="w") as f:
+ json.dump(
+ {
+ "foo": {"v0": {"type": "func_0", "a": 1, "b": 2}},
+ "bar": {"import": os.path.join(temp1, "bar.json")},
+ "baz": {
+ "import": "baz.json",
+ "0/b": 3,
+ "1/d": [1, 2],
+ },
},
- }, f)
- with open(os.path.join(temp1, 'bar.json'), mode='w') as f:
- json.dump({'type': 'func_0', 'a': 3, 'b': 4}, f)
- with open(os.path.join(temp0, 'baz.json'), mode='w') as f:
- json.dump([
- {'type': 'func_1', 'a': 1, 'b': 2},
- {'d': 3, 'e': 4},
- ], f)
+ f,
+ )
+ with open(os.path.join(temp1, "bar.json"), mode="w") as f:
+ json.dump({"type": "func_0", "a": 3, "b": 4}, f)
+ with open(os.path.join(temp0, "baz.json"), mode="w") as f:
+ json.dump(
+ [
+ {"type": "func_1", "a": 1, "b": 2},
+ {"d": 3, "e": 4},
+ ],
+ f,
+ )
config = Config.load_path(
- os.path.join(temp0, 'foo.json'), types=self.types)
-
- self.assertEqual(config['!/foo'],
- {'v0': {'type': 'func_0', 'a': 1, 'b': 2}})
- self.assertEqual(config['/foo'], {'v0': 13})
- self.assertEqual(config['!/bar'], {'type': 'func_0', 'a': 3, 'b': 4})
- self.assertEqual(config['/bar'], 17)
- self.assertEqual(config['!/baz'], [
- {'type': 'func_1', 'a': 1, 'b': 3},
- {'d': [1, 2], 'e': 4},
- ])
- self.assertEqual(config['/baz'], [
- {'d': 3, 'e': 13},
- {'d': [1, 2], 'e': 4},
- ])
+ os.path.join(temp0, "foo.json"), types=self.types
+ )
+
+ self.assertEqual(
+ config["!/foo"], {"v0": {"type": "func_0", "a": 1, "b": 2}}
+ )
+ self.assertEqual(config["/foo"], {"v0": 13})
+ self.assertEqual(config["!/bar"], {"type": "func_0", "a": 3, "b": 4})
+ self.assertEqual(config["/bar"], 17)
+ self.assertEqual(
+ config["!/baz"],
+ [
+ {"type": "func_1", "a": 1, "b": 3},
+ {"d": [1, 2], "e": 4},
+ ],
+ )
+ self.assertEqual(
+ config["/baz"],
+ [
+ {"d": 3, "e": 13},
+ {"d": [1, 2], "e": 4},
+ ],
+ )
def test_config_with_config_key_invalid_index(self):
- config = Config([['a'], [['b', ['c', 'd']]]])
+ config = Config([["a"], [["b", ["c", "d"]]]])
with self.assertRaises(IndexError) as cm:
- config['/1/2/3']
+ config["/1/2/3"]
self.assertEqual(
cm.exception.args[-2:],
- ('2 not in !/1',
- '/1/2/3 -> !/1/2/3 -> !/1/2'))
+ ("2 not in !/1", "/1/2/3 -> !/1/2/3 -> !/1/2"),
+ )
def test_config_with_config_key_invalid_key(self):
- config = Config({'foo': {'bar': {'baz': None}}})
+ config = Config({"foo": {"bar": {"baz": None}}})
with self.assertRaises(KeyError) as cm:
- config['/foo/Bar/baz']
+ config["/foo/Bar/baz"]
self.assertEqual(
cm.exception.args[-2:],
- ('Bar not in !/foo',
- '/foo/Bar/baz -> !/foo/Bar/baz -> !/foo/Bar'))
+ ("Bar not in !/foo", "/foo/Bar/baz -> !/foo/Bar/baz -> !/foo/Bar"),
+ )
def test_config_with_config_key_invalid_type(self):
- config = Config({'foo': [['b', {'baz': None}]]})
+ config = Config({"foo": [["b", {"baz": None}]]})
with self.assertRaises(TypeError) as cm:
- config['/foo/bar/baz']
+ config["/foo/bar/baz"]
self.assertEqual(
cm.exception.args[-2:],
- ('bar not in !/foo',
- '/foo/bar/baz -> !/foo/bar/baz -> !/foo/bar'))
+ ("bar not in !/foo", "/foo/bar/baz -> !/foo/bar/baz -> !/foo/bar"),
+ )
def test_config_with_attr_key_invalid_index(self):
- config = Config([['a'], [['b', ['c', 'd']]]])
+ config = Config([["a"], [["b", ["c", "d"]]]])
with self.assertRaises(IndexError) as cm:
- config['/.1.2.3']
+ config["/.1.2.3"]
self.assertEqual(
cm.exception.args[-2:],
- ('2 not in /.1 ([[\'b\', [\'c\', \'d\']]])',
- '/.1.2.3 -> /.1.2'))
+ ("2 not in /.1 ([['b', ['c', 'd']]])", "/.1.2.3 -> /.1.2"),
+ )
def test_config_with_attr_key_invalid_key(self):
- config = Config({'foo': {'bar': {'baz': None}}})
+ config = Config({"foo": {"bar": {"baz": None}}})
with self.assertRaises(KeyError) as cm:
- config['/.foo.Bar.baz']
+ config["/.foo.Bar.baz"]
self.assertEqual(
cm.exception.args[-2:],
- ('Bar not in /.foo ({\'bar\': {\'baz\': None}})',
- '/.foo.Bar.baz -> /.foo.Bar'))
+ (
+ "Bar not in /.foo ({'bar': {'baz': None}})",
+ "/.foo.Bar.baz -> /.foo.Bar",
+ ),
+ )
def test_config_with_attr_key_invalid_type(self):
- config = Config({'foo': [['b', {'baz': None}]]})
+ config = Config({"foo": [["b", {"baz": None}]]})
with self.assertRaises(TypeError) as cm:
- config['/.foo.bar.baz']
+ config["/.foo.bar.baz"]
self.assertEqual(
cm.exception.args[-2:],
- ('bar not in /.foo ([[\'b\', {\'baz\': None}]])',
- '/.foo.bar.baz -> /.foo.bar'))
+ (
+ "bar not in /.foo ([['b', {'baz': None}]])",
+ "/.foo.bar.baz -> /.foo.bar",
+ ),
+ )
def test_config_with_invalid_type(self):
- config = Config({'foo': [{'type': 'foo', 'a': 0, 'b': 1}]})
+ config = Config({"foo": [{"type": "foo", "a": 0, "b": 1}]})
with self.assertRaises(KeyError) as cm:
- config['/']
+ config["/"]
self.assertEqual(
- cm.exception.args[-2:],
- ('foo not in types',
- '/ -> /foo -> /foo/0'))
+ cm.exception.args[-2:], ("foo not in types", "/ -> /foo -> /foo/0")
+ )
def test_config_with_invalid_call(self):
def foo():
- raise RuntimeError('foo')
+ raise RuntimeError("foo")
config = Config(
- {'foo': [{'type': 'foo', 'a': 0, 'b': 1}]},
- types={'foo': foo})
+ {"foo": [{"type": "foo", "a": 0, "b": 1}]}, types={"foo": foo}
+ )
with self.assertRaises(TypeError) as cm:
- config['/']
+ config["/"]
- self.assertEqual(
- cm.exception.args[-1:],
- ('/ -> /foo -> /foo/0',))
+ self.assertEqual(cm.exception.args[-1:], ("/ -> /foo -> /foo/0",))
def test_config_with_circular_dependency(self):
- config = Config({'foo': '@/bar', 'bar': '@foo.d'})
+ config = Config({"foo": "@/bar", "bar": "@foo.d"})
with self.assertRaises(RuntimeError) as cm:
- config['/']
+ config["/"]
self.assertIn(
cm.exception.args,
{
- ('Circular dependency', '/ -> /foo -> /bar -> /foo.d -> /foo'),
- ('Circular dependency', '/ -> /bar -> /foo.d -> /foo -> /bar'),
- })
+ ("Circular dependency", "/ -> /foo -> /bar -> /foo.d -> /foo"),
+ ("Circular dependency", "/ -> /bar -> /foo.d -> /foo -> /bar"),
+ },
+ )
def test_config_with_circular_import(self):
with tempfile.TemporaryDirectory() as temp:
- with open(os.path.join(temp, 'foo.json'), mode='w') as f:
- json.dump({'a': {'import': 'bar.json'}}, f)
- with open(os.path.join(temp, 'bar.json'), mode='w') as f:
- json.dump([{'import': './foo.json'}], f)
+ with open(os.path.join(temp, "foo.json"), mode="w") as f:
+ json.dump({"a": {"import": "bar.json"}}, f)
+ with open(os.path.join(temp, "bar.json"), mode="w") as f:
+ json.dump([{"import": "./foo.json"}], f)
with self.assertRaises(RuntimeError) as cm:
- Config.load_path(os.path.join(temp, 'foo.json'))
+ Config.load_path(os.path.join(temp, "foo.json"))
self.assertEqual(
cm.exception.args,
- ('Circular import',
- '!/ of {foo} -> !/a of {foo} -> !/ of {bar}'
- ' -> !/0 of {bar} -> !/ of {foo}'.format(
- foo=os.path.join(temp, 'foo.json'),
- bar=os.path.join(temp, 'bar.json'))))
+ (
+ "Circular import",
+ "!/ of {foo} -> !/a of {foo} -> !/ of {bar}"
+ " -> !/0 of {bar} -> !/ of {foo}".format(
+ foo=os.path.join(temp, "foo.json"),
+ bar=os.path.join(temp, "bar.json"),
+ ),
+ ),
+ )
def test_config_with_args_update(self):
- config = Config({
- 'foo': {
- 'ls': ['first']
- }
- }, self.types)
+ config = Config({"foo": {"ls": ["first"]}}, self.types)
- assert config['/foo/ls/0'] == 'first'
- config.update_via_args([('/foo/ls/0', 'changed')])
- assert config['/foo/ls/0'] == 'changed'
+ assert config["/foo/ls/0"] == "first"
+ config.update_via_args([("/foo/ls/0", "changed")])
+ assert config["/foo/ls/0"] == "changed"
def test_config_with_args_update_type_conversion(self):
- config = Config({
- 'foo': {
- 'ls': [0]
- }
- }, self.types)
+ config = Config({"foo": {"ls": [0]}}, self.types)
- assert config['/foo/ls/0'] == 0
- config.update_via_args([('/foo/ls/0', '16')])
- assert config['/foo/ls/0'] == 16
+ assert config["/foo/ls/0"] == 0
+ config.update_via_args([("/foo/ls/0", "16")])
+ assert config["/foo/ls/0"] == 16
def test_config_with_args_update_type_conversion_bool(self):
- config = Config({
- 'foo': {
- 'ls': [True]
- }
- }, self.types)
-
- assert config['/foo/ls/0']
- config.update_via_args([('/foo/ls/0', False)])
- assert not config['/foo/ls/0']
- config.update_via_args([('/foo/ls/0', "False")])
- assert not config['/foo/ls/0']
- config.update_via_args([('/foo/ls/0', "true")])
- assert config['/foo/ls/0']
- config.update_via_args([('/foo/ls/0', "TRUE")])
- assert config['/foo/ls/0']
- config.update_via_args([('/foo/ls/0', "false")])
- assert not config['/foo/ls/0']
+ config = Config({"foo": {"ls": [True]}}, self.types)
+
+ assert config["/foo/ls/0"]
+ config.update_via_args([("/foo/ls/0", False)])
+ assert not config["/foo/ls/0"]
+ config.update_via_args([("/foo/ls/0", "False")])
+ assert not config["/foo/ls/0"]
+ config.update_via_args([("/foo/ls/0", "true")])
+ assert config["/foo/ls/0"]
+ config.update_via_args([("/foo/ls/0", "TRUE")])
+ assert config["/foo/ls/0"]
+ config.update_via_args([("/foo/ls/0", "false")])
+ assert not config["/foo/ls/0"]
with pytest.raises(ValueError):
- config.update_via_args([('/foo/ls/0', "alse")])
+ config.update_via_args([("/foo/ls/0", "alse")])
diff --git a/tests/pytorch_pfn_extras_tests/test_config_types.py b/tests/pytorch_pfn_extras_tests/test_config_types.py
index fe637b05a..95d559abb 100644
--- a/tests/pytorch_pfn_extras_tests/test_config_types.py
+++ b/tests/pytorch_pfn_extras_tests/test_config_types.py
@@ -4,90 +4,99 @@
import unittest
import optuna
-
-from pytorch_pfn_extras.config_types import optuna_types
-from pytorch_pfn_extras.config_types import load_path_with_optuna_types
+from pytorch_pfn_extras.config_types import (
+ load_path_with_optuna_types,
+ optuna_types,
+)
class TestConfigTypes(unittest.TestCase):
-
study = optuna.create_study()
def test_config_optuna_types(self):
def objective(trial):
types = optuna_types(trial)
self.assertEqual(
- types['optuna_suggest_categorical'],
- trial.suggest_categorical)
- self.assertEqual(
- types['optuna_suggest_discrete_uniform'],
- trial.suggest_discrete_uniform)
- self.assertEqual(
- types['optuna_suggest_float'],
- trial.suggest_float)
+ types["optuna_suggest_categorical"], trial.suggest_categorical
+ )
self.assertEqual(
- types['optuna_suggest_int'],
- trial.suggest_int)
+ types["optuna_suggest_discrete_uniform"],
+ trial.suggest_discrete_uniform,
+ )
+ self.assertEqual(types["optuna_suggest_float"], trial.suggest_float)
+ self.assertEqual(types["optuna_suggest_int"], trial.suggest_int)
self.assertEqual(
- types['optuna_suggest_loguniform'],
- trial.suggest_loguniform)
+ types["optuna_suggest_loguniform"], trial.suggest_loguniform
+ )
self.assertEqual(
- types['optuna_suggest_uniform'],
- trial.suggest_uniform)
+ types["optuna_suggest_uniform"], trial.suggest_uniform
+ )
return 0.0
+
self.study.optimize(objective, n_trials=1)
def test_load_path_with_optuna_types(self):
low = 0
high = 8
with tempfile.TemporaryDirectory() as temp0:
- with open(os.path.join(temp0, 'foo.json'), mode='w') as f:
- json.dump({
- 'foo': {
- 'type': 'optuna_suggest_int',
- 'name': 'a',
- 'low': low,
- 'high': high
- }
- }, f)
+ with open(os.path.join(temp0, "foo.json"), mode="w") as f:
+ json.dump(
+ {
+ "foo": {
+ "type": "optuna_suggest_int",
+ "name": "a",
+ "low": low,
+ "high": high,
+ }
+ },
+ f,
+ )
def objective(trial):
config = load_path_with_optuna_types(
- os.path.join(temp0, 'foo.json'), trial)
- self.assertIsInstance(config['/foo'], int)
- self.assertGreaterEqual(config['/foo'], low)
- self.assertLessEqual(config['/foo'], high)
+ os.path.join(temp0, "foo.json"), trial
+ )
+ self.assertIsInstance(config["/foo"], int)
+ self.assertGreaterEqual(config["/foo"], low)
+ self.assertLessEqual(config["/foo"], high)
return 0.0
+
self.study.optimize(objective, n_trials=2 * (high - low + 1))
def test_load_path_with_optuna_types_with_types_argument(self):
low = 0
high = 8
with tempfile.TemporaryDirectory() as temp0:
- with open(os.path.join(temp0, 'foo.json'), mode='w') as f:
- json.dump({
- 'foo': {
- 'type': 'optuna_suggest_int',
- 'name': 'a',
- 'low': low,
- 'high': high
+ with open(os.path.join(temp0, "foo.json"), mode="w") as f:
+ json.dump(
+ {
+ "foo": {
+ "type": "optuna_suggest_int",
+ "name": "a",
+ "low": low,
+ "high": high,
+ },
+ "bar": {"type": "dict", "x": 0},
},
- 'bar': {
- 'type': 'dict',
- 'x': 0
- }
- }, f)
+ f,
+ )
def objective(trial):
config = load_path_with_optuna_types(
- os.path.join(temp0, 'foo.json'), trial,
- types={'optuna_suggest_int': float, 'dict': dict})
- self.assertIsInstance(config['/foo'], int)
- self.assertGreaterEqual(config['/foo'], low)
- self.assertLessEqual(config['/foo'], high)
- self.assertIsInstance(config['/bar'], dict)
- self.assertEqual(config['/bar']['x'], 0)
+ os.path.join(temp0, "foo.json"),
+ trial,
+ types={"optuna_suggest_int": float, "dict": dict},
+ )
+ self.assertIsInstance(config["/foo"], int)
+ self.assertGreaterEqual(config["/foo"], low)
+ self.assertLessEqual(config["/foo"], high)
+ self.assertIsInstance(config["/bar"], dict)
+ self.assertEqual(config["/bar"]["x"], 0)
return 0.0
+
self.assertWarns(
- UserWarning, self.study.optimize, objective,
- n_trials=2 * (high - low + 1))
+ UserWarning,
+ self.study.optimize,
+ objective,
+ n_trials=2 * (high - low + 1),
+ )
diff --git a/tests/pytorch_pfn_extras_tests/test_handler.py b/tests/pytorch_pfn_extras_tests/test_handler.py
index 87e7b381f..cf3615f2e 100644
--- a/tests/pytorch_pfn_extras_tests/test_handler.py
+++ b/tests/pytorch_pfn_extras_tests/test_handler.py
@@ -1,10 +1,9 @@
-import unittest.mock
import contextlib
+import unittest.mock
-import torch
import pytest
-
import pytorch_pfn_extras as ppe
+import torch
def torch_testing_assert_close(*args, **kwargs):
@@ -91,14 +90,14 @@ def forward(self, x):
class MockTrainer:
def __init__(self):
- self.models = {'main': MockModule()}
+ self.models = {"main": MockModule()}
self.optimizers = {}
self.epoch = 0
class MockEvaluator:
def __init__(self):
- self.models = {'main': MockModule()}
+ self.models = {"main": MockModule()}
self.optimizers = {}
self.epoch = 0
@@ -112,7 +111,7 @@ def train_epoch_end(self, epoch, models):
def train_step(self, models, optimizers, batch_idx, batch):
assert batch.converted
- return models['main'](batch)
+ return models["main"](batch)
def train_step_optimizers(self, models, optimizers, batch_idx):
self._train_step_optimizers_called = True
@@ -122,111 +121,110 @@ def train_validation_begin(self, models):
def eval_step(self, models, batch_idx, batch):
assert batch.converted
- return models['main'](batch)
+ return models["main"](batch)
class HandlerTester:
def _get_handler(self, options=None):
if options is None:
options = {}
- ppe.runtime.runtime_registry.register('test_rt', MockRuntime)
+ ppe.runtime.runtime_registry.register("test_rt", MockRuntime)
trainer = MockTrainer()
logic = MockLogic()
handler = ppe.handler.Handler(
- logic, MockRuntime('test_rt', {}), options
+ logic, MockRuntime("test_rt", {}), options
)
return handler, trainer, logic
def _move_modules(self, module, to_move):
for name in to_move:
- if name == 'self':
- ppe.to(module, 'test_rt')
+ if name == "self":
+ ppe.to(module, "test_rt")
else:
- ppe.to(getattr(module, name), 'test_rt')
+ ppe.to(getattr(module, name), "test_rt")
def _assert_called(self, module, to_move, function):
for name, mod in module.named_modules():
- if (mod is module and 'self' in to_move) or (name in to_move):
- assert getattr(mod._ppe_runtime, f'_{function}_called')
+ if (mod is module and "self" in to_move) or (name in to_move):
+ assert getattr(mod._ppe_runtime, f"_{function}_called")
assert mod._ppe_runtime._called_module == mod
else:
- if hasattr(mod, '_ppe_runtime'):
+ if hasattr(mod, "_ppe_runtime"):
assert mod._ppe_runtime._called_module != mod
class TestHandlerTrainSync(HandlerTester):
-
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_setup(self, to_move):
handler, trainer, _ = self._get_handler()
- module = trainer.models['main']
+ module = trainer.models["main"]
self._move_modules(module, to_move)
handler.train_setup(trainer, [])
- self._assert_called(module, to_move, 'initialize')
+ self._assert_called(module, to_move, "initialize")
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_cleanup(self, to_move):
handler, trainer, _ = self._get_handler()
- module = trainer.models['main']
+ module = trainer.models["main"]
self._move_modules(module, to_move)
handler.train_cleanup(trainer)
- self._assert_called(module, to_move, 'train_cleanup')
+ self._assert_called(module, to_move, "train_cleanup")
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_epoch_begin(self, to_move):
handler, trainer, logic = self._get_handler()
- module = trainer.models['main']
+ module = trainer.models["main"]
self._move_modules(module, to_move)
handler.train_epoch_begin(trainer, [])
- self._assert_called(module, to_move, 'train_epoch_begin')
+ self._assert_called(module, to_move, "train_epoch_begin")
assert logic._train_epoch_begin_called
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_epoch_end(self, to_move):
handler, trainer, logic = self._get_handler()
- module = trainer.models['main']
+ module = trainer.models["main"]
self._move_modules(module, to_move)
# Should check that the handler completes
handler.train_epoch_end(trainer)
- self._assert_called(module, to_move, 'train_epoch_end')
+ self._assert_called(module, to_move, "train_epoch_end")
assert logic._train_epoch_end_called
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_step(self, to_move):
handler, trainer, logic = self._get_handler()
- module = trainer.models['main']
+ module = trainer.models["main"]
self._move_modules(module, to_move)
callback = unittest.mock.Mock(return_value=None)
handler.train_step(trainer, 0, None, callback)
callback.assert_called_once_with(0, 1)
- self._assert_called(module, to_move, 'train_pre_step')
+ self._assert_called(module, to_move, "train_pre_step")
assert logic._train_step_optimizers_called
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_post_step(self, to_move):
- options = {'train_report_keys': ['output']}
+ options = {"train_report_keys": ["output"]}
handler, trainer, _ = self._get_handler(options)
- module = trainer.models['main']
+ module = trainer.models["main"]
self._move_modules(module, to_move)
reporter = ppe.reporting.Reporter()
with reporter:
- handler.train_post_step(trainer, 0, None, {'output': 1})
- assert reporter.observation['train/output'] == 1
- self._assert_called(module, to_move, 'train_post_step')
+ handler.train_post_step(trainer, 0, None, {"output": 1})
+ assert reporter.observation["train/output"] == 1
+ self._assert_called(module, to_move, "train_post_step")
class TestHandlerValidationSync(HandlerTester):
@@ -236,53 +234,53 @@ def _get_handler(self, options=None):
return handler, evaluator, logic
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_eval_setup(self, to_move):
handler, evaluator, _ = self._get_handler()
- module = evaluator.models['main']
+ module = evaluator.models["main"]
self._move_modules(module, to_move)
handler.eval_setup(evaluator, [])
- self._assert_called(module, to_move, 'initialize')
+ self._assert_called(module, to_move, "initialize")
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_validation_begin(self, to_move):
handler, evaluator, logic = self._get_handler()
- module = evaluator.models['main']
+ module = evaluator.models["main"]
self._move_modules(module, to_move)
handler.train_validation_begin(None, evaluator)
- self._assert_called(module, to_move, 'train_validation_begin')
+ self._assert_called(module, to_move, "train_validation_begin")
assert logic._train_validation_begin_called
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_eval_step(self, to_move):
handler, evaluator, logic = self._get_handler()
- module = evaluator.models['main']
+ module = evaluator.models["main"]
self._move_modules(module, to_move)
callback = unittest.mock.Mock(return_value=None)
handler.eval_step(evaluator, 0, None, callback)
callback.assert_called_once_with(0, 1)
- self._assert_called(module, to_move, 'eval_pre_step')
+ self._assert_called(module, to_move, "eval_pre_step")
@pytest.mark.parametrize(
- 'to_move', [('self',), ('sm1',), ('sm2',), ('sm1', 'sm2')]
+ "to_move", [("self",), ("sm1",), ("sm2",), ("sm1", "sm2")]
)
def test_train_post_step(self, to_move):
- options = {'eval_report_keys': ['output']}
+ options = {"eval_report_keys": ["output"]}
handler, evaluator, _ = self._get_handler(options)
- module = evaluator.models['main']
+ module = evaluator.models["main"]
self._move_modules(module, to_move)
reporter = ppe.reporting.Reporter()
with reporter:
- handler.eval_post_step(evaluator, 0, None, {'output': 1})
- assert reporter.observation['val/output'] == 1
- self._assert_called(module, to_move, 'eval_post_step')
+ handler.eval_post_step(evaluator, 0, None, {"output": 1})
+ assert reporter.observation["val/output"] == 1
+ self._assert_called(module, to_move, "eval_post_step")
@pytest.mark.gpu
@@ -291,7 +289,7 @@ def run_autocast(self, options):
trainer = MockTrainer()
logic = ppe.handler.Logic(options=options)
handler = ppe.handler.Handler(
- logic, ppe.runtime.PyTorchRuntime('cuda', {}), {}
+ logic, ppe.runtime.PyTorchRuntime("cuda", {}), {}
)
completed = False
@@ -300,11 +298,11 @@ class _MModule(torch.nn.Module):
def forward(self, x, y):
return torch.mm(x, y)
- trainer.models['main'] = _MModule()
- trainer.optimizers['main'] = torch.optim.SGD(
+ trainer.models["main"] = _MModule()
+ trainer.optimizers["main"] = torch.optim.SGD(
[torch.nn.Parameter(torch.zeros(10))], 0.01
)
- ppe.to(trainer.models['main'], 'cuda')
+ ppe.to(trainer.models["main"], "cuda")
completed = False
def callback(batch_idx, outs):
@@ -322,38 +320,38 @@ def callback(batch_idx, outs):
completed = True
inputs = {
- 'x': torch.rand((2, 2)).cuda(),
- 'y': torch.rand((2, 2)).cuda(),
+ "x": torch.rand((2, 2)).cuda(),
+ "y": torch.rand((2, 2)).cuda(),
}
handler.train_step(trainer, 0, inputs, callback)
assert completed
- @pytest.mark.parametrize('autocast', [True, False])
+ @pytest.mark.parametrize("autocast", [True, False])
def test_autocast(self, autocast):
- self.run_autocast({'autocast': autocast})
+ self.run_autocast({"autocast": autocast})
def test_autocast_not_enabled(self):
old_enable = ppe.runtime._autocast._cuda_amp_available
try:
ppe.runtime._autocast._cuda_amp_available = False
with pytest.raises(RuntimeError):
- ppe.handler.Logic(options={'autocast': True})
+ ppe.handler.Logic(options={"autocast": True})
finally:
ppe.runtime._autocast._cuda_amp_available = old_enable
- @pytest.mark.skipif(not ppe.requires("1.10.0"), reason="requires PyTorch>=1.10")
+ @pytest.mark.skipif(
+ not ppe.requires("1.10.0"), reason="requires PyTorch>=1.10"
+ )
@pytest.mark.parametrize(
- 'device_type, dtype',
- [("cpu", torch.bfloat16), ("cuda", torch.float16)]
+ "device_type, dtype", [("cpu", torch.bfloat16), ("cuda", torch.float16)]
)
def test_autocast_options(self, device_type, dtype):
self.run_autocast(
- {"autocast": {'device_type': device_type, "dtype": dtype}}
+ {"autocast": {"device_type": device_type, "dtype": dtype}}
)
class TestLogic:
-
def test_train_epoch_begin(self):
# Check that the DataLoader has the sampler updated
class _MockedDL:
@@ -361,15 +359,17 @@ def __init__(self):
class _Sampler:
def set_epoch(self, epoch):
self.epoch = epoch
+
self.sampler = _Sampler()
+
logic = ppe.handler.Logic()
loader = _MockedDL()
- models = {'main': torch.nn.Linear(1, 1)}
+ models = {"main": torch.nn.Linear(1, 1)}
# The model should be set to train mode
- models['main'].eval()
- assert not models['main'].training
+ models["main"].eval()
+ assert not models["main"].training
logic.train_epoch_begin(models, 10, loader)
- assert models['main'].training
+ assert models["main"].training
assert loader.sampler.epoch == 10
def _run_step(self, logic, device):
@@ -384,26 +384,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x).sum()
model = _Module().to(device)
- models = {'main': model}
- optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0, 0)}
+ models = {"main": model}
+ optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0, 0)}
out = logic.train_step(models, optimizers, 0, input)
return models, optimizers, input, out
def test_train_step(self):
logic = ppe.handler.Logic()
- models, optimizers, input, out = self._run_step(logic, 'cpu')
- model = models['main']
+ models, optimizers, input, out = self._run_step(logic, "cpu")
+ model = models["main"]
assert input.grad is not None
# The gradient of a linear layer is its transposed weight
torch_testing_assert_close(input.grad, model.weight.T)
torch_testing_assert_close(out, model(input))
@pytest.mark.parametrize(
- 'to_backprop',
- [None, ('0',), ('0', '1'), ('0', '1', '2'), ('1', '2'), ('2',)]
+ "to_backprop",
+ [None, ("0",), ("0", "1"), ("0", "1", "2"), ("1", "2"), ("2",)],
)
def test_train_step_backward(self, to_backprop):
- logic = ppe.handler.Logic(options={'backward_outputs': to_backprop})
+ logic = ppe.handler.Logic(options={"backward_outputs": to_backprop})
input = torch.rand(1, 1)
input.requires_grad = True
@@ -415,20 +415,22 @@ def __init__(self):
self.l2 = torch.nn.Linear(1, 1)
def forward(self, x):
- return {'0': self.l0(x), '1': self.l1(x), '2': self.l2(x)}
+ return {"0": self.l0(x), "1": self.l1(x), "2": self.l2(x)}
+
model = _MultiOutModel()
- models = {'main': model}
- optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0)}
+ models = {"main": model}
+ optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0)}
assert input.grad is None
if to_backprop is None:
- to_backprop = ('0', '1', '2')
+ to_backprop = ("0", "1", "2")
# Copy the original parameters to check that they were not updated
original_parameters = {}
for val in to_backprop:
- original_parameters[val] = getattr(
- model, f'l{val}').weight.detach().clone()
+ original_parameters[val] = (
+ getattr(model, f"l{val}").weight.detach().clone()
+ )
outs = logic.train_step(models, optimizers, 0, input)
@@ -436,13 +438,14 @@ def forward(self, x):
assert len(outs.keys()) == 3
grad = torch.zeros(1)
for val in to_backprop:
- grad = grad + getattr(model, f'l{val}').weight.T
+ grad = grad + getattr(model, f"l{val}").weight.T
torch_testing_assert_close(input.grad, grad)
# Check that logic step does not change the value of weight
for val in original_parameters:
torch_testing_assert_close(
- original_parameters[val], getattr(model, f'l{val}').weight)
+ original_parameters[val], getattr(model, f"l{val}").weight
+ )
def test_train_step_backward_nograd(self):
logic = ppe.handler.Logic()
@@ -455,19 +458,19 @@ def __init__(self):
self.l0 = torch.nn.Linear(1, 1)
def forward(self, x):
- return {'0': x}
+ return {"0": x}
model = _DummyModel()
- models = {'main': model}
- optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0)}
+ models = {"main": model}
+ optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0)}
assert input.grad is None
outs = logic.train_step(models, optimizers, 0, input)
- assert outs['0'].grad is None
+ assert outs["0"].grad is None
def test_train_step_backward_invalid(self):
- logic = ppe.handler.Logic(options={'backward_outputs': 'abcd'})
+ logic = ppe.handler.Logic(options={"backward_outputs": "abcd"})
input = torch.rand(1, 1)
input.requires_grad = True
@@ -477,20 +480,20 @@ def __init__(self):
self.l0 = torch.nn.Linear(1, 1)
def forward(self, x):
- return {'0': x}
+ return {"0": x}
model = _DummyModel()
- models = {'main': model}
- optimizers = {'main': torch.optim.SGD(model.parameters(), 1.0)}
+ models = {"main": model}
+ optimizers = {"main": torch.optim.SGD(model.parameters(), 1.0)}
assert input.grad is None
- with pytest.warns(UserWarning, match='backward value: abcd'):
+ with pytest.warns(UserWarning, match="backward value: abcd"):
logic.train_step(models, optimizers, 0, input)
def test_train_step_optimizers(self):
logic = ppe.handler.Logic()
- models, optimizers, input, out = self._run_step(logic, 'cpu')
- model = models['main']
+ models, optimizers, input, out = self._run_step(logic, "cpu")
+ model = models["main"]
m_weight = model.weight.clone().detach()
w_grad = model.weight.grad.clone().detach()
logic.train_step_optimizers(model, optimizers, 0)
@@ -500,10 +503,10 @@ def test_train_step_optimizers(self):
@pytest.mark.gpu
def test_grad_scaler(self):
scaler = torch.cuda.amp.GradScaler()
- options = {'grad_scaler': scaler}
+ options = {"grad_scaler": scaler}
logic = ppe.handler.Logic(options=options)
- models, optimizers, input, out = self._run_step(logic, 'cuda')
- model = models['main']
+ models, optimizers, input, out = self._run_step(logic, "cuda")
+ model = models["main"]
m_weight = model.weight.clone().detach()
w_grad = model.weight.grad.clone().detach()
# The gradient of a linear layer is its transposed weight
@@ -513,11 +516,12 @@ def test_grad_scaler(self):
# Checks that the value was correctly updated and gradients deescaled
# before the update
torch_testing_assert_close(
- scaler.scale(m_weight) - w_grad, scaler.scale(model.weight.T))
+ scaler.scale(m_weight) - w_grad, scaler.scale(model.weight.T)
+ )
@pytest.mark.gpu
def test_invalid_grad_scaler(self):
- options = {'grad_scaler': object()}
+ options = {"grad_scaler": object()}
with pytest.raises(RuntimeError):
ppe.handler.Logic(options=options)
@@ -526,7 +530,7 @@ def test_disabled_grad_scaler(self):
old_enable = ppe.runtime._autocast._cuda_amp_available
try:
ppe.runtime._autocast._cuda_amp_available = False
- options = {'grad_scaler': torch.cuda.amp.GradScaler()}
+ options = {"grad_scaler": torch.cuda.amp.GradScaler()}
with pytest.raises(RuntimeError):
ppe.handler.Logic(options=options)
finally:
@@ -534,23 +538,23 @@ def test_disabled_grad_scaler(self):
def test_train_validation_begin(self):
logic = ppe.handler.Logic()
- models = {'main': torch.nn.Linear(1, 1)}
- models['main'].train()
- assert models['main'].training
+ models = {"main": torch.nn.Linear(1, 1)}
+ models["main"].train()
+ assert models["main"].training
logic.train_validation_begin(models)
- assert not models['main'].training
+ assert not models["main"].training
def test_eval_step(self):
logic = ppe.handler.Logic()
input = torch.rand(1, 1)
model = torch.nn.Linear(1, 1)
- models = {'main': model}
- models['main'].eval()
+ models = {"main": model}
+ models["main"].eval()
out = logic.eval_step(models, 0, input)
torch_testing_assert_close(out, model(input))
@pytest.mark.gpu
def test_use_grad_scaler_with_clousure(self):
- options = {'grad_scaler': torch.cuda.amp.GradScaler()}
+ options = {"grad_scaler": torch.cuda.amp.GradScaler()}
with pytest.raises(RuntimeError):
ppe.handler.ClousureLogic(options=options)
diff --git a/tests/pytorch_pfn_extras_tests/test_logging.py b/tests/pytorch_pfn_extras_tests/test_logging.py
index e4709afa9..55e8ee9cb 100644
--- a/tests/pytorch_pfn_extras_tests/test_logging.py
+++ b/tests/pytorch_pfn_extras_tests/test_logging.py
@@ -7,17 +7,17 @@ def test_file_output():
try:
with tempfile.NamedTemporaryFile() as logfile:
logfile.close() # this is needed for Windows
- logging._configure_logging(filename=logfile.name, level='DEBUG')
+ logging._configure_logging(filename=logfile.name, level="DEBUG")
logger = logging._get_root_logger()
- logger.info('TEST LOG MESSAGE')
+ logger.info("TEST LOG MESSAGE")
with open(logfile.name) as f:
- assert 'TEST LOG MESSAGE' in f.read()
+ assert "TEST LOG MESSAGE" in f.read()
finally:
logging._configure_logging()
def test_get_logger():
- logger = logging.get_logger('app')
+ logger = logging.get_logger("app")
logger.setLevel(logging.DEBUG)
- assert logger.name == 'ppe.app'
+ assert logger.name == "ppe.app"
assert logger.level == logging.DEBUG
diff --git a/tests/pytorch_pfn_extras_tests/test_logic.py b/tests/pytorch_pfn_extras_tests/test_logic.py
index 9d85f8821..6bf97fd33 100644
--- a/tests/pytorch_pfn_extras_tests/test_logic.py
+++ b/tests/pytorch_pfn_extras_tests/test_logic.py
@@ -1,13 +1,13 @@
from typing import Any, Mapping
-import pytest
+from unittest import mock
+import pytest
+import pytorch_pfn_extras as ppe
import torch
from torch import nn
+from torch.nn import Module
+from torch.nn import functional as F
from torch.optim import Optimizer
-from torch.nn import functional as F, Module
-from unittest import mock
-
-import pytorch_pfn_extras as ppe
class MyModel(torch.nn.Module):
@@ -35,15 +35,15 @@ def __init__(self, model):
def forward(self, x, t):
y = self.model(x)
- prefix = 'train' if self.training else 'val'
+ prefix = "train" if self.training else "val"
loss = F.l1_loss(y, t)
- ppe.reporting.report({prefix + '/loss': loss})
+ ppe.reporting.report({prefix + "/loss": loss})
return loss
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_trainer(device):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
iters_per_epoch = 10
epochs = 20
@@ -52,19 +52,41 @@ def test_trainer(device):
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(iters_per_epoch)
+ ]
+ )
backward_fn = mock.Mock(return_value=None)
trainer = ppe.engine.create_trainer(
- model_with_loss, optimizer, epochs,
+ model_with_loss,
+ optimizer,
+ epochs,
device=device,
- options={'backward_function': backward_fn}
+ options={"backward_function": backward_fn},
)
trainer.run(data)
assert backward_fn.call_count == epochs * iters_per_epoch
-@pytest.mark.parametrize("trigger", [(1, "epoch"), (0.5, "epoch"), (10, "iteration"), (5, "iteration"), (1, "iteration")])
+@pytest.mark.parametrize(
+ "trigger",
+ [
+ (1, "epoch"),
+ (0.5, "epoch"),
+ (10, "iteration"),
+ (5, "iteration"),
+ (1, "iteration"),
+ ],
+)
def test_train_step_mode_with_evaluator(trigger):
iters_per_epoch = 10
epochs = 20
@@ -73,11 +95,28 @@ def test_train_step_mode_with_evaluator(trigger):
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(iters_per_epoch)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(iters_per_epoch)
+ ]
+ )
backward_fn = mock.Mock(return_value=None)
class LogicWithTrainStepCheck(ppe.handler.Logic):
- def train_step(self, models: Mapping[str, Module], optimizers: Mapping[str, Optimizer], batch_idx: int, batch: Any) -> Any:
+ def train_step(
+ self,
+ models: Mapping[str, Module],
+ optimizers: Mapping[str, Optimizer],
+ batch_idx: int,
+ batch: Any,
+ ) -> Any:
model = models[self.model_name]
assert model.training
return super().train_step(models, optimizers, batch_idx, batch)
@@ -94,7 +133,7 @@ def train_step(self, models: Mapping[str, Module], optimizers: Mapping[str, Opti
),
trigger,
),
- options={'backward_function': backward_fn}
+ options={"backward_function": backward_fn},
)
trainer.run(data, data)
assert backward_fn.call_count == epochs * iters_per_epoch
diff --git a/tests/pytorch_pfn_extras_tests/test_reporter.py b/tests/pytorch_pfn_extras_tests/test_reporter.py
index 463e5d907..44b6e59e0 100644
--- a/tests/pytorch_pfn_extras_tests/test_reporter.py
+++ b/tests/pytorch_pfn_extras_tests/test_reporter.py
@@ -6,9 +6,8 @@
import numpy
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
def test_empty_reporter():
@@ -41,12 +40,8 @@ def thread_func(reporter, record):
record2 = []
reporter1 = ppe.reporting.Reporter()
reporter2 = ppe.reporting.Reporter()
- thread1 = threading.Thread(
- target=thread_func,
- args=(reporter1, record1))
- thread2 = threading.Thread(
- target=thread_func,
- args=(reporter2, record2))
+ thread1 = threading.Thread(target=thread_func, args=(reporter1, record1))
+ thread2 = threading.Thread(target=thread_func, args=(reporter2, record2))
thread1.daemon = True
thread2.daemon = True
thread1.start()
@@ -72,71 +67,72 @@ def test_scope():
def test_add_observer():
reporter = ppe.reporting.Reporter()
observer = object()
- reporter.add_observer('o', observer)
+ reporter.add_observer("o", observer)
- reporter.report({'x': 1}, observer)
+ reporter.report({"x": 1}, observer)
observation = reporter.observation
- assert 'o/x' in observation
- assert observation['o/x'] == 1
- assert 'x'not in observation
+ assert "o/x" in observation
+ assert observation["o/x"] == 1
+ assert "x" not in observation
def test_add_observers():
reporter = ppe.reporting.Reporter()
observer1 = object()
- reporter.add_observer('o1', observer1)
+ reporter.add_observer("o1", observer1)
observer2 = object()
- reporter.add_observer('o2', observer2)
+ reporter.add_observer("o2", observer2)
- reporter.report({'x': 1}, observer1)
- reporter.report({'y': 2}, observer2)
+ reporter.report({"x": 1}, observer1)
+ reporter.report({"y": 2}, observer2)
observation = reporter.observation
- assert 'o1/x' in observation
- assert observation['o1/x'] == 1
- assert 'o2/y' in observation
- assert observation['o2/y'] == 2
- assert 'x' not in observation
- assert 'y' not in observation
- assert 'o1/y' not in observation
- assert 'o2/x' not in observation
+ assert "o1/x" in observation
+ assert observation["o1/x"] == 1
+ assert "o2/y" in observation
+ assert observation["o2/y"] == 2
+ assert "x" not in observation
+ assert "y" not in observation
+ assert "o1/y" not in observation
+ assert "o2/x" not in observation
def test_report_without_observer():
reporter = ppe.reporting.Reporter()
- reporter.report({'x': 1})
+ reporter.report({"x": 1})
observation = reporter.observation
- assert 'x' in observation
- assert observation['x'] == 1
+ assert "x" in observation
+ assert observation["x"] == 1
# ppe.reporting.report
+
def test_report_without_reporter():
observer = object()
- ppe.reporting.report({'x': 1}, observer)
+ ppe.reporting.report({"x": 1}, observer)
def test_report():
reporter = ppe.reporting.Reporter()
with reporter:
- ppe.reporting.report({'x': 1})
+ ppe.reporting.report({"x": 1})
observation = reporter.observation
- assert 'x' in observation
- assert observation['x'] == 1
+ assert "x" in observation
+ assert observation["x"] == 1
def test_report_with_observer():
reporter = ppe.reporting.Reporter()
observer = object()
- reporter.add_observer('o', observer)
+ reporter.add_observer("o", observer)
with reporter:
- ppe.reporting.report({'x': 1}, observer)
+ ppe.reporting.report({"x": 1}, observer)
observation = reporter.observation
- assert 'o/x' in observation
- assert observation['o/x'] == 1
+ assert "o/x" in observation
+ assert observation["o/x"] == 1
def test_report_with_unregistered_observer():
@@ -144,7 +140,7 @@ def test_report_with_unregistered_observer():
observer = object()
with reporter:
with pytest.raises(KeyError):
- ppe.reporting.report({'x': 1}, observer)
+ ppe.reporting.report({"x": 1}, observer)
def test_report_scope():
@@ -153,21 +149,21 @@ def test_report_scope():
with reporter:
with ppe.reporting.report_scope(observation):
- ppe.reporting.report({'x': 1})
+ ppe.reporting.report({"x": 1})
- assert 'x' in observation
- assert observation['x'] == 1
- assert 'x' not in reporter.observation
+ assert "x" in observation
+ assert observation["x"] == 1
+ assert "x" not in reporter.observation
def test_report_tensor_detached():
reporter = ppe.reporting.Reporter()
- x = torch.tensor(numpy.array(1, 'float32'), requires_grad=True)
+ x = torch.tensor(numpy.array(1, "float32"), requires_grad=True)
with reporter:
- ppe.reporting.report({'x': x})
+ ppe.reporting.report({"x": x})
observation = reporter.observation
- assert 'x' in observation
- assert not observation['x'].requires_grad
+ assert "x" in observation
+ assert not observation["x"].requires_grad
assert x.requires_grad
@@ -182,17 +178,18 @@ def test_report_callable():
# ppe.reporting.Summary
+
def test_summary_basic():
summary = ppe.reporting.Summary()
- summary.add(torch.Tensor(numpy.array(1, 'float32')))
- summary.add(torch.Tensor(numpy.array(-2, 'float32')))
+ summary.add(torch.Tensor(numpy.array(1, "float32")))
+ summary.add(torch.Tensor(numpy.array(-2, "float32")))
mean = summary.compute_mean()
- numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, 'f'))
+ numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, "f"))
mean, std = summary.make_statistics()
- numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, 'f'))
- numpy.testing.assert_allclose(std.numpy(), numpy.array(1.5, 'f'))
+ numpy.testing.assert_allclose(mean.numpy(), numpy.array(-0.5, "f"))
+ numpy.testing.assert_allclose(std.numpy(), numpy.array(1.5, "f"))
def test_summary_int():
@@ -206,28 +203,28 @@ def test_summary_int():
mean, std = summary.make_statistics()
numpy.testing.assert_allclose(mean, 2)
- numpy.testing.assert_allclose(std, numpy.sqrt(2. / 3.))
+ numpy.testing.assert_allclose(std, numpy.sqrt(2.0 / 3.0))
def test_summary_float():
summary = ppe.reporting.Summary()
- summary.add(1.)
- summary.add(2.)
- summary.add(3.)
+ summary.add(1.0)
+ summary.add(2.0)
+ summary.add(3.0)
mean = summary.compute_mean()
- numpy.testing.assert_allclose(mean, 2.)
+ numpy.testing.assert_allclose(mean, 2.0)
mean, std = summary.make_statistics()
- numpy.testing.assert_allclose(mean, 2.)
- numpy.testing.assert_allclose(std, numpy.sqrt(2. / 3.))
+ numpy.testing.assert_allclose(mean, 2.0)
+ numpy.testing.assert_allclose(std, numpy.sqrt(2.0 / 3.0))
def test_summary_weight():
summary = ppe.reporting.Summary()
- summary.add(1., 0.5)
- summary.add(2., numpy.array(0.4))
- summary.add(3., torch.autograd.Variable(torch.Tensor(numpy.array(0.3))))
+ summary.add(1.0, 0.5)
+ summary.add(2.0, numpy.array(0.4))
+ summary.add(3.0, torch.autograd.Variable(torch.Tensor(numpy.array(0.3))))
mean = summary.compute_mean()
val = (1 * 0.5 + 2 * 0.4 + 3 * 0.3) / (0.5 + 0.4 + 0.3)
@@ -244,21 +241,21 @@ def test_summary_deferred_add():
def test_summary_add_operator():
s1 = ppe.reporting.Summary()
- s1.add(1.)
- s1.add(2.)
- s1.add(3.)
+ s1.add(1.0)
+ s1.add(2.0)
+ s1.add(3.0)
s2 = ppe.reporting.Summary()
- s2.add(4.)
- s2.add(5.)
- s2.add(6.)
- s2.add(7.)
+ s2.add(4.0)
+ s2.add(5.0)
+ s2.add(6.0)
+ s2.add(7.0)
s = s1 + s2
assert s.state_dict() == {
- '_x': 28.0,
- '_x2': 140.0,
- '_n': 7,
+ "_x": 28.0,
+ "_x2": 140.0,
+ "_n": 7,
}
@@ -279,12 +276,14 @@ def _check_summary_serialize(value1, value2, value3):
torch.save(summary.state_dict(), f.name)
# Load tensors in CPU to simulate a snapshot restore
summary2.load_state_dict(
- torch.load(f.name, map_location=torch.device('cpu')))
+ torch.load(f.name, map_location=torch.device("cpu"))
+ )
summary2.add(value3)
- expected_mean = float((value1 + value2 + value3) / 3.)
+ expected_mean = float((value1 + value2 + value3) / 3.0)
expected_std = math.sqrt(
- (value1**2 + value2**2 + value3**2) / 3. - expected_mean**2)
+ (value1**2 + value2**2 + value3**2) / 3.0 - expected_mean**2
+ )
mean = summary2.compute_mean()
if isinstance(mean, torch.Tensor):
@@ -305,21 +304,25 @@ def test_serialize_array_float():
numpy.array(1.5, numpy.float32),
numpy.array(2.0, numpy.float32),
# sum of the above two is non-integer
- numpy.array(3.5, numpy.float32))
+ numpy.array(3.5, numpy.float32),
+ )
def test_serialize_array_int():
_check_summary_serialize(
numpy.array(1, numpy.int32),
numpy.array(-2, numpy.int32),
- numpy.array(2, numpy.int32))
+ numpy.array(2, numpy.int32),
+ )
def test_serialize_scalar_float():
_check_summary_serialize(
- 1.5, 2.0,
+ 1.5,
+ 2.0,
# sum of the above two is non-integer
- 3.5)
+ 3.5,
+ )
def test_serialize_scalar_int():
@@ -328,9 +331,8 @@ def test_serialize_scalar_int():
def test_serialize_tensor():
_check_summary_serialize(
- torch.tensor(1.5),
- torch.tensor(2.0),
- torch.tensor(3.5))
+ torch.tensor(1.5), torch.tensor(2.0), torch.tensor(3.5)
+ )
@pytest.mark.gpu
@@ -338,18 +340,21 @@ def test_serialize_tensor_cuda():
_check_summary_serialize(
torch.tensor(1.5).cuda(),
torch.tensor(2.0).cuda(),
- torch.tensor(3.5).cuda())
+ torch.tensor(3.5).cuda(),
+ )
def test_serialize_tensor_with_grad():
_check_summary_serialize(
torch.tensor(1.5, requires_grad=True),
torch.tensor(2.0, requires_grad=True),
- 3.5)
+ 3.5,
+ )
# ppe.reporting.DictSummary
+
def _check_dict_summary(summary, data):
mean = summary.compute_mean()
assert set(mean.keys()) == set(data.keys())
@@ -358,150 +363,171 @@ def _check_dict_summary(summary, data):
numpy.testing.assert_allclose(mean[name], m)
stats = summary.make_statistics()
- assert (
- set(stats.keys())
- == set(data.keys()).union(name + '.std' for name in data.keys()))
+ assert set(stats.keys()) == set(data.keys()).union(
+ name + ".std" for name in data.keys()
+ )
for name in data.keys():
m = sum(data[name]) / float(len(data[name]))
s = numpy.sqrt(
- sum(x * x for x in data[name]) / float(len(data[name]))
- - m * m)
+ sum(x * x for x in data[name]) / float(len(data[name])) - m * m
+ )
numpy.testing.assert_allclose(stats[name], m)
- numpy.testing.assert_allclose(stats[name + '.std'], s)
+ numpy.testing.assert_allclose(stats[name + ".std"], s)
def test_dict_summary():
summary = ppe.reporting.DictSummary()
- summary.add({'numpy': numpy.array(3, 'f'), 'int': 1, 'float': 4.})
- summary.add({'numpy': numpy.array(1, 'f'), 'int': 5, 'float': 9.})
- summary.add({'numpy': numpy.array(2, 'f'), 'int': 6, 'float': 5.})
- summary.add({'numpy': numpy.array(3, 'f'), 'int': 5, 'float': 8.})
+ summary.add({"numpy": numpy.array(3, "f"), "int": 1, "float": 4.0})
+ summary.add({"numpy": numpy.array(1, "f"), "int": 5, "float": 9.0})
+ summary.add({"numpy": numpy.array(2, "f"), "int": 6, "float": 5.0})
+ summary.add({"numpy": numpy.array(3, "f"), "int": 5, "float": 8.0})
- _check_dict_summary(summary, {
- 'numpy': (3., 1., 2., 3.),
- 'int': (1, 5, 6, 5),
- 'float': (4., 9., 5., 8.),
- })
+ _check_dict_summary(
+ summary,
+ {
+ "numpy": (3.0, 1.0, 2.0, 3.0),
+ "int": (1, 5, 6, 5),
+ "float": (4.0, 9.0, 5.0, 8.0),
+ },
+ )
def test_dit_summary_sparse():
summary = ppe.reporting.DictSummary()
- summary.add({'a': 3., 'b': 1.})
- summary.add({'a': 1., 'b': 5., 'c': 9.})
- summary.add({'b': 6.})
- summary.add({'a': 3., 'b': 5., 'c': 8.})
+ summary.add({"a": 3.0, "b": 1.0})
+ summary.add({"a": 1.0, "b": 5.0, "c": 9.0})
+ summary.add({"b": 6.0})
+ summary.add({"a": 3.0, "b": 5.0, "c": 8.0})
- _check_dict_summary(summary, {
- 'a': (3., 1., 3.),
- 'b': (1., 5., 6., 5.),
- 'c': (9., 8.),
- })
+ _check_dict_summary(
+ summary,
+ {
+ "a": (3.0, 1.0, 3.0),
+ "b": (1.0, 5.0, 6.0, 5.0),
+ "c": (9.0, 8.0),
+ },
+ )
def test_dict_summary_weight():
summary = ppe.reporting.DictSummary()
- summary.add({'a': (1., 0.5)})
- summary.add({'a': (2., numpy.array(0.4))})
+ summary.add({"a": (1.0, 0.5)})
+ summary.add({"a": (2.0, numpy.array(0.4))})
summary.add(
- {'a': (3., torch.autograd.Variable(torch.Tensor(numpy.array(0.3))))})
+ {"a": (3.0, torch.autograd.Variable(torch.Tensor(numpy.array(0.3))))}
+ )
mean = summary.compute_mean()
val = (1 * 0.5 + 2 * 0.4 + 3 * 0.3) / (0.5 + 0.4 + 0.3)
- numpy.testing.assert_allclose(mean['a'].numpy(), val)
+ numpy.testing.assert_allclose(mean["a"].numpy(), val)
arr = numpy.array([0.5])
with pytest.raises(ValueError):
- summary.add({'a': (4., arr)})
+ summary.add({"a": (4.0, arr)})
var = torch.autograd.Variable(torch.Tensor(numpy.array([0.5])))
with pytest.raises(ValueError):
- summary.add({'a': (4., var)})
+ summary.add({"a": (4.0, var)})
def test_dict_summary_serialize():
summary = ppe.reporting.DictSummary()
- summary.add({'numpy': numpy.array(3, 'f'), 'int': 1, 'float': 4.})
- summary.add({'numpy': numpy.array(1, 'f'), 'int': 5, 'float': 9.})
- summary.add({'numpy': numpy.array(2, 'f'), 'int': 6, 'float': 5.})
+ summary.add({"numpy": numpy.array(3, "f"), "int": 1, "float": 4.0})
+ summary.add({"numpy": numpy.array(1, "f"), "int": 5, "float": 9.0})
+ summary.add({"numpy": numpy.array(2, "f"), "int": 6, "float": 5.0})
summary2 = ppe.reporting.DictSummary()
summary2.load_state_dict(summary.state_dict())
- summary2.add({'numpy': numpy.array(3, 'f'), 'int': 5, 'float': 8.})
+ summary2.add({"numpy": numpy.array(3, "f"), "int": 5, "float": 8.0})
- _check_dict_summary(summary2, {
- 'numpy': (3., 1., 2., 3.),
- 'int': (1, 5, 6, 5),
- 'float': (4., 9., 5., 8.),
- })
+ _check_dict_summary(
+ summary2,
+ {
+ "numpy": (3.0, 1.0, 2.0, 3.0),
+ "int": (1, 5, 6, 5),
+ "float": (4.0, 9.0, 5.0, 8.0),
+ },
+ )
-@pytest.mark.parametrize('delimiter', ['/', '.'])
+@pytest.mark.parametrize("delimiter", ["/", "."])
@pytest.mark.parametrize(
# How the state of the summary is transferred.
- 'transfer_protocol',
+ "transfer_protocol",
[
- 'direct', # Use state_dict() and load_state_dict()
- 'torch', # Use torch.save() and torch.load()
- ])
+ "direct", # Use state_dict() and load_state_dict()
+ "torch", # Use torch.save() and torch.load()
+ ],
+)
def test_dict_summary_serialize_names_with_delimiter(
- delimiter, transfer_protocol):
- key1 = 'a{d}b'.format(d=delimiter)
- key2 = '{d}a{d}b'.format(d=delimiter)
- key3 = 'a{d}b{d}'.format(d=delimiter)
+ delimiter, transfer_protocol
+):
+ key1 = "a{d}b".format(d=delimiter)
+ key2 = "{d}a{d}b".format(d=delimiter)
+ key3 = "a{d}b{d}".format(d=delimiter)
summary = ppe.reporting.DictSummary()
- summary.add({key1: 3., key2: 1., key3: 4.})
- summary.add({key1: 1., key2: 5., key3: 9.})
- summary.add({key1: 2., key2: 6., key3: 5.})
+ summary.add({key1: 3.0, key2: 1.0, key3: 4.0})
+ summary.add({key1: 1.0, key2: 5.0, key3: 9.0})
+ summary.add({key1: 2.0, key2: 6.0, key3: 5.0})
- if transfer_protocol == 'direct':
+ if transfer_protocol == "direct":
summary2 = ppe.reporting.DictSummary()
summary2.load_state_dict(summary.state_dict())
else:
- assert transfer_protocol == 'torch'
+ assert transfer_protocol == "torch"
f = io.BytesIO()
torch.save(summary, f)
summary2 = torch.load(io.BytesIO(f.getvalue()))
- summary2.add({key1: 3., key2: 5., key3: 8.})
+ summary2.add({key1: 3.0, key2: 5.0, key3: 8.0})
- _check_dict_summary(summary2, {
- key1: (3., 1., 2., 3.),
- key2: (1., 5., 6., 5.),
- key3: (4., 9., 5., 8.),
- })
+ _check_dict_summary(
+ summary2,
+ {
+ key1: (3.0, 1.0, 2.0, 3.0),
+ key2: (1.0, 5.0, 6.0, 5.0),
+ key3: (4.0, 9.0, 5.0, 8.0),
+ },
+ )
def test_serialize_overwrite_different_names():
summary = ppe.reporting.DictSummary()
- summary.add({'a': 3., 'b': 1.})
- summary.add({'a': 1., 'b': 5.})
+ summary.add({"a": 3.0, "b": 1.0})
+ summary.add({"a": 1.0, "b": 5.0})
summary2 = ppe.reporting.DictSummary()
- summary2.add({'c': 5.})
+ summary2.add({"c": 5.0})
summary2.load_state_dict(summary.state_dict())
- _check_dict_summary(summary2, {
- 'a': (3., 1.),
- 'b': (1., 5.),
- })
+ _check_dict_summary(
+ summary2,
+ {
+ "a": (3.0, 1.0),
+ "b": (1.0, 5.0),
+ },
+ )
def test_serialize_overwrite_rollback():
summary = ppe.reporting.DictSummary()
- summary.add({'a': 3., 'b': 1.})
- summary.add({'a': 1., 'b': 5.})
+ summary.add({"a": 3.0, "b": 1.0})
+ summary.add({"a": 1.0, "b": 5.0})
state = summary.state_dict()
- summary.add({'a': 2., 'b': 6., 'c': 5.})
- summary.add({'a': 3., 'b': 4., 'c': 6.})
+ summary.add({"a": 2.0, "b": 6.0, "c": 5.0})
+ summary.add({"a": 3.0, "b": 4.0, "c": 6.0})
summary.load_state_dict(state)
- summary.add({'a': 3., 'b': 5., 'c': 8.})
+ summary.add({"a": 3.0, "b": 5.0, "c": 8.0})
- _check_dict_summary(summary, {
- 'a': (3., 1., 3.),
- 'b': (1., 5., 5.),
- 'c': (8.,),
- })
+ _check_dict_summary(
+ summary,
+ {
+ "a": (3.0, 1.0, 3.0),
+ "b": (1.0, 5.0, 5.0),
+ "c": (8.0,),
+ },
+ )
def test_dict_summary_deferred_add():
@@ -509,26 +535,32 @@ def test_dict_summary_deferred_add():
summary.add({"x": lambda: 1.0, "y": lambda: 2.0})
summary.add({"x": lambda: -1.0})
- _check_dict_summary(summary, {
- "x": (1.0, -1.0),
- "y": (2.0,),
- })
+ _check_dict_summary(
+ summary,
+ {
+ "x": (1.0, -1.0),
+ "y": (2.0,),
+ },
+ )
def test_dict_summary_add_operator():
s1 = ppe.reporting.DictSummary()
- s1.add({'a': 1., 'b': 0.1, 'c': 0.01})
- s1.add({'a': 2., 'b': 0.2, 'c': 0.02})
+ s1.add({"a": 1.0, "b": 0.1, "c": 0.01})
+ s1.add({"a": 2.0, "b": 0.2, "c": 0.02})
s2 = ppe.reporting.DictSummary()
- s2.add({'a': 3., 'b': 0.3, 'f': 0.03})
- s2.add({'a': 4., 'b': 0.4, 'f': 0.04})
- s2.add({'a': 5., 'b': 0.5, 'f': 0.05})
+ s2.add({"a": 3.0, "b": 0.3, "f": 0.03})
+ s2.add({"a": 4.0, "b": 0.4, "f": 0.04})
+ s2.add({"a": 5.0, "b": 0.5, "f": 0.05})
s = s1 + s2
- _check_dict_summary(s, {
- 'a': [1, 2, 3, 4, 5],
- 'b': [0.1, 0.2, 0.3, 0.4, 0.5],
- 'c': [0.01, 0.02],
- 'f': [0.03, 0.04, 0.05],
- })
+ _check_dict_summary(
+ s,
+ {
+ "a": [1, 2, 3, 4, 5],
+ "b": [0.1, 0.2, 0.3, 0.4, 0.5],
+ "c": [0.01, 0.02],
+ "f": [0.03, 0.04, 0.05],
+ },
+ )
diff --git a/tests/pytorch_pfn_extras_tests/test_tensor.py b/tests/pytorch_pfn_extras_tests/test_tensor.py
index a05ff280a..7f1d6ad10 100644
--- a/tests/pytorch_pfn_extras_tests/test_tensor.py
+++ b/tests/pytorch_pfn_extras_tests/test_tensor.py
@@ -1,10 +1,9 @@
import numpy
import pytest
+import pytorch_pfn_extras as ppe
import torch
import torch.utils.dlpack
-import pytorch_pfn_extras as ppe
-
def test_from_ndarray_numpy():
np_arr = numpy.arange(24).reshape(2, 3, 4)
@@ -21,7 +20,7 @@ def test_from_ndarray_numpy_neg():
def test_from_ndarray_cupy():
- cupy = pytest.importorskip('cupy')
+ cupy = pytest.importorskip("cupy")
cp_arr = cupy.arange(24).reshape(2, 3, 4)
tensor = ppe.from_ndarray(cp_arr)
assert cp_arr.data.ptr == tensor.data_ptr()
@@ -29,7 +28,7 @@ def test_from_ndarray_cupy():
def test_from_ndarray_cupy_neg():
- cupy = pytest.importorskip('cupy')
+ cupy = pytest.importorskip("cupy")
cp_arr = cupy.flip(cupy.arange(24).reshape(2, 3, 4))
tensor = ppe.from_ndarray(cp_arr) # copy
assert cp_arr.data.ptr != tensor.data_ptr()
@@ -50,7 +49,7 @@ def test_as_ndarray_cpu():
def test_as_ndarray_cupy():
- cupy = pytest.importorskip('cupy')
+ cupy = pytest.importorskip("cupy")
tensor = torch.arange(24).reshape(2, 3, 4).cuda()
arr = ppe.as_ndarray(tensor)
assert isinstance(arr, cupy.ndarray)
@@ -60,14 +59,14 @@ def test_as_ndarray_cupy():
def test_get_xp_numpy():
assert ppe.get_xp(torch.ones(4)) is numpy
- assert ppe.get_xp(torch.device('cpu')) is numpy
+ assert ppe.get_xp(torch.device("cpu")) is numpy
assert ppe.get_xp(numpy.ones(4)) is numpy
def test_get_xp_cupy():
- cupy = pytest.importorskip('cupy')
+ cupy = pytest.importorskip("cupy")
assert ppe.get_xp(torch.ones(4).cuda()) is cupy
- assert ppe.get_xp(torch.device('cuda:0')) is cupy
+ assert ppe.get_xp(torch.device("cuda:0")) is cupy
assert ppe.get_xp(cupy.ones(1)) is cupy
@@ -76,12 +75,22 @@ def test_get_xp_invalid_type():
ppe.get_xp([1, 2, 3])
-@pytest.mark.parametrize('dtype', [
- 'bool',
- 'uint8', 'int8', 'int16', 'int32', 'int64',
- 'float16', 'float32', 'float64',
- 'complex64', 'complex128',
-])
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ "bool",
+ "uint8",
+ "int8",
+ "int16",
+ "int32",
+ "int64",
+ "float16",
+ "float32",
+ "float64",
+ "complex64",
+ "complex128",
+ ],
+)
def test_torch_numpy_dtype(dtype):
torch_dtype = getattr(torch, dtype)
numpy_dtype = numpy.dtype(dtype)
@@ -93,6 +102,6 @@ def test_torch_numpy_dtype_unsupported():
with pytest.raises(TypeError):
ppe.as_numpy_dtype(torch.bfloat16)
with pytest.raises(TypeError):
- ppe.from_numpy_dtype(numpy.dtype('object'))
+ ppe.from_numpy_dtype(numpy.dtype("object"))
with pytest.raises(TypeError):
ppe.from_numpy_dtype(None)
diff --git a/tests/pytorch_pfn_extras_tests/test_torchscript.py b/tests/pytorch_pfn_extras_tests/test_torchscript.py
index 840a9c9c7..d3460b573 100644
--- a/tests/pytorch_pfn_extras_tests/test_torchscript.py
+++ b/tests/pytorch_pfn_extras_tests/test_torchscript.py
@@ -1,10 +1,10 @@
-import torch
import pytorch_pfn_extras.torchscript as ts
+import torch
def test_find_inplace():
def f(v: torch.Tensor) -> None:
- v += torch.ones((1,2,3))
+ v += torch.ones((1, 2, 3))
def g(v: torch.Tensor):
f(v)
@@ -17,7 +17,7 @@ def g(v: torch.Tensor):
def test_find_inplace_not_found():
def f(v: torch.Tensor) -> torch.Tensor:
- return torch.ones((1,2,3))
+ return torch.ones((1, 2, 3))
s = torch.jit.script(f)
diff --git a/tests/pytorch_pfn_extras_tests/test_writing.py b/tests/pytorch_pfn_extras_tests/test_writing.py
index 47106dd6f..b8ea93274 100644
--- a/tests/pytorch_pfn_extras_tests/test_writing.py
+++ b/tests/pytorch_pfn_extras_tests/test_writing.py
@@ -1,20 +1,22 @@
-import tempfile
import os
+import tempfile
import pytest
-
import pytorch_pfn_extras as ppe
-@pytest.mark.filterwarnings("ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning")
+@pytest.mark.filterwarnings(
+ "ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning"
+)
def test_tensorboard_writing():
- pytest.importorskip('tensorboard')
+ pytest.importorskip("tensorboard")
data = {"a": 1, "iteration": 1}
with tempfile.TemporaryDirectory() as tempd:
writer = ppe.writing.TensorBoardWriter(
- out_dir=tempd, filename_suffix='_test')
+ out_dir=tempd, filename_suffix="_test"
+ )
writer(None, None, data)
# Check that the file was generated
for snap in os.listdir(tempd):
- assert '_test' in snap
+ assert "_test" in snap
writer.finalize()
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py
index ab5c37b68..908da5624 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_best_value.py
@@ -1,45 +1,55 @@
import tempfile
import pytest
-
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions
-
params = [
- ([None, 4.0, 4.5, 3.0, 3.5],
- [4.0, 4.0, 3.0, 3.0],
- [1, 1, 3, 3],
- [3, 3, 9, 9],
- extensions.MinValue),
- ([None, 3.0, 4.5, 4.0, 5.0],
- [3.0, 4.5, 4.5, 5.0],
- [1, 2, 2, 4],
- [3, 6, 6, 12],
- extensions.MaxValue),
+ (
+ [None, 4.0, 4.5, 3.0, 3.5],
+ [4.0, 4.0, 3.0, 3.0],
+ [1, 1, 3, 3],
+ [3, 3, 9, 9],
+ extensions.MinValue,
+ ),
+ (
+ [None, 3.0, 4.5, 4.0, 5.0],
+ [3.0, 4.5, 4.5, 5.0],
+ [1, 2, 2, 4],
+ [3, 6, 6, 12],
+ extensions.MaxValue,
+ ),
]
-@pytest.mark.parametrize('observed_values,expected_best_values,'
- 'expected_best_epochs,expected_best_iterations,BestValueT',
- params)
-def test_best_observation(observed_values, expected_best_values,
- expected_best_epochs, expected_best_iterations, BestValueT):
+@pytest.mark.parametrize(
+ "observed_values,expected_best_values,"
+ "expected_best_epochs,expected_best_iterations,BestValueT",
+ params,
+)
+def test_best_observation(
+ observed_values,
+ expected_best_values,
+ expected_best_epochs,
+ expected_best_iterations,
+ BestValueT,
+):
max_epochs = 4
iters_per_epoch = 3
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir
+ )
def observer_fn(manager):
return observed_values[manager.epoch]
- observe = extensions.observe_value('value', observer_fn)
- manager.extend(observe, trigger=(1, 'epoch'))
+ observe = extensions.observe_value("value", observer_fn)
+ manager.extend(observe, trigger=(1, "epoch"))
- best_value = BestValueT('value')
- manager.extend(best_value, trigger=(1, 'epoch'))
+ best_value = BestValueT("value")
+ manager.extend(best_value, trigger=(1, "epoch"))
for epoch in range(max_epochs):
for _ in range(iters_per_epoch):
@@ -51,16 +61,16 @@ def observer_fn(manager):
# Save/Load state dict (snapshot support)
assert best_value.state_dict() == {
- '_best_trigger': {
- '_best_value': expected_best_values[-1],
- '_summary': {},
- 'interval_trigger': {}
+ "_best_trigger": {
+ "_best_value": expected_best_values[-1],
+ "_summary": {},
+ "interval_trigger": {},
},
- '_best_it': expected_best_iterations[-1],
- '_best_epoch': expected_best_epochs[-1],
+ "_best_it": expected_best_iterations[-1],
+ "_best_epoch": expected_best_epochs[-1],
}
- best_value2 = BestValueT('value')
+ best_value2 = BestValueT("value")
best_value2.load_state_dict(best_value.state_dict())
assert best_value2.best_value == expected_best_values[-1]
assert best_value2.best_iteration == expected_best_iterations[-1]
@@ -73,10 +83,11 @@ def test_key_error():
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir
+ )
- best_observation = extensions.BestValue('value', lambda a, b: a < b)
- manager.extend(best_observation, trigger=(1, 'epoch'))
+ best_observation = extensions.BestValue("value", lambda a, b: a < b)
+ manager.extend(best_observation, trigger=(1, "epoch"))
with pytest.raises(KeyError) as e:
for _ in range(max_epochs):
@@ -87,7 +98,7 @@ def test_key_error():
def test_error_before_first_call():
- best_observation = extensions.BestValue('value', lambda a, b: a < b)
+ best_observation = extensions.BestValue("value", lambda a, b: a < b)
with pytest.raises(RuntimeError):
best_observation.best_value
with pytest.raises(RuntimeError):
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py
index fb27501eb..a38b3f6aa 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py
@@ -2,26 +2,24 @@
import os
import tempfile
-import torch
import pytest
-
-from pytorch_pfn_extras import distributed
-from pytorch_pfn_extras import training
+import torch
+from pytorch_pfn_extras import distributed, training
from pytorch_pfn_extras.training import extensions
def _create_distributed_model(gpu=True):
comm_size, comm_rank, comm_local_rank, device = _init_distributed(True)
- device = torch.device('cuda:{}'.format(comm_local_rank)
- if gpu else 'cpu')
+ device = torch.device("cuda:{}".format(comm_local_rank) if gpu else "cpu")
model = torch.nn.Linear(128, 1)
if torch.distributed.is_initialized():
if not gpu:
raise pytest.skip("Distributed tests require GPUs.")
model = torch.nn.parallel.DistributedDataParallel(
- model.to(device), device_ids=[comm_local_rank])
+ model.to(device), device_ids=[comm_local_rank]
+ )
else:
model = model.to(device)
@@ -32,27 +30,27 @@ def get_trainer(path):
epochs = 10 # FIXME
model = _create_distributed_model()
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
- optimizers = {'main': optimizer}
- models = {'main': model}
+ optimizers = {"main": optimizer}
+ models = {"main": model}
return training.ExtensionsManager(
- models, optimizers, epochs, iters_per_epoch=1, out_dir=path)
+ models, optimizers, epochs, iters_per_epoch=1, out_dir=path
+ )
def _init_distributed(use_cuda):
- if ('OMPI_COMM_WORLD_SIZE' in os.environ):
- size, rank, local_rank = (
- distributed.initialize_ompi_environment(
- backend="nccl", init_method="env"))
+ if "OMPI_COMM_WORLD_SIZE" in os.environ:
+ size, rank, local_rank = distributed.initialize_ompi_environment(
+ backend="nccl", init_method="env"
+ )
else:
pytest.skip("This test requires MPI to run")
- device = torch.device(
- "cuda:{}".format(local_rank) if use_cuda else "cpu")
+ device = torch.device("cuda:{}".format(local_rank) if use_cuda else "cpu")
return size, rank, local_rank, device
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def path():
with tempfile.TemporaryDirectory() as t_path:
yield t_path
@@ -73,11 +71,11 @@ def test_distributed_snapshot(path):
torch.distributed.barrier()
saver_rank = 0
- fmt = 'snapshot_iter_{.iteration}'
+ fmt = "snapshot_iter_{.iteration}"
snapshot = extensions.snapshot(filename=fmt, saver_rank=saver_rank)
trainer = get_trainer(path)
- trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)
+ trainer.extend(snapshot, trigger=(1, "iteration"), priority=2)
for _ in range(1):
with trainer.run_iteration():
pass
@@ -91,7 +89,8 @@ def test_distributed_snapshot(path):
new_trainer = get_trainer(path)
new_trainer.load_state_dict(torch.load(os.path.join(path, found[0])))
assert _model_params_equal(
- trainer._models['main'], new_trainer._models['main'])
+ trainer._models["main"], new_trainer._models["main"]
+ )
if comm_size > 1:
torch.distributed.barrier()
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py
index 54e9be181..60e4957a1 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_evaluator.py
@@ -2,38 +2,35 @@
import numpy
import pytest
-import torch
-import torch.distributed as dist
-
import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.training.extensions as ext
+import torch
+import torch.distributed as dist
class DummyModel(torch.nn.Module):
-
def __init__(self):
super().__init__()
self.args = []
def forward(self, x):
self.args.append(x)
- ppe.reporting.report({'loss': x.sum()}, self)
+ ppe.reporting.report({"loss": x.sum()}, self)
return x
def custom_metric(batch, out, last_iter):
- ppe.reporting.report({'custom-metric': out.sum()})
+ ppe.reporting.report({"custom-metric": out.sum()})
class DummyModelTwoArgs(torch.nn.Module):
-
def __init__(self):
super().__init__()
self.args = []
def forward(self, x, y):
self.args.append((x, y))
- ppe.reporting.report({'loss': x.sum() + y.sum()}, self)
+ ppe.reporting.report({"loss": x.sum() + y.sum()}, self)
def _torch_batch_to_numpy(batch):
@@ -43,11 +40,11 @@ def _torch_batch_to_numpy(batch):
return batch.squeeze(0).numpy()
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def evaluator_dummies():
data = [
- numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f')
- for _ in range(2)]
+ numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f") for _ in range(2)
+ ]
data_loader = torch.utils.data.DataLoader(data)
target = DummyModel()
@@ -60,7 +57,7 @@ def test_evaluate(evaluator_dummies):
data, data_loader, target, evaluator, expect_mean = evaluator_dummies
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
mean = evaluator.evaluate()
@@ -71,10 +68,12 @@ def test_evaluate(evaluator_dummies):
assert len(target.args) == len(data)
for i in range(len(data)):
numpy.testing.assert_array_equal(
- _torch_batch_to_numpy(target.args[i]), data[i])
+ _torch_batch_to_numpy(target.args[i]), data[i]
+ )
numpy.testing.assert_almost_equal(
- mean['target/loss'], expect_mean, decimal=4)
+ mean["target/loss"], expect_mean, decimal=4
+ )
def test_metric(evaluator_dummies):
@@ -82,9 +81,8 @@ def test_metric(evaluator_dummies):
evaluator.add_metric(custom_metric)
mean = evaluator()
# 'main' is used by default
- assert 'custom-metric' in mean
- numpy.testing.assert_almost_equal(
- mean['main/loss'], expect_mean, decimal=4)
+ assert "custom-metric" in mean
+ numpy.testing.assert_almost_equal(mean["main/loss"], expect_mean, decimal=4)
def test_call(evaluator_dummies):
@@ -92,18 +90,18 @@ def test_call(evaluator_dummies):
mean = evaluator()
# 'main' is used by default
- numpy.testing.assert_almost_equal(
- mean['main/loss'], expect_mean, decimal=4)
+ numpy.testing.assert_almost_equal(mean["main/loss"], expect_mean, decimal=4)
def test_evaluator_name(evaluator_dummies):
data, data_loader, target, evaluator, expect_mean = evaluator_dummies
- evaluator.name = 'eval'
+ evaluator.name = "eval"
mean = evaluator()
# name is used as a prefix
numpy.testing.assert_almost_equal(
- mean['eval/main/loss'], expect_mean, decimal=4)
+ mean["eval/main/loss"], expect_mean, decimal=4
+ )
def test_current_report(evaluator_dummies):
@@ -118,16 +116,19 @@ def test_current_report(evaluator_dummies):
def test_evaluator_tuple_data():
data = [
- (numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f'),
- numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f'))
- for _ in range(2)]
+ (
+ numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"),
+ numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"),
+ )
+ for _ in range(2)
+ ]
data_loader = torch.utils.data.DataLoader(data)
target = DummyModelTwoArgs()
evaluator = ppe.training.extensions.Evaluator(data_loader, target)
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
mean = evaluator.evaluate()
@@ -135,72 +136,83 @@ def test_evaluator_tuple_data():
for i in range(len(data)):
assert len(target.args[i]) == len(data[i])
numpy.testing.assert_array_equal(
- _torch_batch_to_numpy(target.args[i][0]), data[i][0])
+ _torch_batch_to_numpy(target.args[i][0]), data[i][0]
+ )
numpy.testing.assert_array_equal(
- _torch_batch_to_numpy(target.args[i][1]), data[i][1])
+ _torch_batch_to_numpy(target.args[i][1]), data[i][1]
+ )
expect_mean = numpy.mean([numpy.sum(x) for x in data])
numpy.testing.assert_almost_equal(
- mean['target/loss'], expect_mean, decimal=4)
+ mean["target/loss"], expect_mean, decimal=4
+ )
def test_evaluator_dict_data():
data = [
- {'x': numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f'),
- 'y': numpy.random.uniform(-1, 1, (2, 3, 4)).astype('f')}
- for _ in range(2)]
+ {
+ "x": numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"),
+ "y": numpy.random.uniform(-1, 1, (2, 3, 4)).astype("f"),
+ }
+ for _ in range(2)
+ ]
data_loader = torch.utils.data.DataLoader(data)
target = DummyModelTwoArgs()
evaluator = ppe.training.extensions.Evaluator(data_loader, target)
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
mean = evaluator.evaluate()
assert len(target.args) == len(data)
for i in range(len(data)):
numpy.testing.assert_array_equal(
- _torch_batch_to_numpy(target.args[i][0]), data[i]['x'])
+ _torch_batch_to_numpy(target.args[i][0]), data[i]["x"]
+ )
numpy.testing.assert_array_equal(
- _torch_batch_to_numpy(target.args[i][1]), data[i]['y'])
+ _torch_batch_to_numpy(target.args[i][1]), data[i]["y"]
+ )
expect_mean = numpy.mean(
- [numpy.sum(x['x']) + numpy.sum(x['y']) for x in data])
+ [numpy.sum(x["x"]) + numpy.sum(x["y"]) for x in data]
+ )
numpy.testing.assert_almost_equal(
- mean['target/loss'], expect_mean, decimal=4)
+ mean["target/loss"], expect_mean, decimal=4
+ )
def test_evaluator_with_eval_func():
- data = [
- numpy.random.uniform(-1, 1, (3, 4)).astype('f') for _ in range(2)]
+ data = [numpy.random.uniform(-1, 1, (3, 4)).astype("f") for _ in range(2)]
data_loader = torch.utils.data.DataLoader(data)
target = DummyModel()
evaluator = ppe.training.extensions.Evaluator(
- data_loader, {}, eval_func=target)
+ data_loader, {}, eval_func=target
+ )
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
evaluator.evaluate()
assert len(target.args) == len(data)
for i in range(len(data)):
numpy.testing.assert_array_equal(
- _torch_batch_to_numpy(target.args[i]), data[i])
+ _torch_batch_to_numpy(target.args[i]), data[i]
+ )
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def evaluator_with_progress():
- data = [
- numpy.random.uniform(-1, 1, (3, 4)).astype('f') for _ in range(2)]
+ data = [numpy.random.uniform(-1, 1, (3, 4)).astype("f") for _ in range(2)]
data_loader = torch.utils.data.DataLoader(data, batch_size=1)
target = DummyModel()
evaluator = ppe.training.extensions.Evaluator(
- data_loader, {}, eval_func=target, progress_bar=True)
+ data_loader, {}, eval_func=target, progress_bar=True
+ )
return evaluator, target
@@ -208,43 +220,43 @@ def test_evaluator_progress_bar(capsys, evaluator_with_progress):
evaluator, target = evaluator_with_progress
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
evaluator.evaluate()
stdout = capsys.readouterr().out
- assert 'validation [.........' in stdout
- assert ' 0 iterations' in stdout
- assert ' inf iters/sec.' in stdout
- assert 'validation [#########' in stdout
- assert ' 1 iterations' in stdout
+ assert "validation [........." in stdout
+ assert " 0 iterations" in stdout
+ assert " inf iters/sec." in stdout
+ assert "validation [#########" in stdout
+ assert " 1 iterations" in stdout
def test_evaluator_progress_bar_custom_label(capsys, evaluator_with_progress):
evaluator, target = evaluator_with_progress
- evaluator.name = 'my_own_evaluator' # Set a (long) custom name
+ evaluator.name = "my_own_evaluator" # Set a (long) custom name
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
evaluator.evaluate()
stdout = capsys.readouterr().out
- assert 'my_own_evaluator [.........' in stdout
- assert ' 0 iterations' in stdout
- assert ' inf iters/sec.' in stdout
- assert 'my_own_evaluator [#########' in stdout
- assert ' 1 iterations' in stdout
+ assert "my_own_evaluator [........." in stdout
+ assert " 0 iterations" in stdout
+ assert " inf iters/sec." in stdout
+ assert "my_own_evaluator [#########" in stdout
+ assert " 1 iterations" in stdout
# Code excerpts to test IgniteEvaluator
class IgniteDummyModel(torch.nn.Module):
def __init__(self):
super(IgniteDummyModel, self).__init__()
- self.count = 0.
+ self.count = 0.0
def forward(self, *args):
- ppe.reporting.report({'x': self.count}, self)
- self.count += 1.
- return 0.
+ ppe.reporting.report({"x": self.count}, self)
+ self.count += 1.0
+ return 0.0
def create_dummy_evaluator(model):
@@ -265,7 +277,7 @@ def test_ignite_evaluator_reporting_metrics():
try:
from ignite.metrics import MeanSquaredError
except ImportError:
- pytest.skip('pytorch-ignite is not installed')
+ pytest.skip("pytorch-ignite is not installed")
# This tests verifies that either, usuer manually reported metrics
# and ignite calculated ones are correctly reflected in the reporter
@@ -279,7 +291,7 @@ def test_ignite_evaluator_reporting_metrics():
evaluator = create_dummy_evaluator(model)
# Attach metrics to the evaluator
metric = MeanSquaredError()
- metric.attach(evaluator, 'mse')
+ metric.attach(evaluator, "mse")
evaluator_ignite_ext = ppe.training.extensions.IgniteEvaluator(
evaluator, loader, model, progress_bar=False
)
@@ -287,51 +299,58 @@ def test_ignite_evaluator_reporting_metrics():
with reporter:
result = evaluator_ignite_ext()
# Internally reported metrics
- assert result['main/x'] == 1.5
+ assert result["main/x"] == 1.5
# Ignite calculated metric
- assert result['val/mse'] == 0.0
+ assert result["val/mse"] == 0.0
def test_distributed_evaluation():
- dummy_data = [] # Note: has no effect to the evaluation
+ dummy_data = [] # Note: has no effect to the evaluation
data_loader = torch.utils.data.DataLoader(dummy_data)
target = DummyModel()
- with mock.patch.object(dist, 'is_initialized', return_value=True):
+ with mock.patch.object(dist, "is_initialized", return_value=True):
evaluator = ext.DistributedEvaluator(data_loader, target)
# Make a (simulated) summary for each rank
worker_evaluations = [
- [1., 2., 3.], # assuming rank=0
- [4., 5., 6.], # rank=1
- [7., 8., 9.], # ...
- [10., 11., 12.],
+ [1.0, 2.0, 3.0], # assuming rank=0
+ [4.0, 5.0, 6.0], # rank=1
+ [7.0, 8.0, 9.0], # ...
+ [10.0, 11.0, 12.0],
]
worker_summaries = []
for accs in worker_evaluations:
s = ppe.reporting.DictSummary()
for acc in accs:
- s.add({'target/score': acc})
+ s.add({"target/score": acc})
worker_summaries.append(s)
reporter = ppe.reporting.Reporter()
- reporter.add_observer('target', target)
+ reporter.add_observer("target", target)
with reporter:
- with mock.patch.object(ppe.training.extensions.evaluator, '_dist_gather',
- return_value=worker_summaries):
+ with mock.patch.object(
+ ppe.training.extensions.evaluator,
+ "_dist_gather",
+ return_value=worker_summaries,
+ ):
mean = evaluator.evaluate()
- assert mean['target/score'] == 6.5
+ assert mean["target/score"] == 6.5
def test_distributed_evaluator_progress_bar():
- with mock.patch.object(dist, 'is_initialized', return_value=True):
+ with mock.patch.object(dist, "is_initialized", return_value=True):
data_loader = torch.utils.data.DataLoader([])
target = DummyModel()
- with mock.patch.object(dist, 'get_rank', return_value=0):
- evaluator = ext.DistributedEvaluator(data_loader, target, progress_bar=True)
+ with mock.patch.object(dist, "get_rank", return_value=0):
+ evaluator = ext.DistributedEvaluator(
+ data_loader, target, progress_bar=True
+ )
assert evaluator._progress_bar
# rank != 0 will forcibly set progress_bar=False
- with mock.patch.object(dist, 'get_rank', return_value=1):
- evaluator = ext.DistributedEvaluator(data_loader, target, progress_bar=True)
+ with mock.patch.object(dist, "get_rank", return_value=1):
+ evaluator = ext.DistributedEvaluator(
+ data_loader, target, progress_bar=True
+ )
assert not evaluator._progress_bar
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py
index 5284b2289..c086b63bb 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_fail_on_non_number.py
@@ -1,9 +1,7 @@
-import pytest
import numpy
+import pytest
import torch
-
from pytorch_pfn_extras import training
-
from pytorch_pfn_extras.training.extensions import FailOnNonNumber
@@ -16,8 +14,9 @@ def __len__(self):
return len(self.values)
def __getitem__(self, idx):
- return numpy.array(
- [self.values[idx]], numpy.float32), numpy.int64(idx % 2)
+ return numpy.array([self.values[idx]], numpy.float32), numpy.int64(
+ idx % 2
+ )
class Model(torch.nn.Module):
@@ -40,18 +39,18 @@ def forward(ctx, i):
@staticmethod
def backward(ctx, grad_output):
- return grad_output + float('nan')
+ return grad_output + float("nan")
def get_manager_model_optimizer(*, check_grad=True, grad_error=False):
epochs = 3
model = Model(grad_error)
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
- optimizers = {'main': optimizer}
- models = {'main': model}
+ optimizers = {"main": optimizer}
+ models = {"main": model}
manager = training.ExtensionsManager(
- models, optimizers, epochs,
- iters_per_epoch=4)
+ models, optimizers, epochs, iters_per_epoch=4
+ )
manager.extend(FailOnNonNumber(check_grad=check_grad))
return manager, model, optimizer
@@ -79,27 +78,29 @@ def test_valid():
def test_nan():
manager, model, optimizer = get_manager_model_optimizer()
with torch.no_grad():
- model.l1.weight[1, 0] = float('NaN')
- with pytest.raises(RuntimeError, match='diverge'):
+ model.l1.weight[1, 0] = float("NaN")
+ with pytest.raises(RuntimeError, match="diverge"):
run_train(manager, model, optimizer)
def test_inf():
manager, model, optimizer = get_manager_model_optimizer()
with torch.no_grad():
- model.l1.weight[2, 0] = float('inf')
- with pytest.raises(RuntimeError, match='diverge'):
+ model.l1.weight[2, 0] = float("inf")
+ with pytest.raises(RuntimeError, match="diverge"):
run_train(manager, model, optimizer)
def test_check_grad():
manager, model, optimizer = get_manager_model_optimizer(
- check_grad=True, grad_error=True)
- with pytest.raises(RuntimeError, match='diverge'):
+ check_grad=True, grad_error=True
+ )
+ with pytest.raises(RuntimeError, match="diverge"):
run_train(manager, model, optimizer, optimizer_step=False)
def test_no_check_grad():
manager, model, optimizer = get_manager_model_optimizer(
- check_grad=False, grad_error=True)
+ check_grad=False, grad_error=True
+ )
run_train(manager, model, optimizer, optimizer_step=False)
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py
index 277b2626b..4e947c642 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_buffer.py
@@ -1,28 +1,30 @@
import io
import os
import tempfile
-import yaml
import pytorch_pfn_extras as ppe
+import yaml
from pytorch_pfn_extras.training import extensions
-from pytorch_pfn_extras.training.extensions import log_report as log_report_module
+from pytorch_pfn_extras.training.extensions import (
+ log_report as log_report_module,
+)
def test_log_buffer():
buf = log_report_module._LogBuffer()
looker = buf.emit_new_looker()
assert buf.size() == 0
- buf.append('mes1')
- buf.append('mes2')
+ buf.append("mes1")
+ buf.append("mes2")
assert buf.size() == 2
- assert looker.get() == ['mes1', 'mes2']
+ assert looker.get() == ["mes1", "mes2"]
assert buf.size() == 2
looker.clear()
assert buf.size() == 0
assert looker.get() == []
- buf.append('mes3')
+ buf.append("mes3")
assert buf.size() == 1
- assert looker.get() == ['mes3']
+ assert looker.get() == ["mes3"]
assert buf.size() == 1
looker.clear()
assert buf.size() == 0
@@ -33,15 +35,15 @@ def test_log_buffer_multiple_lookers():
buf = log_report_module._LogBuffer()
looker1 = buf.emit_new_looker()
looker2 = buf.emit_new_looker()
- buf.append('mes1')
- assert looker1.get() == ['mes1']
- assert looker2.get() == ['mes1']
+ buf.append("mes1")
+ assert looker1.get() == ["mes1"]
+ assert looker2.get() == ["mes1"]
assert buf.size() == 1
looker2.clear()
assert buf.size() == 1
- buf.append('mes2')
- assert looker1.get() == ['mes1', 'mes2']
- assert looker2.get() == ['mes2']
+ buf.append("mes2")
+ assert looker1.get() == ["mes1", "mes2"]
+ assert looker2.get() == ["mes2"]
assert buf.size() == 2
looker2.clear()
assert buf.size() == 2
@@ -55,11 +57,13 @@ def test_buffer_size_log_report():
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir
+ )
log_report = extensions.LogReport(
- filename='out', format='yaml', append=True)
- manager.extend(log_report, (1, 'epoch'))
+ filename="out", format="yaml", append=True
+ )
+ manager.extend(log_report, (1, "epoch"))
for _ in range(max_epochs):
for _ in range(iters_per_epoch):
@@ -67,7 +71,7 @@ def test_buffer_size_log_report():
with manager.run_iteration():
pass
- with open(os.path.join(tmpdir, 'out')) as f:
+ with open(os.path.join(tmpdir, "out")) as f:
data = f.read()
values = yaml.load(data, Loader=yaml.SafeLoader)
assert len(values) == max_epochs
@@ -79,15 +83,17 @@ def test_buffer_size_log_report_and_print_report():
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch, out_dir=tmpdir
+ )
log_report = extensions.LogReport(
- filename='out', format='yaml', append=True)
- manager.extend(log_report, trigger=(1, 'epoch'))
+ filename="out", format="yaml", append=True
+ )
+ manager.extend(log_report, trigger=(1, "epoch"))
out = io.StringIO()
print_report = extensions.PrintReport(out=out)
- manager.extend(print_report, trigger=(3, 'epoch'))
+ manager.extend(print_report, trigger=(3, "epoch"))
for _ in range(max_epochs):
for _ in range(iters_per_epoch):
@@ -95,7 +101,7 @@ def test_buffer_size_log_report_and_print_report():
with manager.run_iteration():
pass
- with open(os.path.join(tmpdir, 'out')) as f:
+ with open(os.path.join(tmpdir, "out")) as f:
data = f.read()
values = yaml.load(data, Loader=yaml.SafeLoader)
assert len(values) == max_epochs
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py
index 221500b8b..f241b9fdf 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_log_report.py
@@ -3,21 +3,20 @@
import tempfile
import pytest
-import yaml
-
import pytorch_pfn_extras as ppe
+import yaml
from pytorch_pfn_extras.training import extensions
@pytest.mark.parametrize(
- 'filename,expected_format',
+ "filename,expected_format",
[
- ('out.json', 'json'),
- ('out.xyz', 'json'),
- (None, 'json'),
- ('out.yaml', 'yaml'),
- ('out.jsonl', 'json-lines'),
- ]
+ ("out.json", "json"),
+ ("out.xyz", "json"),
+ (None, "json"),
+ ("out.yaml", "yaml"),
+ ("out.jsonl", "json-lines"),
+ ],
)
def test_format_from_ext(filename, expected_format):
log_report = extensions.LogReport(filename=filename, format=None)
@@ -25,14 +24,14 @@ def test_format_from_ext(filename, expected_format):
@pytest.mark.parametrize(
- 'format,append',
+ "format,append",
[
- ('json', False),
- ('json-lines', True),
- ('json-lines', False),
- ('yaml', True),
- ('yaml', False),
- ]
+ ("json", False),
+ ("json-lines", True),
+ ("json-lines", False),
+ ("yaml", True),
+ ("yaml", False),
+ ],
)
def test_output(format, append):
max_epochs = 3
@@ -40,34 +39,42 @@ def test_output(format, append):
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch,
- out_dir=tmpdir)
+ {},
+ {},
+ max_epochs=max_epochs,
+ iters_per_epoch=iters_per_epoch,
+ out_dir=tmpdir,
+ )
log_report = extensions.LogReport(
- filename='out', format=format, append=append)
+ filename="out", format=format, append=append
+ )
manager.extend(log_report)
for epoch_idx in range(max_epochs):
for _ in range(iters_per_epoch):
with manager.run_iteration():
pass
- with open(os.path.join(tmpdir, 'out')) as f:
+ with open(os.path.join(tmpdir, "out")) as f:
data = f.read()
- if format == 'json':
+ if format == "json":
values = json.loads(data)
- elif format == 'json-lines':
+ elif format == "json-lines":
values = [json.loads(x) for x in data.splitlines()]
- elif format == 'yaml':
+ elif format == "yaml":
values = yaml.load(data, Loader=yaml.SafeLoader)
assert len(values) == epoch_idx + 1
this_epoch = values.pop()
- assert this_epoch['epoch'] == epoch_idx + 1
- assert (this_epoch['iteration']
- == (epoch_idx + 1) * iters_per_epoch)
- assert 0 < this_epoch['elapsed_time']
+ assert this_epoch["epoch"] == epoch_idx + 1
+ assert (
+ this_epoch["iteration"] == (epoch_idx + 1) * iters_per_epoch
+ )
+ assert 0 < this_epoch["elapsed_time"]
-@pytest.mark.filterwarnings("ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning")
+@pytest.mark.filterwarnings(
+ "ignore:`np.bool8` is a deprecated alias for `np.bool_`:DeprecationWarning"
+)
def test_tensorboard_writer():
- pytest.importorskip('tensorboard')
+ pytest.importorskip("tensorboard")
max_epochs = 3
iters_per_epoch = 5
@@ -75,10 +82,15 @@ def test_tensorboard_writer():
with tempfile.TemporaryDirectory() as tmpdir:
writer = ppe.writing.TensorBoardWriter(out_dir=tmpdir)
log_report = extensions.LogReport(
- writer=writer, trigger=(1, 'iteration'))
+ writer=writer, trigger=(1, "iteration")
+ )
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch,
- out_dir=tmpdir)
+ {},
+ {},
+ max_epochs=max_epochs,
+ iters_per_epoch=iters_per_epoch,
+ out_dir=tmpdir,
+ )
manager.extend(log_report)
for _ in range(max_epochs):
for _ in range(iters_per_epoch):
@@ -89,13 +101,13 @@ def test_tensorboard_writer():
files = os.listdir(tmpdir)
assert len(files) == 1
tb_file = files[0]
- assert tb_file.startswith('events.out.')
+ assert tb_file.startswith("events.out.")
# Won't play with protobuf, just ensure that our keys are in.
- with open(os.path.join(tmpdir, tb_file), 'rb') as f:
+ with open(os.path.join(tmpdir, tb_file), "rb") as f:
tb_data = f.read()
- for key in ['epoch', 'iteration', 'elapsed_time']:
- assert key.encode('ascii') in tb_data
+ for key in ["epoch", "iteration", "elapsed_time"]:
+ assert key.encode("ascii") in tb_data
def test_deferred_values():
@@ -104,15 +116,19 @@ def test_deferred_values():
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch,
- out_dir=tmpdir)
+ {},
+ {},
+ max_epochs=max_epochs,
+ iters_per_epoch=iters_per_epoch,
+ out_dir=tmpdir,
+ )
log_report = extensions.LogReport(filename="out")
manager.extend(log_report)
for epoch_idx in range(max_epochs):
for _ in range(iters_per_epoch):
with manager.run_iteration():
ppe.reporting.report({"x": lambda: epoch_idx})
- with open(os.path.join(tmpdir, 'out')) as f:
+ with open(os.path.join(tmpdir, "out")) as f:
data = f.read()
values = json.loads(data)
assert len(values) == epoch_idx + 1
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py
index 24612c896..d0be24cd2 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_lr_scheduler.py
@@ -1,9 +1,8 @@
import tempfile
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
def _setup_manager():
@@ -12,63 +11,73 @@ def _setup_manager():
sched = torch.optim.lr_scheduler.MultiStepLR(
optim, milestones=[1, 2, 3], gamma=0.1, last_epoch=-1
)
- ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, 'iteration'))
+ ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, "iteration"))
manager = ppe.training.ExtensionsManager(
- {}, {'main': optim}, 1, extensions=[ext], iters_per_epoch=40)
+ {}, {"main": optim}, 1, extensions=[ext], iters_per_epoch=40
+ )
return optim, manager
def test_lr_scheduler():
optim, manager = _setup_manager()
for i in range(4):
- with manager.run_iteration(step_optimizers=['main']):
+ with manager.run_iteration(step_optimizers=["main"]):
if i < 1:
- assert optim.param_groups[0]['lr'] == pytest.approx(1.0)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1.0)
elif i < 2:
- assert optim.param_groups[0]['lr'] == pytest.approx(1e-1)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1e-1)
elif i < 3:
- assert optim.param_groups[0]['lr'] == pytest.approx(1e-2)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1e-2)
elif i < 4:
- assert optim.param_groups[0]['lr'] == pytest.approx(1e-3)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1e-3)
def test_serialize_scheduler():
optim, manager = _setup_manager()
for i in range(2):
- with manager.run_iteration(step_optimizers=['main']):
+ with manager.run_iteration(step_optimizers=["main"]):
if i < 1:
- assert optim.param_groups[0]['lr'] == pytest.approx(1.0)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1.0)
else:
- assert optim.param_groups[0]['lr'] == pytest.approx(1e-1)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1e-1)
state = manager.state_dict()
optim, manager = _setup_manager()
manager.load_state_dict(state)
for i in range(2):
- with manager.run_iteration(step_optimizers=['main']):
+ with manager.run_iteration(step_optimizers=["main"]):
if i < 1:
- assert optim.param_groups[0]['lr'] == pytest.approx(1e-2)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1e-2)
else:
- assert optim.param_groups[0]['lr'] == pytest.approx(1e-3)
+ assert optim.param_groups[0]["lr"] == pytest.approx(1e-3)
def test_reduce_lr_on_plateau():
param = torch.nn.Parameter(torch.zeros(10))
optim = torch.optim.SGD([param], 1.0)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
- optim, patience=1, verbose=True)
- ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, 'iteration'))
+ optim, patience=1, verbose=True
+ )
+ ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, "iteration"))
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {'main': optim}, 1, extensions=[ext], iters_per_epoch=4,
- out_dir=tmpdir)
- manager.extend(ppe.training.extensions.LogReport(
- filename=None, trigger=(1, "iteration")))
+ {},
+ {"main": optim},
+ 1,
+ extensions=[ext],
+ iters_per_epoch=4,
+ out_dir=tmpdir,
+ )
+ manager.extend(
+ ppe.training.extensions.LogReport(
+ filename=None, trigger=(1, "iteration")
+ )
+ )
for _ in range(4):
with manager.run_iteration():
- ppe.reporting.report({'val/loss': 1.0})
+ ppe.reporting.report({"val/loss": 1.0})
lr = optim.param_groups[0]["lr"]
assert lr == pytest.approx(1e-1)
@@ -77,11 +86,13 @@ def test_reduce_lr_on_plateau_no_report():
param = torch.nn.Parameter(torch.zeros(10))
optim = torch.optim.SGD([param], 1.0)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
- optim, patience=1, verbose=True)
- ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, 'iteration'))
+ optim, patience=1, verbose=True
+ )
+ ext = ppe.training.extensions.LRScheduler(sched, trigger=(1, "iteration"))
manager = ppe.training.ExtensionsManager(
- {}, {'main': optim}, 1, extensions=[ext], iters_per_epoch=4)
+ {}, {"main": optim}, 1, extensions=[ext], iters_per_epoch=4
+ )
with pytest.raises(ValueError):
with manager.run_iteration():
pass
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py
index 339f0e1bb..bc7eff230 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_micro_average.py
@@ -1,5 +1,4 @@
import numpy
-
import pytorch_pfn_extras as ppe
@@ -11,22 +10,26 @@ def test_run():
# NumPy<1.17 does not support array-like inputs in `numpy.random.randint`.
data_correct = numpy.random.randint(10000, size=data_shape) % data_total
- manager = ppe.training.ExtensionsManager(
- {}, [], 100,
- iters_per_epoch=5)
+ manager = ppe.training.ExtensionsManager({}, [], 100, iters_per_epoch=5)
extension = ppe.training.extensions.MicroAverage(
- 'main/correct', 'main/total', 'main/accuracy',
- (trigger_iters, 'iteration'))
- manager.extend(extension, trigger=(1, 'iteration'))
+ "main/correct",
+ "main/total",
+ "main/accuracy",
+ (trigger_iters, "iteration"),
+ )
+ manager.extend(extension, trigger=(1, "iteration"))
for js in numpy.ndindex(data_shape):
with manager.run_iteration():
- ppe.reporting.report({
- 'main/correct': data_correct[js],
- 'main/total': data_total[js],
- })
+ ppe.reporting.report(
+ {
+ "main/correct": data_correct[js],
+ "main/total": data_total[js],
+ }
+ )
assert (
# average is computed every trigger_iters
- ('main/accuracy' in manager.observation)
- == (js[1] == trigger_iters - 1))
+ ("main/accuracy" in manager.observation)
+ == (js[1] == trigger_iters - 1)
+ )
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py
index e9bfa81b6..2d56f1a2a 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_plot_report.py
@@ -1,7 +1,6 @@
import warnings
import pytest
-
from pytorch_pfn_extras.training import extensions
@@ -9,6 +8,7 @@
def matplotlib_or_none():
try:
import matplotlib
+
return matplotlib
except ImportError:
return None
@@ -17,7 +17,7 @@ def matplotlib_or_none():
@pytest.fixture(scope="module")
def matplotlib(matplotlib_or_none):
if matplotlib_or_none is None:
- pytest.skip('matplotlib is not installed')
+ pytest.skip("matplotlib is not installed")
return matplotlib_or_none
@@ -36,9 +36,9 @@ def test_lazy_import(matplotlib):
# has to be called earlier.
with warnings.catch_warnings():
- warnings.simplefilter('error')
- matplotlib.use('Agg')
+ warnings.simplefilter("error")
+ matplotlib.use("Agg")
# Test again with a different backend, because the above does not
# generate a warning if matplotlib.use('Agg') is called and then
# matplotlib.pyplot is imported.
- matplotlib.use('PS')
+ matplotlib.use("PS")
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py
index 8b0f013bd..fa4e5d672 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report.py
@@ -8,7 +8,8 @@ def test_run_print_report():
max_epochs = 5
iters_per_epoch = 5
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch
+ )
out = io.StringIO()
log_report = extensions.LogReport()
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py
index c4d91f431..3007b6b40 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py
@@ -1,7 +1,6 @@
import io
import pytest
-
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training.extensions import _ipython_module_available
from pytorch_pfn_extras.training.extensions.log_report import _pandas_available
@@ -10,13 +9,14 @@
@pytest.mark.skipif(
not _ipython_module_available or not _pandas_available,
reason="print report notebook import failed, "
- "maybe ipython is not installed"
+ "maybe ipython is not installed",
)
def test_run_print_report_notebook():
max_epochs = 5
iters_per_epoch = 5
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch
+ )
out = io.StringIO()
log_report = ppe.training.extensions.LogReport()
@@ -32,5 +32,5 @@ def test_run_print_report_notebook():
pass
-if __name__ == '__main__':
- pytest.main([__file__, '-v', '-s'])
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "-s"])
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py
index a3844cc65..789f39a5e 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_profile_report.py
@@ -1,12 +1,11 @@
+import json
+import os
import tempfile
import time
-import os
-import json
import pytest
-import yaml
-
import pytorch_pfn_extras as ppe
+import yaml
def _body():
@@ -15,14 +14,14 @@ def _body():
@pytest.mark.parametrize(
- 'format,append',
+ "format,append",
[
- ('json', False),
- ('json-lines', True),
- ('json-lines', False),
- ('yaml', True),
- ('yaml', False),
- ]
+ ("json", False),
+ ("json-lines", True),
+ ("json-lines", False),
+ ("yaml", True),
+ ("yaml", False),
+ ],
)
def test_profile_report(format, append):
ext = ppe.training.extensions.ProfileReport(format=format, append=append)
@@ -31,22 +30,26 @@ def test_profile_report(format, append):
# ppe.profiler.time_summary.clear()
with tempfile.TemporaryDirectory() as tmpdir:
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs=max_epochs, iters_per_epoch=iters_per_epoch,
- out_dir=tmpdir)
+ {},
+ {},
+ max_epochs=max_epochs,
+ iters_per_epoch=iters_per_epoch,
+ out_dir=tmpdir,
+ )
manager.extend(ext)
for _epoch_idx in range(max_epochs):
for _ in range(iters_per_epoch):
with manager.run_iteration():
_body()
- with open(os.path.join(tmpdir, 'log')) as f:
+ with open(os.path.join(tmpdir, "log")) as f:
data = f.read()
- if format == 'json':
+ if format == "json":
values = json.loads(data)
- elif format == 'json-lines':
+ elif format == "json-lines":
values = [json.loads(x) for x in data.splitlines()]
- elif format == 'yaml':
+ elif format == "yaml":
values = yaml.load(data, Loader=yaml.SafeLoader)
assert len(values) == _epoch_idx + 1
for value in values:
- assert abs(value['iter-time'] - 0.1) < 2e-2
+ assert abs(value["iter-time"] - 0.1) < 2e-2
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py
index 84a0c7246..cb7874dce 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar.py
@@ -9,7 +9,8 @@ def test_run():
max_epochs = 5
iters_per_epoch = 5
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch
+ )
out = io.StringIO()
extension = ppe.training.extensions.ProgressBar(
@@ -27,8 +28,13 @@ def test_run():
if manager.iteration < 2:
continue
status = out.getvalue()
- assert '{} iter, {} epoch / {} epochs'.format(
- manager.iteration, epoch, max_epochs) in status
+ assert (
+ "{} iter, {} epoch / {} epochs".format(
+ manager.iteration, epoch, max_epochs
+ )
+ in status
+ )
iters_per_sec = float(
- re.findall(r'([0-9]+\.[0-9]*) iters/sec', status)[-1])
+ re.findall(r"([0-9]+\.[0-9]*) iters/sec", status)[-1]
+ )
assert 7 <= iters_per_sec <= 12
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py
index 57559b329..4496820bd 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_progress_bar_notebook.py
@@ -1,26 +1,26 @@
import io
import pytest
+import pytorch_pfn_extras as ppe
import torch
+from pytorch_pfn_extras import training
+from pytorch_pfn_extras.training.extensions import _ipython_module_available
from torch import nn
from torch.nn import Linear
from torch.optim.adam import Adam
-import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras import training
-from pytorch_pfn_extras.training.extensions import _ipython_module_available
-
@pytest.mark.skipif(
not _ipython_module_available,
reason="progress bar notebook import failed, "
- "maybe ipython is not installed"
+ "maybe ipython is not installed",
)
def test_run_progress_bar_notebook():
max_epochs = 5
iters_per_epoch = 5
manager = ppe.training.ExtensionsManager(
- {}, {}, max_epochs, iters_per_epoch=iters_per_epoch)
+ {}, {}, max_epochs, iters_per_epoch=iters_per_epoch
+ )
out = io.StringIO()
extension = ppe.training.extensions.ProgressBarNotebook(
@@ -36,22 +36,22 @@ def test_run_progress_bar_notebook():
with manager.run_iteration():
if manager.iteration < 2:
continue
- status = '{} iter, {} epoch / {} epochs'.format(
- manager.iteration, epoch, max_epochs)
+ status = "{} iter, {} epoch / {} epochs".format(
+ manager.iteration, epoch, max_epochs
+ )
assert status in extension._status_html.value
@pytest.mark.skipif(
not _ipython_module_available,
reason="progress bar notebook import failed, "
- "maybe ipython is not installed"
+ "maybe ipython is not installed",
)
def test_ignite_extensions_manager_with_progressbar_notebook():
-
try:
from ignite.engine import create_supervised_trainer
except ImportError:
- pytest.skip('pytorch-ignite not found')
+ pytest.skip("pytorch-ignite not found")
max_epochs = 5
iters_per_epoch = 4
@@ -70,21 +70,21 @@ def forward(self, *args):
def _fake_loss(*args):
return torch.tensor([0.0], requires_grad=True)
- trainer = create_supervised_trainer(
- model, optimizer, _fake_loss)
+ trainer = create_supervised_trainer(model, optimizer, _fake_loss)
manager = training.IgniteExtensionsManager(
trainer,
- {'model_name': model},
- {'optimizer_name': optimizer},
+ {"model_name": model},
+ {"optimizer_name": optimizer},
max_epochs,
)
manager.extend(ppe.training.extensions.ProgressBarNotebook())
loader = torch.utils.data.DataLoader(
- [(i, i) for i in range(iters_per_epoch)])
+ [(i, i) for i in range(iters_per_epoch)]
+ )
trainer.run(loader, max_epochs=max_epochs)
-if __name__ == '__main__':
- pytest.main([__file__, '-v', '-s'])
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "-s"])
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py
index 8dbce0b8e..d9f6e039f 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_slack.py
@@ -3,150 +3,157 @@
from unittest import mock
import pytest
-
import pytorch_pfn_extras as ppe
@pytest.mark.skipif(
not ppe.training.extensions.slack._slack_sdk_available,
- reason="Slack SDK not installed"
+ reason="Slack SDK not installed",
)
class TestSlack:
def _get_manager(self):
return ppe.training.ExtensionsManager({}, [], 1, iters_per_epoch=5)
- @pytest.mark.parametrize('thread',[False, True])
+ @pytest.mark.parametrize("thread", [False, True])
def test_post_message(self, thread):
manager = self._get_manager()
- message = 'It {manager.iteration} loss: {loss}'
+ message = "It {manager.iteration} loss: {loss}"
extension = ppe.training.extensions.Slack(
- '0', message, token='123', thread=thread)
+ "0", message, token="123", thread=thread
+ )
t_ts = None
if thread:
t_ts = 1
- manager.extend(extension, trigger=(1, 'iteration'))
+ manager.extend(extension, trigger=(1, "iteration"))
with mock.patch(
- 'slack_sdk.WebClient.chat_postMessage',
- return_value={'ok': True, 'ts': t_ts},
+ "slack_sdk.WebClient.chat_postMessage",
+ return_value={"ok": True, "ts": t_ts},
) as patched:
with manager.run_iteration():
- assert 'Training started' in patched.call_args.kwargs["text"]
- ppe.reporting.report({'loss': 0.5})
+ assert "Training started" in patched.call_args.kwargs["text"]
+ ppe.reporting.report({"loss": 0.5})
patched.assert_called_with(
- channel='0', text='It 1 loss: 0.5', thread_ts=t_ts
+ channel="0", text="It 1 loss: 0.5", thread_ts=t_ts
)
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.75})
+ ppe.reporting.report({"loss": 0.75})
patched.assert_called_with(
- channel='0', text='It 2 loss: 0.75', thread_ts=t_ts
+ channel="0", text="It 2 loss: 0.75", thread_ts=t_ts
)
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.75})
+ ppe.reporting.report({"loss": 0.75})
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.75})
+ ppe.reporting.report({"loss": 0.75})
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.75})
- assert 'Training finish' in patched.call_args.kwargs["text"]
+ ppe.reporting.report({"loss": 0.75})
+ assert "Training finish" in patched.call_args.kwargs["text"]
def test_post_message_on_error(self):
manager = self._get_manager()
- message = 'It {manager.iteration} loss: {loss}'
+ message = "It {manager.iteration} loss: {loss}"
extension = ppe.training.extensions.Slack(
- '0', message, token='123', thread=False)
+ "0", message, token="123", thread=False
+ )
t_ts = None
- manager.extend(extension, trigger=(1, 'iteration'))
+ manager.extend(extension, trigger=(1, "iteration"))
with mock.patch(
- 'slack_sdk.WebClient.chat_postMessage',
- return_value={'ok': True, 'ts': t_ts},
+ "slack_sdk.WebClient.chat_postMessage",
+ return_value={"ok": True, "ts": t_ts},
) as patched:
try:
with manager.run_iteration():
- raise RuntimeError('error')
+ raise RuntimeError("error")
except RuntimeError:
- assert 'Error during' in patched.call_args.kwargs["text"]
+ assert "Error during" in patched.call_args.kwargs["text"]
def test_post_message_webhook(self):
manager = self._get_manager()
- message = 'It {manager.iteration} loss: {loss}'
+ message = "It {manager.iteration} loss: {loss}"
extension = ppe.training.extensions.SlackWebhook(
- url="http://test", msg=message)
+ url="http://test", msg=message
+ )
- manager.extend(extension, trigger=(1, 'iteration'))
- payload_1 = json.dumps({'text': "It 1 loss: 0.5"}).encode('utf-8')
- payload_2 = json.dumps({'text': "It 2 loss: 0.75"}).encode('utf-8')
+ manager.extend(extension, trigger=(1, "iteration"))
+ payload_1 = json.dumps({"text": "It 1 loss: 0.5"}).encode("utf-8")
+ payload_2 = json.dumps({"text": "It 2 loss: 0.75"}).encode("utf-8")
with mock.patch(
- 'urllib.request.urlopen',
+ "urllib.request.urlopen",
return_value=SimpleNamespace(status=200),
) as patched:
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.5})
+ ppe.reporting.report({"loss": 0.5})
assert patched.call_args.args[0].data == payload_1
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.75})
+ ppe.reporting.report({"loss": 0.75})
assert patched.call_args.args[0].data == payload_2
@pytest.mark.parametrize(
- 'message',
+ "message",
[
- 'It {manager.iteration} loss: {loss} custom: {context.foo}',
- lambda m, c: 'It {manager.iteration} loss: {loss} custom: {context.foo}'.format( # NOQA
- manager=m, context=c, **m.observation)
- ]
+ "It {manager.iteration} loss: {loss} custom: {context.foo}",
+ lambda m, c: "It {manager.iteration} loss: {loss} custom: {context.foo}".format( # NOQA
+ manager=m, context=c, **m.observation
+ ),
+ ],
)
def test_post_message_context(self, message):
class _CustomContext:
def __init__(self):
- self.foo = 'bar'
+ self.foo = "bar"
manager = self._get_manager()
context = _CustomContext()
extension = ppe.training.extensions.Slack(
- '0', message, context=context, token='123')
- manager.extend(extension, trigger=(1, 'iteration'))
+ "0", message, context=context, token="123"
+ )
+ manager.extend(extension, trigger=(1, "iteration"))
with mock.patch(
- 'slack_sdk.WebClient.chat_postMessage',
- return_value={'ok': True, 'ts': 1},
+ "slack_sdk.WebClient.chat_postMessage",
+ return_value={"ok": True, "ts": 1},
) as patched:
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.5})
+ ppe.reporting.report({"loss": 0.5})
patched.assert_called_with(
- channel='0', text='It 1 loss: 0.5 custom: bar',
- thread_ts=1
+ channel="0", text="It 1 loss: 0.5 custom: bar", thread_ts=1
)
- context.foo = 'test'
+ context.foo = "test"
with manager.run_iteration():
- ppe.reporting.report({'loss': 0.75})
+ ppe.reporting.report({"loss": 0.75})
patched.assert_called_with(
- channel='0', text='It 2 loss: 0.75 custom: test',
- thread_ts=1
+ channel="0", text="It 2 loss: 0.75 custom: test", thread_ts=1
)
def test_post_message_files(self):
manager = self._get_manager()
- message = 'it: {manager.iteration}'
- filenames = ['file_{manager.iteration}', '{manager._out}/abc']
+ message = "it: {manager.iteration}"
+ filenames = ["file_{manager.iteration}", "{manager._out}/abc"]
extension = ppe.training.extensions.Slack(
- '0', message, filenames=filenames, token='123')
- manager.extend(extension, trigger=(1, 'iteration'))
+ "0", message, filenames=filenames, token="123"
+ )
+ manager.extend(extension, trigger=(1, "iteration"))
with mock.patch(
- 'slack_sdk.WebClient.chat_postMessage',
- return_value={'ok': True, 'ts': 1},
- ), mock.patch('slack_sdk.WebClient.files_upload') as upload:
+ "slack_sdk.WebClient.chat_postMessage",
+ return_value={"ok": True, "ts": 1},
+ ), mock.patch("slack_sdk.WebClient.files_upload") as upload:
with manager.run_iteration():
pass
- upload.assert_has_calls([
- mock.call(file='file_1'),
- mock.call(file='result/abc'),
- ], any_order=True)
+ upload.assert_has_calls(
+ [
+ mock.call(file="file_1"),
+ mock.call(file="result/abc"),
+ ],
+ any_order=True,
+ )
def test_invalid(self):
- message = 'it: {manager.iteration}'
- filenames = ['file_{manager.iteration}', '{manager._out}/abc']
- with pytest.raises(RuntimeError, match='needed for communicating'):
+ message = "it: {manager.iteration}"
+ filenames = ["file_{manager.iteration}", "{manager._out}/abc"]
+ with pytest.raises(RuntimeError, match="needed for communicating"):
ppe.training.extensions.Slack(
- '0', message, start_msg=None, end_msg=None, filenames=filenames)
+ "0", message, start_msg=None, end_msg=None, filenames=filenames
+ )
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py
index 75675bb0f..dd1608bff 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot.py
@@ -5,29 +5,29 @@
import time
from unittest import mock
-import torch
import pytest
-
import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras import training
+import torch
+from pytorch_pfn_extras import training, writing
from pytorch_pfn_extras.training import extensions
from pytorch_pfn_extras.training.extensions._snapshot import (
- _find_latest_snapshot, _find_snapshot_files, _find_stale_snapshots)
-from pytorch_pfn_extras import writing
+ _find_latest_snapshot,
+ _find_snapshot_files,
+ _find_stale_snapshots,
+)
def get_trainer(*, out_dir, state_to_load=None, epochs=10):
model_state_dict = {}
optimizer_state_dict = {}
- models = {'main': _StateDictModel(state_dict=model_state_dict)}
- optimizers = {'main': _StateDictObj(state_dict=optimizer_state_dict)}
+ models = {"main": _StateDictModel(state_dict=model_state_dict)}
+ optimizers = {"main": _StateDictObj(state_dict=optimizer_state_dict)}
return training.ExtensionsManager(
- models, optimizers, epochs,
- iters_per_epoch=10,
- out_dir=out_dir)
+ models, optimizers, epochs, iters_per_epoch=10, out_dir=out_dir
+ )
-class _StateDictObj():
+class _StateDictObj:
def __init__(self, *, state_dict=None):
super().__init__()
self.called_load_state_dict = 0
@@ -62,7 +62,8 @@ def test_call():
def test_savefun_and_writer_exclusive():
# savefun and writer arguments cannot be specified together.
def savefun(*args, **kwargs):
- pytest.fail('never reach')
+ pytest.fail("never reach")
+
writer = writing.SimpleWriter()
with pytest.raises(TypeError):
extensions.snapshot(savefun=savefun, writer=writer)
@@ -72,104 +73,98 @@ def savefun(*args, **kwargs):
extensions.snapshot_object(trainer, savefun=savefun, writer=writer)
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def remover():
yield
- if os.path.exists('myfile.dat'):
- os.remove('myfile.dat')
+ if os.path.exists("myfile.dat"):
+ os.remove("myfile.dat")
def test_save_file(remover):
- trainer = get_trainer(out_dir='.')
+ trainer = get_trainer(out_dir=".")
trainer._done = True
w = writing.SimpleWriter()
- snapshot = extensions.snapshot_object(trainer, 'myfile.dat',
- writer=w)
+ snapshot = extensions.snapshot_object(trainer, "myfile.dat", writer=w)
snapshot(trainer)
- assert os.path.exists('myfile.dat')
+ assert os.path.exists("myfile.dat")
def test_multi_target(remover):
- trainer = get_trainer(out_dir='.')
+ trainer = get_trainer(out_dir=".")
trainer._done = True
- other_state_dict = {'test': True}
+ other_state_dict = {"test": True}
other = _StateDictObj(state_dict=other_state_dict)
w = ppe.writing.SimpleWriter()
- target = {'trainer': trainer, 'other': other}
- snapshot = extensions.snapshot_object(target, 'myfile.dat',
- writer=w)
+ target = {"trainer": trainer, "other": other}
+ snapshot = extensions.snapshot_object(target, "myfile.dat", writer=w)
snapshot(trainer)
- assert os.path.exists('myfile.dat')
+ assert os.path.exists("myfile.dat")
# Load the snapshot and verify it
- state = torch.load('myfile.dat')
- new_trainer = get_trainer(out_dir='.')
+ state = torch.load("myfile.dat")
+ new_trainer = get_trainer(out_dir=".")
new_other = _StateDictObj(state_dict={})
- new_trainer.load_state_dict(state['trainer'])
- new_other.load_state_dict(state['other'])
+ new_trainer.load_state_dict(state["trainer"])
+ new_other.load_state_dict(state["other"])
assert new_trainer.state_dict() == trainer.state_dict()
assert new_other.state_dict() == other_state_dict
def test_multi_target_autoload(remover):
- trainer = get_trainer(out_dir='.')
+ trainer = get_trainer(out_dir=".")
trainer._done = True
- other_state_dict = {'test': True}
+ other_state_dict = {"test": True}
other = _StateDictObj(state_dict=other_state_dict)
w = ppe.writing.SimpleWriter()
- target = {'trainer': trainer, 'other': other}
- snapshot = extensions.snapshot_object(target, 'myfile.dat',
- writer=w)
+ target = {"trainer": trainer, "other": other}
+ snapshot = extensions.snapshot_object(target, "myfile.dat", writer=w)
snapshot(trainer)
- assert os.path.exists('myfile.dat')
- new_trainer = get_trainer(out_dir='.')
+ assert os.path.exists("myfile.dat")
+ new_trainer = get_trainer(out_dir=".")
new_other = _StateDictObj(state_dict={})
- target = {'trainer': new_trainer, 'other': new_other}
- snapshot2 = extensions.snapshot_object(target, 'myfile.dat',
- autoload=True)
+ target = {"trainer": new_trainer, "other": new_other}
+ snapshot2 = extensions.snapshot_object(target, "myfile.dat", autoload=True)
# Load the snapshot and verify it
- assert snapshot2.initialize(new_trainer) == 'myfile.dat'
+ assert snapshot2.initialize(new_trainer) == "myfile.dat"
assert new_trainer.state_dict() == trainer.state_dict()
assert new_other.state_dict() == other_state_dict
def test_multi_target_autoload_not_found(remover):
- trainer = get_trainer(out_dir='.')
- other = _StateDictObj(state_dict={'original': 'state'})
+ trainer = get_trainer(out_dir=".")
+ other = _StateDictObj(state_dict={"original": "state"})
- target = {'trainer': trainer, 'other': other}
- snapshot = extensions.snapshot_object(target, 'myfile.dat',
- autoload=True)
+ target = {"trainer": trainer, "other": other}
+ snapshot = extensions.snapshot_object(target, "myfile.dat", autoload=True)
assert snapshot.initialize(trainer) is None
- assert other.state_dict() == {'original': 'state'}
+ assert other.state_dict() == {"original": "state"}
def test_clean_up_tempdir(remover):
- trainer = get_trainer(out_dir='.')
+ trainer = get_trainer(out_dir=".")
trainer._done = True
- snapshot = extensions.snapshot_object(trainer, 'myfile.dat')
+ snapshot = extensions.snapshot_object(trainer, "myfile.dat")
snapshot(trainer)
- left_tmps = [fn for fn in os.listdir('.')
- if fn.startswith('tmpmyfile.dat')]
+ left_tmps = [fn for fn in os.listdir(".") if fn.startswith("tmpmyfile.dat")]
assert len(left_tmps) == 0
def test_on_error():
# Will fail when accesing the dummy optimizer
- optimizers = {'main': object()}
+ optimizers = {"main": object()}
trainer = training.ExtensionsManager(
- {}, optimizers, 1,
- iters_per_epoch=1,
- out_dir='.')
- filename = 'myfile-deadbeef.dat'
+ {}, optimizers, 1, iters_per_epoch=1, out_dir="."
+ )
+ filename = "myfile-deadbeef.dat"
- snapshot = extensions.snapshot_object(trainer, filename,
- snapshot_on_error=True)
+ snapshot = extensions.snapshot_object(
+ trainer, filename, snapshot_on_error=True
+ )
trainer.extend(snapshot)
assert not os.path.exists(filename)
with pytest.raises(AttributeError):
@@ -178,7 +173,7 @@ def test_on_error():
assert not os.path.exists(filename)
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def path():
with tempfile.TemporaryDirectory() as t_path:
yield t_path
@@ -190,19 +185,22 @@ def snapshot_path():
yield t_path
-@pytest.mark.parametrize('fmt', [
- 'snapshot_iter_{}',
- 'snapshot_iter_{}.npz',
- '{}_snapshot_man_suffix.npz',
-])
+@pytest.mark.parametrize(
+ "fmt",
+ [
+ "snapshot_iter_{}",
+ "snapshot_iter_{}.npz",
+ "{}_snapshot_man_suffix.npz",
+ ],
+)
def test_find_snapshot_files(fmt, path):
files = (fmt.format(i) for i in range(1, 100))
- noise = ('dummy-foobar-iter{}'.format(i) for i in range(10, 304))
- noise2 = ('tmpsnapshot_iter_{}'.format(i) for i in range(10, 304))
+ noise = ("dummy-foobar-iter{}".format(i) for i in range(10, 304))
+ noise2 = ("tmpsnapshot_iter_{}".format(i) for i in range(10, 304))
for file in itertools.chain(noise, files, noise2):
file = os.path.join(path, file)
- open(file, 'w').close()
+ open(file, "w").close()
writer = ppe.writing.SimpleWriter()
snapshot_files = _find_snapshot_files(fmt, path, writer.fs)
@@ -213,18 +211,21 @@ def test_find_snapshot_files(fmt, path):
assert expected == sorted(list(snapshot_files))
-@pytest.mark.parametrize('fmt', [
- 'snapshot_iter_{}',
- 'snapshot_iter_{}.npz',
- '{}_snapshot_man_suffix.npz',
-])
+@pytest.mark.parametrize(
+ "fmt",
+ [
+ "snapshot_iter_{}",
+ "snapshot_iter_{}.npz",
+ "{}_snapshot_man_suffix.npz",
+ ],
+)
def test_find_latest_snapshot(fmt, path):
files = [fmt.format(i) for i in range(1, 100)]
base_timestamp = time.time()
for i, file in enumerate(files):
file = os.path.join(path, file)
- open(file, 'w').close()
+ open(file, "w").close()
# mtime resolution of some filesystems e.g. ext3 or HFS+
# is a second and thus snapshot files such as
@@ -240,27 +241,36 @@ def test_find_latest_snapshot(fmt, path):
assert fmt.format(99) == _find_latest_snapshot(fmt, path, writer.fs)
-@pytest.mark.parametrize('fmt', [
- 'snapshot_iter_{}_{}',
- 'snapshot_iter_{}_{}.npz',
- '{}_snapshot_man_{}-suffix.npz',
- 'snapshot_iter_{}.{}',
-])
+@pytest.mark.parametrize(
+ "fmt",
+ [
+ "snapshot_iter_{}_{}",
+ "snapshot_iter_{}_{}.npz",
+ "{}_snapshot_man_{}-suffix.npz",
+ "snapshot_iter_{}.{}",
+ ],
+)
def test_find_snapshot_files2(fmt, path):
- files = (fmt.format(i * 10, j * 10) for i, j
- in itertools.product(range(0, 10), range(0, 10)))
- noise = ('tmpsnapshot_iter_{}.{}'.format(i, j)
- for i, j in zip(range(10, 304), range(10, 200)))
+ files = (
+ fmt.format(i * 10, j * 10)
+ for i, j in itertools.product(range(0, 10), range(0, 10))
+ )
+ noise = (
+ "tmpsnapshot_iter_{}.{}".format(i, j)
+ for i, j in zip(range(10, 304), range(10, 200))
+ )
for file in itertools.chain(noise, files):
file = os.path.join(path, file)
- open(file, 'w').close()
+ open(file, "w").close()
writer = ppe.writing.SimpleWriter()
snapshot_files = _find_snapshot_files(fmt, path, writer.fs)
- expected = [fmt.format(i * 10, j * 10)
- for i, j in itertools.product(range(0, 10), range(0, 10))]
+ expected = [
+ fmt.format(i * 10, j * 10)
+ for i, j in itertools.product(range(0, 10), range(0, 10))
+ ]
timestamps, snapshot_files = zip(*snapshot_files)
expected.sort()
@@ -268,19 +278,27 @@ def test_find_snapshot_files2(fmt, path):
assert expected == snapshot_files
-@pytest.mark.parametrize('length_retain', [
- (100, 30), (10, 30), (1, 1000),
- (1000, 1), (1, 1), (1, 3), (2, 3),
-])
+@pytest.mark.parametrize(
+ "length_retain",
+ [
+ (100, 30),
+ (10, 30),
+ (1, 1000),
+ (1000, 1),
+ (1, 1),
+ (1, 3),
+ (2, 3),
+ ],
+)
def test_find_stale_snapshot(length_retain, path):
length, retain = length_retain
- fmt = 'snapshot_iter_{}'
+ fmt = "snapshot_iter_{}"
files = [fmt.format(i) for i in range(0, length)]
base_timestamp = time.time() - length * 2
for i, file in enumerate(files):
file = os.path.join(path, file)
- open(file, 'w').close()
+ open(file, "w").close()
# Same comment applies here. See comment in ``TestFindSnapshot``
t = base_timestamp + i
@@ -294,17 +312,18 @@ def test_find_stale_snapshot(length_retain, path):
def test_remove_stale_snapshots(path):
- fmt = 'snapshot_iter_{.iteration}'
+ fmt = "snapshot_iter_{.iteration}"
retain = 3
- snapshot = extensions.snapshot(filename=fmt, n_retains=retain,
- autoload=False)
+ snapshot = extensions.snapshot(
+ filename=fmt, n_retains=retain, autoload=False
+ )
trainer = get_trainer(out_dir=path)
- trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)
+ trainer.extend(snapshot, trigger=(1, "iteration"), priority=2)
class TimeStampUpdater(training.Extension):
t = time.time() - 100
- name = 'ts_updater'
+ name = "ts_updater"
priority = 1 # This must be called after snapshot taken
def __call__(self, _trainer):
@@ -313,7 +332,7 @@ def __call__(self, _trainer):
# For filesystems that does low timestamp precision
os.utime(filename, (self.t, self.t))
- trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
+ trainer.extend(TimeStampUpdater(), trigger=(1, "iteration"))
for _ in range(10):
with trainer.run_iteration():
pass
@@ -324,30 +343,32 @@ def __call__(self, _trainer):
assert retain == len(found)
found.sort()
# snapshot_iter_(8, 9, 10) expected
- expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
+ expected = ["snapshot_iter_{}".format(i) for i in range(8, 11)]
expected.sort()
assert expected == found
- trainer2 = get_trainer(
- out_dir=path, state_to_load=trainer.state_dict())
+ trainer2 = get_trainer(out_dir=path, state_to_load=trainer.state_dict())
snapshot2 = extensions.snapshot(filename=fmt, autoload=True)
# Just making sure no error occurs
snapshot2.initialize(trainer2)
def test_remove_stale_snapshots_with_writer(path, snapshot_path):
- fmt = 'snapshot_iter_{.iteration}'
+ fmt = "snapshot_iter_{.iteration}"
retain = 3
- snapshot = extensions.snapshot(filename=fmt, n_retains=retain,
- writer=ppe.writing.SimpleWriter(out_dir=snapshot_path),
- autoload=False)
+ snapshot = extensions.snapshot(
+ filename=fmt,
+ n_retains=retain,
+ writer=ppe.writing.SimpleWriter(out_dir=snapshot_path),
+ autoload=False,
+ )
trainer = get_trainer(out_dir=path)
- trainer.extend(snapshot, trigger=(1, 'iteration'), priority=2)
+ trainer.extend(snapshot, trigger=(1, "iteration"), priority=2)
class TimeStampUpdater(training.Extension):
t = time.time() - 100
- name = 'ts_updater'
+ name = "ts_updater"
priority = 1 # This must be called after snapshot taken
def __call__(self, _trainer):
@@ -356,7 +377,7 @@ def __call__(self, _trainer):
# For filesystems that does low timestamp precision
os.utime(filename, (self.t, self.t))
- trainer.extend(TimeStampUpdater(), trigger=(1, 'iteration'))
+ trainer.extend(TimeStampUpdater(), trigger=(1, "iteration"))
for _ in range(10):
with trainer.run_iteration():
pass
@@ -371,13 +392,16 @@ def __call__(self, _trainer):
assert retain == len(found)
found.sort()
# snapshot_iter_(8, 9, 10) expected
- expected = ['snapshot_iter_{}'.format(i) for i in range(8, 11)]
+ expected = ["snapshot_iter_{}".format(i) for i in range(8, 11)]
expected.sort()
assert expected == found
- trainer2 = get_trainer(
- out_dir=path, state_to_load=trainer.state_dict())
- snapshot2 = extensions.snapshot(filename=fmt, autoload=True, writer=ppe.writing.SimpleWriter(out_dir=snapshot_path))
+ trainer2 = get_trainer(out_dir=path, state_to_load=trainer.state_dict())
+ snapshot2 = extensions.snapshot(
+ filename=fmt,
+ autoload=True,
+ writer=ppe.writing.SimpleWriter(out_dir=snapshot_path),
+ )
# Just making sure no error occurs
snapshot2.initialize(trainer2)
@@ -408,7 +432,7 @@ def test_model_transformations(path):
transform_model=lambda n, x: x.wrapper_module(),
)
- snapshot = extensions.snapshot(filename='test')
+ snapshot = extensions.snapshot(filename="test")
snapshot(manager)
assert model.accessed
@@ -417,7 +441,7 @@ def test_model_transformations(path):
def test_snapshot_autoload_twice(path):
max_epochs = 10
iters_per_epoch = 10
- fmt = 'snapshot_iter_{.iteration}'
+ fmt = "snapshot_iter_{.iteration}"
def get_epoch_indices():
manager = get_trainer(out_dir=path, epochs=max_epochs)
@@ -461,13 +485,20 @@ def test_snapshot_autoload_with_writer(path, snapshot_path):
trainer = get_trainer(out_dir=path, epochs=10)
trainer.models["main"]._state_dict = {"value": 0}
- snapshot = extensions.snapshot(filename=snapshot_filename, writer=ppe.writing.SimpleWriter(out_dir=snapshot_path))
+ snapshot = extensions.snapshot(
+ filename=snapshot_filename,
+ writer=ppe.writing.SimpleWriter(out_dir=snapshot_path),
+ )
snapshot(trainer)
assert os.path.isfile(os.path.join(snapshot_path, snapshot_filename))
assert not os.path.isfile(os.path.join(path, snapshot_filename))
trainer2 = get_trainer(out_dir=path, epochs=0)
- snapshot2 = extensions.snapshot(filename=snapshot_filename, writer=ppe.writing.SimpleWriter(out_dir=snapshot_path), autoload=True)
+ snapshot2 = extensions.snapshot(
+ filename=snapshot_filename,
+ writer=ppe.writing.SimpleWriter(out_dir=snapshot_path),
+ autoload=True,
+ )
assert trainer2.state_dict() != trainer.state_dict()
assert snapshot2.initialize(trainer2) == snapshot_filename
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py
index 41b9ecfd9..c6578b111 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_snapshot_writers.py
@@ -1,14 +1,12 @@
import multiprocessing
-import threading
import tempfile
+import threading
from unittest import mock
import pytest
-
from pytorch_pfn_extras import writing
-
-spshot_writers_path = 'pytorch_pfn_extras.writing'
+spshot_writers_path = "pytorch_pfn_extras.writing"
def test_simple_writer():
@@ -16,10 +14,10 @@ def test_simple_writer():
w = writing.SimpleWriter(foo=True)
savefun = mock.MagicMock()
with tempfile.TemporaryDirectory() as tempd:
- w('myfile.dat', tempd, target, savefun=savefun)
+ w("myfile.dat", tempd, target, savefun=savefun)
assert savefun.call_count == 1
assert savefun.call_args[0][0] == target
- assert savefun.call_args[1]['foo'] is True
+ assert savefun.call_args[1]["foo"] is True
def test_standard_writer():
@@ -27,11 +25,11 @@ def test_standard_writer():
w = writing.StandardWriter()
worker = mock.MagicMock()
worker.exitcode = 0
- name = spshot_writers_path + '.StandardWriter.create_worker'
+ name = spshot_writers_path + ".StandardWriter.create_worker"
with mock.patch(name, return_value=worker):
with tempfile.TemporaryDirectory() as tempd:
- w('myfile.dat', tempd, target)
- w('myfile.dat', tempd, target)
+ w("myfile.dat", tempd, target)
+ w("myfile.dat", tempd, target)
w.finalize()
assert worker.start.call_count == 2
@@ -42,16 +40,16 @@ def test_thread_writer_create_worker():
target = mock.MagicMock()
w = writing.ThreadWriter()
with tempfile.TemporaryDirectory() as tempd:
- worker = w.create_worker('myfile.dat', tempd, target, append=False)
+ worker = w.create_worker("myfile.dat", tempd, target, append=False)
assert isinstance(worker, threading.Thread)
- w('myfile2.dat', tempd, 'test')
+ w("myfile2.dat", tempd, "test")
w.finalize()
def test_thread_writer_fail():
w = writing.ThreadWriter(savefun=None)
with tempfile.TemporaryDirectory() as tempd:
- w('myfile2.dat', tempd, 'test')
+ w("myfile2.dat", tempd, "test")
with pytest.raises(RuntimeError):
w.finalize()
@@ -60,16 +58,16 @@ def test_process_writer_create_worker():
target = mock.MagicMock()
w = writing.ProcessWriter()
with tempfile.TemporaryDirectory() as tempd:
- worker = w.create_worker('myfile.dat', tempd, target, append=False)
+ worker = w.create_worker("myfile.dat", tempd, target, append=False)
assert isinstance(worker, multiprocessing.Process)
- w('myfile2.dat', tempd, 'test')
+ w("myfile2.dat", tempd, "test")
w.finalize()
def test_process_writer_fail():
w = writing.ProcessWriter(savefun=None)
with tempfile.TemporaryDirectory() as tempd:
- w('myfile2.dat', tempd, 'test')
+ w("myfile2.dat", tempd, "test")
with pytest.raises(RuntimeError):
w.finalize()
@@ -78,15 +76,17 @@ def test_queue_writer():
target = mock.MagicMock()
q = mock.MagicMock()
consumer = mock.MagicMock()
- names = [spshot_writers_path + '.QueueWriter.create_queue',
- spshot_writers_path + '.QueueWriter.create_consumer']
+ names = [
+ spshot_writers_path + ".QueueWriter.create_queue",
+ spshot_writers_path + ".QueueWriter.create_consumer",
+ ]
with mock.patch(names[0], return_value=q):
with mock.patch(names[1], return_value=consumer):
w = writing.QueueWriter()
with tempfile.TemporaryDirectory() as tempd:
- w('myfile.dat', tempd, target)
- w('myfile.dat', tempd, target)
+ w("myfile.dat", tempd, target)
+ w("myfile.dat", tempd, target)
w.finalize()
assert consumer.start.call_count == 1
@@ -96,8 +96,10 @@ def test_queue_writer():
def test_queue_writer_consume():
- names = [spshot_writers_path + '.QueueWriter.create_queue',
- spshot_writers_path + '.QueueWriter.create_consumer']
+ names = [
+ spshot_writers_path + ".QueueWriter.create_queue",
+ spshot_writers_path + ".QueueWriter.create_consumer",
+ ]
with mock.patch(names[0]):
with mock.patch(names[1]):
task = mock.MagicMock()
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py
index 3f1bb0296..cba6dd4be 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_value_observation.py
@@ -1,30 +1,25 @@
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
def test_observe_value():
lr = 0.1
- manager = ppe.training.ExtensionsManager(
- {}, [], 1,
- iters_per_epoch=1)
- extension = ppe.training.extensions.observe_value('lr', lambda x: lr)
+ manager = ppe.training.ExtensionsManager({}, [], 1, iters_per_epoch=1)
+ extension = ppe.training.extensions.observe_value("lr", lambda x: lr)
manager.extend(extension)
with manager.run_iteration():
pass
- assert manager.observation['lr'] == lr
+ assert manager.observation["lr"] == lr
def test_observe_lr():
lr = 0.01
- manager = ppe.training.ExtensionsManager(
- {}, [], 1,
- iters_per_epoch=1)
+ manager = ppe.training.ExtensionsManager({}, [], 1, iters_per_epoch=1)
optimizer = torch.optim.Adam({torch.nn.Parameter()}, lr=lr)
extension = ppe.training.extensions.observe_lr(optimizer)
manager.extend(extension)
with manager.run_iteration():
pass
- assert manager.observation['lr'] == lr
+ assert manager.observation["lr"] == lr
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py
index e6dad4b9b..3b433815a 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_variable_statistics_plot.py
@@ -2,32 +2,31 @@
import numpy
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
@pytest.fixture(scope="module")
def matplotlib():
try:
import matplotlib
- matplotlib.use('Agg')
+
+ matplotlib.use("Agg")
return matplotlib
except ImportError:
- pytest.skip('matplotlib is not installed')
+ pytest.skip("matplotlib is not installed")
def test_run_and_save_plot(matplotlib):
- filename = 'variable_statistics_plot_test.png'
+ filename = "variable_statistics_plot_test.png"
iterations = 2
- extension_trigger = (1, 'iteration')
- manager = ppe.training.ExtensionsManager(
- {}, [], 2,
- iters_per_epoch=1)
+ extension_trigger = (1, "iteration")
+ manager = ppe.training.ExtensionsManager({}, [], 2, iters_per_epoch=1)
x = torch.rand(1, 2, 3)
extension = ppe.training.extensions.VariableStatisticsPlot(
- x, trigger=extension_trigger, filename=filename)
+ x, trigger=extension_trigger, filename=filename
+ )
manager.extend(extension, trigger=extension_trigger)
# In the following we explicitly use plot_report._available instead of
@@ -47,12 +46,11 @@ def test_reservoir_size():
shape = (2, 7, 3)
n = 5
reservoir_size = 3
- xs = [
- 2 * torch.rand(shape) - 1 for i in range(n)]
+ xs = [2 * torch.rand(shape) - 1 for i in range(n)]
- reservoir = (
- ppe.training.extensions.variable_statistics_plot.Reservoir(
- size=reservoir_size, data_shape=shape))
+ reservoir = ppe.training.extensions.variable_statistics_plot.Reservoir(
+ size=reservoir_size, data_shape=shape
+ )
for x in xs:
reservoir.add(x)
idxs, data = reservoir.get_data()
@@ -68,20 +66,23 @@ def test_statistician_percentile():
shape = (2, 7, 3)
x = 2 * torch.rand(shape) - 1
- percentile_sigmas = (0., 100.) # min, max
+ percentile_sigmas = (0.0, 100.0) # min, max
statistician = (
ppe.training.extensions.variable_statistics_plot.Statistician(
- collect_mean=True, collect_std=True,
- percentile_sigmas=percentile_sigmas))
+ collect_mean=True,
+ collect_std=True,
+ percentile_sigmas=percentile_sigmas,
+ )
+ )
stat = statistician(x, axis=None, dtype=x.dtype)
for s in stat.values():
assert s.dtype == x.dtype
- assert torch.allclose(stat['mean'], torch.mean(x))
- assert torch.allclose(stat['std'], torch.std(x))
+ assert torch.allclose(stat["mean"], torch.mean(x))
+ assert torch.allclose(stat["std"], torch.std(x))
- percentile = stat['percentile']
+ percentile = stat["percentile"]
assert len(percentile) == 2
assert torch.allclose(percentile[0], torch.min(x))
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py b/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py
index 0ddaaf776..f0f1b49be 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_engine.py
@@ -5,7 +5,8 @@
class TestEngine:
def test_engine_extension(self):
engine = ppe.training._trainer.Trainer(
- None, evaluator=None, models={}, optimizers={}, max_epochs=10)
+ None, evaluator=None, models={}, optimizers={}, max_epochs=10
+ )
extension = ppe.training.extensions.LogReport()
engine.extend(extension)
# Create the actual manager object
@@ -15,26 +16,32 @@ def test_engine_extension(self):
def test_engine_state_dict(self):
manager = ppe.training.ExtensionsManager(
- {}, {}, 10, iters_per_epoch=100)
+ {}, {}, 10, iters_per_epoch=100
+ )
engine = ppe.training._trainer.Trainer(
- None, evaluator=None, models={}, optimizers={}, max_epochs=10)
+ None, evaluator=None, models={}, optimizers={}, max_epochs=10
+ )
engine._setup_manager(100)
assert engine.state_dict() == manager.state_dict()
def test_engine_load_state_dict(self):
manager = ppe.training.ExtensionsManager(
- {}, {}, 10, iters_per_epoch=100)
+ {}, {}, 10, iters_per_epoch=100
+ )
engine = ppe.training._trainer.Trainer(
- None, evaluator=None, models={}, optimizers={}, max_epochs=1)
+ None, evaluator=None, models={}, optimizers={}, max_epochs=1
+ )
engine.load_state_dict(manager.state_dict())
engine._setup_manager(20)
assert engine.state_dict() == manager.state_dict()
def test_engine_load_state_dict_2(self):
manager = ppe.training.ExtensionsManager(
- {}, {}, 10, iters_per_epoch=100)
+ {}, {}, 10, iters_per_epoch=100
+ )
engine = ppe.training._trainer.Trainer(
- None, evaluator=None, models={}, optimizers={}, max_epochs=1)
+ None, evaluator=None, models={}, optimizers={}, max_epochs=1
+ )
engine._setup_manager(20)
engine.load_state_dict(manager.state_dict())
assert engine.state_dict() == manager.state_dict()
@@ -42,22 +49,29 @@ def test_engine_load_state_dict_2(self):
class TestEngineInvalid:
def test_engine_wrong_models(self):
- with pytest.raises(ValueError, match='model must be an instance'):
+ with pytest.raises(ValueError, match="model must be an instance"):
ppe.training._trainer.Trainer(
- None, evaluator=None, models=object(), optimizers={}, max_epochs=10)
+ None,
+ evaluator=None,
+ models=object(),
+ optimizers={},
+ max_epochs=10,
+ )
def test_engine_not_started(self):
engine = ppe.training._trainer.Trainer(
- None, evaluator=None, models={}, optimizers={}, max_epochs=10)
- with pytest.raises(RuntimeError, match='is not started'):
+ None, evaluator=None, models={}, optimizers={}, max_epochs=10
+ )
+ with pytest.raises(RuntimeError, match="is not started"):
engine.state_dict()
- with pytest.raises(RuntimeError, match='is not started'):
+ with pytest.raises(RuntimeError, match="is not started"):
engine.manager
def test_extend_after_init(self):
engine = ppe.training._trainer.Trainer(
- None, evaluator=None, models={}, optimizers={}, max_epochs=10)
+ None, evaluator=None, models={}, optimizers={}, max_epochs=10
+ )
engine._setup_manager(10)
extension = ppe.training.extensions.LogReport()
- with pytest.raises(RuntimeError, match='cannot extend after'):
+ with pytest.raises(RuntimeError, match="cannot extend after"):
engine.extend(extension)
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py b/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py
index 111b3dc59..43180affe 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_evaluator_metrics.py
@@ -1,8 +1,6 @@
import pytest
-
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
from pytorch_pfn_extras import engine
@@ -15,25 +13,28 @@ def forward(self, x, t):
g = t.clone()
to_alter = int(10 * (1 - self.correct_ratio))
g[0:to_alter][:] -= 1
- return {'y': g}
+ return {"y": g}
-@pytest.mark.parametrize('device', ['cpu'])
-@pytest.mark.parametrize('accuracy', [0, 0.5, 1.0])
+@pytest.mark.parametrize("device", ["cpu"])
+@pytest.mark.parametrize("accuracy", [0, 0.5, 1.0])
def test_evaluator_with_metric(device, accuracy):
model = MyModel(accuracy)
data = torch.utils.data.DataLoader(
- [{'x': torch.rand(20), 't': torch.rand(1)} for i in range(10)],
- batch_size=10)
+ [{"x": torch.rand(20), "t": torch.rand(1)} for i in range(10)],
+ batch_size=10,
+ )
ppe.to(model, device)
evaluator = engine.create_evaluator(
- model, device=device,
- metrics=[ppe.training.metrics.AccuracyMetric('t', 'y')],
- options={'eval_report_keys': ['accuracy']})
+ model,
+ device=device,
+ metrics=[ppe.training.metrics.AccuracyMetric("t", "y")],
+ options={"eval_report_keys": ["accuracy"]},
+ )
evaluator.handler.eval_setup(evaluator, data)
reporter = ppe.reporting.Reporter()
observation = {}
with reporter.scope(observation):
evaluator.run(data)
- assert pytest.approx(observation['val/accuracy']) == accuracy
+ assert pytest.approx(observation["val/accuracy"]) == accuracy
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py b/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py
index e81beb433..f6f73e2a3 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_extension.py
@@ -1,15 +1,14 @@
-import pytest
-import torch
-
from unittest import mock
+import pytest
import pytorch_pfn_extras as ppe
+import torch
def _get_dummy_manager():
model = torch.nn.Module()
return ppe.training.ExtensionsManager(
- {'main': model},
+ {"main": model},
[], # optimizers
10, # max_epochs
iters_per_epoch=1,
@@ -31,7 +30,7 @@ class MyExtension(ppe.training.Extension):
pass
ext = MyExtension()
- assert ext.default_name == 'MyExtension'
+ assert ext.default_name == "MyExtension"
def test_deleted_invoke_before_training():
@@ -46,13 +45,17 @@ class MyExtension(ppe.training.Extension):
def test_make_extension():
initialize = mock.Mock()
- @ppe.training.make_extension(trigger=(2, 'epoch'), default_name='my_ext',
- priority=50, initializer=initialize)
+ @ppe.training.make_extension(
+ trigger=(2, "epoch"),
+ default_name="my_ext",
+ priority=50,
+ initializer=initialize,
+ )
def my_extension(trainer):
pass
- assert my_extension.trigger == (2, 'epoch')
- assert my_extension.default_name == 'my_ext'
+ assert my_extension.trigger == (2, "epoch")
+ assert my_extension.default_name == "my_ext"
assert my_extension.priority == 50
trainer = object()
@@ -66,8 +69,8 @@ def test_make_extension_default_values():
def my_extension(trainer):
pass
- assert my_extension.trigger == (1, 'iteration')
- assert my_extension.default_name == 'my_extension'
+ assert my_extension.trigger == (1, "iteration")
+ assert my_extension.default_name == "my_extension"
assert my_extension.priority == ppe.training.PRIORITY_READER
manager = object()
my_extension.initialize(manager)
@@ -75,6 +78,7 @@ def my_extension(trainer):
def test_make_extension_unexpected_kwargs():
with pytest.raises(TypeError):
+
@ppe.training.make_extension(foo=1)
def my_extension(_):
pass
@@ -92,9 +96,10 @@ def on_error(self, manager, exc, tb):
assert isinstance(exc, RuntimeError)
self.call_cnt += 1
- optimizers = {'main': object()}
+ optimizers = {"main": object()}
manager = ppe.training.ExtensionsManager(
- {}, optimizers, 1, iters_per_epoch=2)
+ {}, optimizers, 1, iters_per_epoch=2
+ )
ext = DummyExt()
manager.extend(ext)
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py b/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py
index f4801e7e2..87cafe79b 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_extension_entry.py
@@ -1,12 +1,11 @@
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
def _get_dummy_manager():
model = torch.nn.Module()
return ppe.training.ExtensionsManager(
- {'main': model},
+ {"main": model},
[], # optimizers
10, # max_epochs
iters_per_epoch=1,
@@ -16,25 +15,25 @@ def _get_dummy_manager():
def test_default_name():
class MyExtension(ppe.training.Extension):
name = None
- default_name = 'defalut_name'
+ default_name = "defalut_name"
ext = MyExtension()
entry = ppe.training.ExtensionEntry(ext)
assert entry.name == MyExtension.default_name
- entry = ppe.training.ExtensionEntry(ext, name='updated')
- assert entry.name == 'updated'
+ entry = ppe.training.ExtensionEntry(ext, name="updated")
+ assert entry.name == "updated"
def test_name():
class MyExtension(ppe.training.Extension):
- name = 'name'
- default_name = 'defalut_name'
+ name = "name"
+ default_name = "defalut_name"
ext = MyExtension()
entry = ppe.training.ExtensionEntry(ext)
assert entry.name == MyExtension.name
- entry = ppe.training.ExtensionEntry(ext, name='updated')
- assert entry.name == 'updated'
+ entry = ppe.training.ExtensionEntry(ext, name="updated")
+ assert entry.name == "updated"
def test_priority():
@@ -50,14 +49,14 @@ class MyExtension(ppe.training.Extension):
def test_trigger():
class MyExtension(ppe.training.Extension):
- trigger = (1, 'iteration')
+ trigger = (1, "iteration")
ext = MyExtension()
entry = ppe.training.ExtensionEntry(ext)
assert isinstance(entry.trigger, ppe.training.triggers.IntervalTrigger)
assert entry.trigger.period == 1
- assert entry.trigger.unit == 'iteration'
- entry = ppe.training.ExtensionEntry(ext, trigger=(3, 'epoch'))
+ assert entry.trigger.unit == "iteration"
+ entry = ppe.training.ExtensionEntry(ext, trigger=(3, "epoch"))
assert isinstance(entry.trigger, ppe.training.triggers.IntervalTrigger)
assert entry.trigger.period == 3
- assert entry.trigger.unit == 'epoch'
+ assert entry.trigger.unit == "epoch"
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py b/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py
index 8cafab62b..f0aa1b9a7 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_manager.py
@@ -1,18 +1,13 @@
import pytest
-
-import torch
-from torch import nn
-
import pytorch_pfn_extras as ppe
+import torch
from pytorch_pfn_extras import training
+from torch import nn
def test_manager_status_info():
manager = training.ExtensionsManager(
- nn.Module(),
- object(),
- 10,
- iters_per_epoch=4
+ nn.Module(), object(), 10, iters_per_epoch=4
)
manager.iteration = 9
assert manager.iteration == 9
@@ -25,9 +20,7 @@ def test_manager_status_info():
class _DummyExtension(training.Extension):
-
- def __init__(
- self, extension_id, call_record, init_record, use_model=False):
+ def __init__(self, extension_id, call_record, init_record, use_model=False):
self.extension_id = extension_id
self.call_record = call_record
self.init_record = init_record
@@ -37,12 +30,11 @@ def __call__(self, manager):
self.call_record.append(self.extension_id)
if self.use_model:
# Check if models are accessible.
- assert manager.models['main'] is not None
- assert manager.raw_models['main'] is not None
+ assert manager.models["main"] is not None
+ assert manager.raw_models["main"] is not None
class _DummyExtensionInitialize(_DummyExtension):
-
def initialize(self, manager):
self.init_record.append(self.extension_id)
@@ -53,8 +45,8 @@ def test_extensions_manager_extensions():
max_epochs = 5
iters_per_epoch = 4
manager = training.ExtensionsManager(
- {'model_name': model},
- {'optimizer_name': optimizer},
+ {"model_name": model},
+ {"optimizer_name": optimizer},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
@@ -72,33 +64,37 @@ def test_extensions_manager_extensions():
lambda manager: dummy5(manager),
training.ExtensionEntry(
_DummyExtension(6, call_record, init_record),
- name='ext6', priority=-3, call_before_training=True,
+ name="ext6",
+ priority=-3,
+ call_before_training=True,
),
training.ExtensionEntry(
_DummyExtension(7, call_record, init_record),
- name='ext7_', priority=-2, call_before_training=True,
+ name="ext7_",
+ priority=-2,
+ call_before_training=True,
),
]
- manager.extend(exts[0], 'ext0', priority=2, call_before_training=True)
- manager.extend(exts[1], 'ext1', priority=1, call_before_training=False)
- manager.extend(exts[2], 'ext2', priority=3, call_before_training=False)
- manager.extend(exts[3], 'ext3', priority=0, call_before_training=True)
- manager.extend(exts[4], 'ext4', priority=4, call_before_training=True)
- manager.extend(exts[5], 'ext5', priority=-1, call_before_training=True)
+ manager.extend(exts[0], "ext0", priority=2, call_before_training=True)
+ manager.extend(exts[1], "ext1", priority=1, call_before_training=False)
+ manager.extend(exts[2], "ext2", priority=3, call_before_training=False)
+ manager.extend(exts[3], "ext3", priority=0, call_before_training=True)
+ manager.extend(exts[4], "ext4", priority=4, call_before_training=True)
+ manager.extend(exts[5], "ext5", priority=-1, call_before_training=True)
manager.extend(exts[6])
- manager.extend(exts[7], 'ext7', priority=-4, call_before_training=False)
+ manager.extend(exts[7], "ext7", priority=-4, call_before_training=False)
- assert manager.get_extension('ext0') is exts[0]
- assert manager.get_extension('ext1') is exts[1]
- assert manager.get_extension('ext2') is exts[2]
- assert manager.get_extension('ext3') is exts[3]
- assert manager.get_extension('ext4') is exts[4]
- assert manager.get_extension('ext6') is exts[6].extension
- assert manager.get_extension('ext7') is exts[7].extension
+ assert manager.get_extension("ext0") is exts[0]
+ assert manager.get_extension("ext1") is exts[1]
+ assert manager.get_extension("ext2") is exts[2]
+ assert manager.get_extension("ext3") is exts[3]
+ assert manager.get_extension("ext4") is exts[4]
+ assert manager.get_extension("ext6") is exts[6].extension
+ assert manager.get_extension("ext7") is exts[7].extension
with pytest.raises(ValueError):
- manager.get_extension('ext10')
+ manager.get_extension("ext10")
for it in range(max_epochs * iters_per_epoch):
call_record.clear()
@@ -121,7 +117,7 @@ def test_extensions_manager_extensions():
assert init_record == []
-class _StateDictObj():
+class _StateDictObj:
def __init__(self, *, state_dict=None, state_dict_to_be_loaded=None):
super().__init__()
self.called_load_state_dict = 0
@@ -137,13 +133,11 @@ def load_state_dict(self, state_dict):
class _StateDictModel(_StateDictObj, nn.Module):
-
def forward(self, *args):
pass
class _StateDictOptimizer(_StateDictObj):
-
def zero_grad(self):
pass
@@ -152,7 +146,6 @@ def step(self):
class _StateDictExtension(_StateDictObj, training.Extension):
-
def __call__(self, manager):
pass
@@ -170,15 +163,16 @@ def test_extensions_manager_state_dict():
passed_iteration = 11
manager = training.ExtensionsManager(
- {'model_name': _StateDictModel(state_dict=model_state_dict)},
- {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)},
+ {"model_name": _StateDictModel(state_dict=model_state_dict)},
+ {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
manager.extend(
- _StateDictExtension(
- state_dict=extension_state_dict), name='extension_name')
+ _StateDictExtension(state_dict=extension_state_dict),
+ name="extension_name",
+ )
for _ in range(passed_iteration):
with manager.run_iteration():
@@ -187,29 +181,31 @@ def test_extensions_manager_state_dict():
state_dict = manager.state_dict()
assert state_dict == {
- 'ppe_version': ppe.__version__,
- '_start_execution': passed_iteration,
- '_start_iteration': passed_iteration,
- 'models': {'model_name': model_state_dict},
- 'optimizers': {'optimizer_name': optimizer_state_dict},
- 'extensions': {'extension_name': {
- 'extension': extension_state_dict,
- 'trigger': {
- },
- }},
+ "ppe_version": ppe.__version__,
+ "_start_execution": passed_iteration,
+ "_start_iteration": passed_iteration,
+ "models": {"model_name": model_state_dict},
+ "optimizers": {"optimizer_name": optimizer_state_dict},
+ "extensions": {
+ "extension_name": {
+ "extension": extension_state_dict,
+ "trigger": {},
+ }
+ },
}
new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
new_extension = _StateDictExtension(
- state_dict_to_be_loaded=extension_state_dict)
+ state_dict_to_be_loaded=extension_state_dict
+ )
new_manager = training.ExtensionsManager(
- {'model_name': new_model},
- {'optimizer_name': new_optimizer},
+ {"model_name": new_model},
+ {"optimizer_name": new_optimizer},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
- new_manager.extend(new_extension, name='extension_name')
+ new_manager.extend(new_extension, name="extension_name")
new_manager.load_state_dict(state_dict)
assert new_model.called_load_state_dict == 1
assert new_optimizer.called_load_state_dict == 1
@@ -224,8 +220,8 @@ def test_extensions_manager_state_dict_old_ppe_no_version():
passed_iteration = 11
manager = training.ExtensionsManager(
- {'model_name': _StateDictModel(state_dict=model_state_dict)},
- {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)},
+ {"model_name": _StateDictModel(state_dict=model_state_dict)},
+ {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
@@ -237,15 +233,15 @@ def test_extensions_manager_state_dict_old_ppe_no_version():
new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
manager_2 = training.ExtensionsManager(
- {'model_name': new_model},
- {'optimizer_name': new_optimizer},
+ {"model_name": new_model},
+ {"optimizer_name": new_optimizer},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
state_dict = manager.state_dict()
- del state_dict['ppe_version']
- with pytest.warns(UserWarning, match='version'):
+ del state_dict["ppe_version"]
+ with pytest.warns(UserWarning, match="version"):
manager_2.load_state_dict(state_dict)
@@ -257,8 +253,8 @@ def test_extensions_manager_state_dict_old_ppe_version():
passed_iteration = 11
manager = training.ExtensionsManager(
- {'model_name': _StateDictModel(state_dict=model_state_dict)},
- {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)},
+ {"model_name": _StateDictModel(state_dict=model_state_dict)},
+ {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
@@ -270,15 +266,15 @@ def test_extensions_manager_state_dict_old_ppe_version():
new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
manager_2 = training.ExtensionsManager(
- {'model_name': new_model},
- {'optimizer_name': new_optimizer},
+ {"model_name": new_model},
+ {"optimizer_name": new_optimizer},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
state_dict = manager.state_dict()
- state_dict['ppe_version'] = '0.4.0'
- with pytest.warns(UserWarning, match='version'):
+ state_dict["ppe_version"] = "0.4.0"
+ with pytest.warns(UserWarning, match="version"):
manager_2.load_state_dict(state_dict)
@@ -290,8 +286,8 @@ def test_extensions_manager_state_dict_future_ppe_version():
passed_iteration = 11
manager = training.ExtensionsManager(
- {'model_name': _StateDictModel(state_dict=model_state_dict)},
- {'optimizer_name': _StateDictObj(state_dict=optimizer_state_dict)},
+ {"model_name": _StateDictModel(state_dict=model_state_dict)},
+ {"optimizer_name": _StateDictObj(state_dict=optimizer_state_dict)},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
@@ -303,24 +299,23 @@ def test_extensions_manager_state_dict_future_ppe_version():
new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
new_optimizer = _StateDictObj(state_dict_to_be_loaded=optimizer_state_dict)
manager_2 = training.ExtensionsManager(
- {'model_name': new_model},
- {'optimizer_name': new_optimizer},
+ {"model_name": new_model},
+ {"optimizer_name": new_optimizer},
max_epochs,
iters_per_epoch=iters_per_epoch,
)
state_dict = manager.state_dict()
- state_dict['ppe_version'] = '23.0.0'
- with pytest.warns(UserWarning, match='version'):
+ state_dict["ppe_version"] = "23.0.0"
+ with pytest.warns(UserWarning, match="version"):
manager_2.load_state_dict(state_dict)
def test_ignite_extensions_manager_state_dict():
-
try:
from ignite.engine import create_supervised_trainer
except ImportError:
- pytest.skip('pytorch-ignite not found')
+ pytest.skip("pytorch-ignite not found")
model_state_dict = object()
optimizer_state_dict = object()
@@ -332,54 +327,57 @@ def test_ignite_extensions_manager_state_dict():
model = _StateDictModel(state_dict=model_state_dict)
optimizer = _StateDictOptimizer(state_dict=optimizer_state_dict)
- trainer = create_supervised_trainer(
- model, optimizer, _fake_loss)
+ trainer = create_supervised_trainer(model, optimizer, _fake_loss)
manager = training.IgniteExtensionsManager(
trainer,
- {'model_name': model},
- {'optimizer_name': optimizer},
+ {"model_name": model},
+ {"optimizer_name": optimizer},
max_epochs,
)
manager.extend(
- _StateDictExtension(
- state_dict=extension_state_dict), name='extension_name')
+ _StateDictExtension(state_dict=extension_state_dict),
+ name="extension_name",
+ )
loader = torch.utils.data.DataLoader(
- [(i, i) for i in range(iters_per_epoch)])
+ [(i, i) for i in range(iters_per_epoch)]
+ )
trainer.run(loader, max_epochs=max_epochs)
state_dict = manager.state_dict()
assert state_dict == {
- 'ppe_version': ppe.__version__,
- '_start_execution': passed_iteration,
- '_start_iteration': passed_iteration,
- '_epoch_length': iters_per_epoch,
- 'models': {'model_name': model_state_dict},
- 'optimizers': {'optimizer_name': optimizer_state_dict},
- 'extensions': {'extension_name': {
- 'extension': extension_state_dict,
- 'trigger': {
- },
- }},
+ "ppe_version": ppe.__version__,
+ "_start_execution": passed_iteration,
+ "_start_iteration": passed_iteration,
+ "_epoch_length": iters_per_epoch,
+ "models": {"model_name": model_state_dict},
+ "optimizers": {"optimizer_name": optimizer_state_dict},
+ "extensions": {
+ "extension_name": {
+ "extension": extension_state_dict,
+ "trigger": {},
+ }
+ },
}
new_model = _StateDictModel(state_dict_to_be_loaded=model_state_dict)
new_optimizer = _StateDictOptimizer(
- state_dict_to_be_loaded=optimizer_state_dict)
+ state_dict_to_be_loaded=optimizer_state_dict
+ )
new_extension = _StateDictExtension(
- state_dict_to_be_loaded=extension_state_dict)
+ state_dict_to_be_loaded=extension_state_dict
+ )
- new_trainer = create_supervised_trainer(
- model, optimizer, _fake_loss)
+ new_trainer = create_supervised_trainer(model, optimizer, _fake_loss)
new_manager = training.IgniteExtensionsManager(
new_trainer,
- {'model_name': new_model},
- {'optimizer_name': new_optimizer},
+ {"model_name": new_model},
+ {"optimizer_name": new_optimizer},
max_epochs,
)
- new_manager.extend(new_extension, name='extension_name')
+ new_manager.extend(new_extension, name="extension_name")
new_manager.load_state_dict(state_dict)
assert new_model.called_load_state_dict == 1
assert new_optimizer.called_load_state_dict == 1
@@ -401,12 +399,12 @@ def test_extensions_manager_with_plain_model_and_optimizer():
state_dict = manager.state_dict()
assert state_dict == {
- 'ppe_version': ppe.__version__,
- '_start_execution': 0,
- '_start_iteration': 0,
- 'models': {'main': model_state_dict},
- 'optimizers': {'main': optimizer_state_dict},
- 'extensions': {}
+ "ppe_version": ppe.__version__,
+ "_start_execution": 0,
+ "_start_iteration": 0,
+ "models": {"main": model_state_dict},
+ "optimizers": {"main": optimizer_state_dict},
+ "extensions": {},
}
@@ -435,7 +433,7 @@ def test_model_transformations():
transform_model=lambda n, x: x.wrapper_module(),
)
- assert not isinstance(manager.models['main'], Wrapper)
+ assert not isinstance(manager.models["main"], Wrapper)
assert model.accessed
@@ -466,7 +464,7 @@ def test_model_transformations_in_state_dict():
transform_model=lambda n, x: x.wrapper_module(),
)
new_manager.load_state_dict(state_dict)
- assert isinstance(new_manager.models['main'], _StateDictModel)
+ assert isinstance(new_manager.models["main"], _StateDictModel)
def test_call_optimizers():
@@ -479,9 +477,9 @@ def test_call_optimizers():
1,
iters_per_epoch=1,
)
- with manager.run_iteration(step_optimizers=['main']):
+ with manager.run_iteration(step_optimizers=["main"]):
a.grad = torch.tensor([2.0])
- assert torch.equal(a.detach(), torch.tensor([-1.]))
+ assert torch.equal(a.detach(), torch.tensor([-1.0]))
def test_needs_state_this_iteration():
@@ -489,15 +487,11 @@ def test_needs_state_this_iteration():
a = torch.ones(1, requires_grad=True)
optimizer = torch.optim.SGD(lr=1.0, params=[a])
extension = _DummyExtension(0, [], [], True)
- extension.name = 'Dummy'
+ extension.name = "Dummy"
extension.needs_model_state = True
- extension.trigger = (50, 'iteration')
+ extension.trigger = (50, "iteration")
manager = training.ExtensionsManager(
- m,
- optimizer,
- 1,
- iters_per_epoch=100,
- extensions=[extension]
+ m, optimizer, 1, iters_per_epoch=100, extensions=[extension]
)
while not manager.stop_trigger:
with manager.run_iteration():
@@ -509,27 +503,26 @@ def test_needs_state_this_iteration():
assert not manager.needs_state_this_iteration()
-@pytest.mark.parametrize('priority', [
- None,
- training.extension.PRIORITY_SNAPSHOT,
- training.PRIORITY_WRITER,
-])
+@pytest.mark.parametrize(
+ "priority",
+ [
+ None,
+ training.extension.PRIORITY_SNAPSHOT,
+ training.PRIORITY_WRITER,
+ ],
+)
def test_extensions_accessing_models_without_flag(priority):
m = torch.nn.Linear(5, 5)
a = torch.ones(1, requires_grad=True)
optimizer = torch.optim.SGD(lr=1.0, params=[a])
extension = _DummyExtension(0, [], [], True)
- extension.name = 'Dummy'
+ extension.name = "Dummy"
extension.needs_model_state = False
- extension.trigger = (1, 'iteration')
+ extension.trigger = (1, "iteration")
if priority is not None:
extension.priority = priority
manager = training.ExtensionsManager(
- m,
- optimizer,
- 1,
- iters_per_epoch=5,
- extensions=[extension]
+ m, optimizer, 1, iters_per_epoch=5, extensions=[extension]
)
while not manager.stop_trigger:
with pytest.raises(RuntimeError):
@@ -549,9 +542,10 @@ def __call__(self, manager):
def finalize(self, manager):
self.finalized = True
- optimizers = {'main': object()}
+ optimizers = {"main": object()}
manager = ppe.training.ExtensionsManager(
- {}, optimizers, 1, iters_per_epoch=1)
+ {}, optimizers, 1, iters_per_epoch=1
+ )
ext = DummyExt()
manager.extend(ext)
@@ -565,5 +559,5 @@ def finalize(self, manager):
pass
-if __name__ == '__main__':
- pytest.main([__file__, '-v', '-s'])
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "-s"])
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py
index 564c7cec5..4cfd57f87 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py
@@ -1,20 +1,17 @@
import os
import tempfile
import typing
+from unittest import mock
import pytest
-
+import pytorch_pfn_extras as ppe
import torch
+from pytorch_pfn_extras import engine, training
from torch import nn
from torch.nn import functional as F
-from unittest import mock
-
-import pytorch_pfn_extras as ppe
-from pytorch_pfn_extras import engine
-from pytorch_pfn_extras import training
-@pytest.fixture(scope='function')
+@pytest.fixture(scope="function")
def path():
with tempfile.TemporaryDirectory() as t_path:
yield t_path
@@ -45,45 +42,59 @@ def __init__(self, model):
def forward(self, x, t):
y = self.model(x)
- prefix = 'train' if self.training else 'val'
+ prefix = "train" if self.training else "val"
loss = F.l1_loss(y, t)
- ppe.reporting.report({prefix + '/loss': loss})
+ ppe.reporting.report({prefix + "/loss": loss})
return loss
def _make_extensions():
return [
- training.extensions.LogReport(trigger=(10, 'iteration')),
+ training.extensions.LogReport(trigger=(10, "iteration")),
training.extensions.ProgressBar(update_interval=2),
training.extensions.PrintReport(
[
- 'epoch',
- 'iteration',
- 'train/loss',
- 'val/loss',
- 'val/accuracy',
- 'elapsed_time',
- 'time',
+ "epoch",
+ "iteration",
+ "train/loss",
+ "val/loss",
+ "val/accuracy",
+ "elapsed_time",
+ "time",
]
),
]
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_trainer(device, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, extensions=extensions,
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ extensions=extensions,
out_dir=path,
)
trainer.run(data)
@@ -94,12 +105,26 @@ def test_trainer_no_to(path):
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device='cpu', extensions=extensions,
+ model_with_loss,
+ optimizer,
+ 20,
+ device="cpu",
+ extensions=extensions,
out_dir=path,
)
with pytest.raises(RuntimeError, match="ppe.to"):
@@ -107,44 +132,63 @@ def test_trainer_no_to(path):
def test_trainer_invalid_options(path):
- device = 'cpu'
+ device = "cpu"
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
extensions = _make_extensions()
- options = {'UNKNOWN_OPTIONS': True}
+ options = {"UNKNOWN_OPTIONS": True}
with pytest.raises(ValueError, match="UNKNOWN_OPTIONS"):
engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, extensions=extensions,
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ extensions=extensions,
out_dir=path,
options=options,
)
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
-@pytest.mark.parametrize('progress_bar', [True, False])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+@pytest.mark.parametrize("progress_bar", [True, False])
def test_train_with_evaluator(device, progress_bar, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar)
+ model_with_loss, device=device, progress_bar=progress_bar
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=evaluator, extensions=extensions,
- out_dir=path
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
)
- mpath = 'pytorch_pfn_extras.training._evaluator.Evaluator.run'
+ mpath = "pytorch_pfn_extras.training._evaluator.Evaluator.run"
with mock.patch(mpath) as patched:
trainer.run(data, data)
assert patched.call_count == 20
@@ -155,68 +199,120 @@ def test_train_with_evaluator(device, progress_bar, path):
[(20, (1, "epoch")), (40, (5, "iteration"))],
)
def test_evaluator_trigger(evaluator_trigger, path):
- device = 'cpu'
+ device = "cpu"
progress_bar = False
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar)
+ model_with_loss, device=device, progress_bar=progress_bar
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=(evaluator, evaluator_trigger[1]),
- extensions=extensions, out_dir=path
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=(evaluator, evaluator_trigger[1]),
+ extensions=extensions,
+ out_dir=path,
)
- path = 'pytorch_pfn_extras.training._evaluator.Evaluator.run'
+ path = "pytorch_pfn_extras.training._evaluator.Evaluator.run"
with mock.patch(path) as patched:
trainer.run(data, data)
assert patched.call_count == evaluator_trigger[0]
def test_evaluator_dict(path):
- device = 'cpu'
+ device = "cpu"
progress_bar = False
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator1 = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar)
+ model_with_loss, device=device, progress_bar=progress_bar
+ )
evaluator2 = engine.create_evaluator(
- model, device=device, progress_bar=progress_bar)
+ model, device=device, progress_bar=progress_bar
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator={
- '1': evaluator1, # called 20 times.
- '2': (evaluator2, (5, 'iteration')), # called 40 times.
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator={
+ "1": evaluator1, # called 20 times.
+ "2": (evaluator2, (5, "iteration")), # called 40 times.
},
- extensions=extensions, out_dir=path
+ extensions=extensions,
+ out_dir=path,
)
- path = 'pytorch_pfn_extras.training._evaluator.Evaluator.run'
+ path = "pytorch_pfn_extras.training._evaluator.Evaluator.run"
with mock.patch(path) as patched:
trainer.run(data, data)
assert patched.call_count == 20 + 40
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_train_result_equal(device, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
train_data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
data = torch.utils.data.DataLoader(
- [(torch.rand(20,),) for i in range(30)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ )
+ for i in range(30)
+ ]
+ )
def get_result_from_trainer():
model = MyModel()
@@ -226,9 +322,12 @@ def get_result_from_trainer():
extensions = _make_extensions()
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, extensions=extensions,
- out_dir=path
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ extensions=extensions,
+ out_dir=path,
)
trainer.run(train_data)
@@ -293,14 +392,17 @@ def _compare_states(s1, s2):
class TestTrainerState:
def _get_trainer(self, epochs, out_dir):
model = MyModel()
- ppe.to(model, 'cpu')
+ ppe.to(model, "cpu")
model_with_loss = MyModelWithLossFn(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
extensions = _make_extensions()
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device='cpu', extensions=extensions,
- out_dir=out_dir
+ model_with_loss,
+ optimizer,
+ 20,
+ device="cpu",
+ extensions=extensions,
+ out_dir=out_dir,
)
return trainer
@@ -308,7 +410,18 @@ def test_trainer_state(self, path):
torch.manual_seed(0)
trainer = self._get_trainer(20, path)
data = torch.utils.data.DataLoader(
- [(torch.ones(20,), torch.ones(10,)) for i in range(10)])
+ [
+ (
+ torch.ones(
+ 20,
+ ),
+ torch.ones(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
trainer.run(data)
# State to be compared to
state = trainer.state_dict()
@@ -324,7 +437,18 @@ def test_trainer_state(self, path):
def test_trainer_autoload(self, path):
trainer = self._get_trainer(20, path)
data = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(10)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(10)
+ ]
+ )
trainer.extend(ppe.training.extensions.snapshot())
trainer.run(data)
@@ -333,8 +457,7 @@ def test_trainer_autoload(self, path):
# This forces engine initialization
new_trainer._setup_manager(len(data))
assert new_trainer.epoch == 20
- assert _compare_states(
- trainer.state_dict(), new_trainer.state_dict())
+ assert _compare_states(trainer.state_dict(), new_trainer.state_dict())
class MyModelWithLossDictOutput(torch.nn.Module):
@@ -344,32 +467,48 @@ def __init__(self, model):
def forward(self, x, t):
y = self.model(x)
- prefix = 'train' if self.training else 'val'
+ prefix = "train" if self.training else "val"
loss = F.l1_loss(y, t)
- ppe.reporting.report({prefix + '/loss': loss})
- return {'y': y, 'loss': loss}
+ ppe.reporting.report({prefix + "/loss": loss})
+ return {"y": y, "loss": loss}
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
-@pytest.mark.parametrize('progress_bar', [True, False])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+@pytest.mark.parametrize("progress_bar", [True, False])
def test_trainer_dict_input(device, progress_bar, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = MyModelWithLossDictOutput(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)])
+ [
+ {
+ "x": torch.rand(
+ 20,
+ ),
+ "t": torch.rand(
+ 10,
+ ),
+ }
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar)
+ model_with_loss, device=device, progress_bar=progress_bar
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=evaluator, extensions=extensions,
- out_dir=path
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
)
trainer.run(data, data)
@@ -393,65 +532,103 @@ def __init__(self, model):
def forward(self, input):
y = self.model(input.x)
- prefix = 'train' if self.training else 'val'
+ prefix = "train" if self.training else "val"
loss = F.l1_loss(y, input.t)
- ppe.reporting.report({prefix + '/loss': loss})
+ ppe.reporting.report({prefix + "/loss": loss})
return Output(y, loss, input.v)
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
-@pytest.mark.parametrize('progress_bar', [True, False])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+@pytest.mark.parametrize("progress_bar", [True, False])
def test_trainer_namedtuple_input(device, progress_bar, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
ppe.to(model, device)
model_with_loss = ModelNamedTupleIO(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [Input(torch.rand(20,), torch.rand(10,), str(i)) for i in range(10)])
+ [
+ Input(
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ str(i),
+ )
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar)
+ model_with_loss, device=device, progress_bar=progress_bar
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=evaluator, extensions=extensions,
- out_dir=path
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
)
trainer.run(data, data)
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
-@pytest.mark.parametrize('progress_bar', [True, False])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+@pytest.mark.parametrize("progress_bar", [True, False])
def test_trainer_with_code_block(device, progress_bar, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
model_with_loss = MyModelWithLossDictOutput(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)])
+ [
+ {
+ "x": torch.rand(
+ 20,
+ ),
+ "t": torch.rand(
+ 10,
+ ),
+ }
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar,
- logic=ppe.handler.CodeBlockLogic())
+ model_with_loss,
+ device=device,
+ progress_bar=progress_bar,
+ logic=ppe.handler.CodeBlockLogic(),
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=evaluator, extensions=extensions,
- out_dir=path, logic=ppe.handler.CodeBlockLogic()
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
+ logic=ppe.handler.CodeBlockLogic(),
)
trainer.run(data, data)
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
-@pytest.mark.parametrize('progress_bar', [True, False])
-def test_trainer_with_code_block_with_multiple_optimizers(device, progress_bar, path):
- if not torch.cuda.is_available() and device == 'cuda':
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+@pytest.mark.parametrize("progress_bar", [True, False])
+def test_trainer_with_code_block_with_multiple_optimizers(
+ device, progress_bar, path
+):
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
model_with_loss = MyModelWithLossDictOutput(model)
@@ -459,37 +636,66 @@ def test_trainer_with_code_block_with_multiple_optimizers(device, progress_bar,
optimizer0 = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)])
+ [
+ {
+ "x": torch.rand(
+ 20,
+ ),
+ "t": torch.rand(
+ 10,
+ ),
+ }
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar,
- logic=ppe.handler.CodeBlockLogic())
+ model_with_loss,
+ device=device,
+ progress_bar=progress_bar,
+ logic=ppe.handler.CodeBlockLogic(),
+ )
trainer = engine.create_trainer(
- model_with_loss, {"0": optimizer0, "1": optimizer1}, 20,
- device=device, evaluator=evaluator, extensions=extensions,
- out_dir=path, logic=ppe.handler.CodeBlockLogic()
+ model_with_loss,
+ {"0": optimizer0, "1": optimizer1},
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
+ logic=ppe.handler.CodeBlockLogic(),
)
trainer.run(data, data)
@pytest.mark.skipif(
- os.name == 'nt' and not ppe.requires("1.9"),
- reason='torch.profiler.profile is not supported.',
+ os.name == "nt" and not ppe.requires("1.9"),
+ reason="torch.profiler.profile is not supported.",
)
def test_trainer_profile():
- device = 'cpu'
+ device = "cpu"
model = MyModel()
model_with_loss = MyModelWithLossDictOutput(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)])
+ [
+ {
+ "x": torch.rand(
+ 20,
+ ),
+ "t": torch.rand(
+ 10,
+ ),
+ }
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
- evaluator = engine.create_evaluator(
- model_with_loss, device=device)
+ evaluator = engine.create_evaluator(model_with_loss, device=device)
trace_handler = mock.Mock()
warmup = 1
@@ -500,35 +706,58 @@ def test_trainer_profile():
schedule=torch.profiler.schedule(wait=0, warmup=warmup, active=active),
)
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=evaluator, extensions=extensions,
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
profile=profile,
)
trainer.run(data, data)
assert trace_handler.call_count == 20 # n_epochs
-@pytest.mark.parametrize('device', ['cpu', 'cuda'])
-@pytest.mark.parametrize('progress_bar', [True, False])
+@pytest.mark.parametrize("device", ["cpu", "cuda"])
+@pytest.mark.parametrize("progress_bar", [True, False])
def test_trainer_with_clousure_logic(device, progress_bar, path):
- if not torch.cuda.is_available() and device == 'cuda':
+ if not torch.cuda.is_available() and device == "cuda":
pytest.skip()
model = MyModel()
model_with_loss = MyModelWithLossFn(model)
ppe.to(model_with_loss, device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
data = torch.utils.data.DataLoader(
- [{'x': torch.rand(20,), 't': torch.rand(10,)} for i in range(10)])
+ [
+ {
+ "x": torch.rand(
+ 20,
+ ),
+ "t": torch.rand(
+ 10,
+ ),
+ }
+ for i in range(10)
+ ]
+ )
extensions = _make_extensions()
evaluator = engine.create_evaluator(
- model_with_loss, device=device, progress_bar=progress_bar,
- logic=ppe.handler.ClousureLogic())
+ model_with_loss,
+ device=device,
+ progress_bar=progress_bar,
+ logic=ppe.handler.ClousureLogic(),
+ )
trainer = engine.create_trainer(
- model_with_loss, optimizer, 20,
- device=device, evaluator=evaluator, extensions=extensions,
- out_dir=path, logic=ppe.handler.ClousureLogic(options={"backward_outputs": ["loss"]})
+ model_with_loss,
+ optimizer,
+ 20,
+ device=device,
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
+ logic=ppe.handler.ClousureLogic(options={"backward_outputs": ["loss"]}),
)
trainer.run(data, data)
@@ -544,11 +773,16 @@ def test_trainer_with_autocast(path):
extensions = []
autocast_options = {"autocast": True}
evaluator = engine.create_evaluator(
- model_with_loss, device="cuda",
- options=autocast_options)
+ model_with_loss, device="cuda", options=autocast_options
+ )
engine.create_trainer(
- model_with_loss, optimizer, 20,
- device="cuda", evaluator=evaluator, extensions=extensions,
- out_dir=path, options=autocast_options
+ model_with_loss,
+ optimizer,
+ 20,
+ device="cuda",
+ evaluator=evaluator,
+ extensions=extensions,
+ out_dir=path,
+ options=autocast_options,
)
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py
index 700c0e478..1ca566a8e 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trigger_util.py
@@ -1,34 +1,38 @@
import pytest
-
from pytorch_pfn_extras import training
-from pytorch_pfn_extras.training import _trigger_util
-from pytorch_pfn_extras.training import triggers
+from pytorch_pfn_extras.training import _trigger_util, triggers
@pytest.mark.parametrize(
- 'iters_per_epoch,trigger_args,expected',
+ "iters_per_epoch,trigger_args,expected",
[
# Never fire trigger
(2, None, [False, False, False, False, False, False, False]),
-
# Interval trigger
- (2, (2, 'iteration'),
- [False, True, False, True, False, True, False]),
- (2, (2, 'epoch'),
- [False, False, False, True, False, False, False]),
-
+ (2, (2, "iteration"), [False, True, False, True, False, True, False]),
+ (2, (2, "epoch"), [False, False, False, True, False, False, False]),
# Callable object
- (2, _trigger_util.get_trigger(None),
- [False, False, False, False, False, False, False]),
- (2, triggers.IntervalTrigger(2, 'iteration'),
- [False, True, False, True, False, True, False]),
- (2, (lambda trainer: trainer.iteration == 3),
- [False, False, True, False, False, False, False]),
- ]
+ (
+ 2,
+ _trigger_util.get_trigger(None),
+ [False, False, False, False, False, False, False],
+ ),
+ (
+ 2,
+ triggers.IntervalTrigger(2, "iteration"),
+ [False, True, False, True, False, True, False],
+ ),
+ (
+ 2,
+ (lambda trainer: trainer.iteration == 3),
+ [False, False, True, False, False, False, False],
+ ),
+ ],
)
def test_get_trigger(iters_per_epoch, trigger_args, expected):
trainer = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = _trigger_util.get_trigger(trigger_args)
# before the first iteration, trigger should be False
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py
index 9d8e1e16d..10c539aa6 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_early_stopping_trigger.py
@@ -1,12 +1,10 @@
import numpy
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
def _test_trigger(trigger, key, accuracies, expected):
- manager = ppe.training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=1)
+ manager = ppe.training.ExtensionsManager({}, [], 100, iters_per_epoch=1)
for a, e in zip(accuracies, expected):
with manager.run_iteration():
pass
@@ -15,56 +13,59 @@ def _test_trigger(trigger, key, accuracies, expected):
def test_early_stopping_trigger_with_accuracy():
- key = 'main/accuracy'
+ key = "main/accuracy"
trigger = ppe.training.triggers.EarlyStoppingTrigger(
- monitor=key,
- patience=3,
- check_trigger=(1, 'epoch'),
- verbose=False)
+ monitor=key, patience=3, check_trigger=(1, "epoch"), verbose=False
+ )
accuracies = [
torch.Tensor(numpy.asarray(acc, dtype=numpy.float32))
- for acc in [0.5, 0.5, 0.6, 0.7, 0.6, 0.4, 0.3, 0.2]]
+ for acc in [0.5, 0.5, 0.6, 0.7, 0.6, 0.4, 0.3, 0.2]
+ ]
expected = [False, False, False, False, False, False, True, True]
_test_trigger(trigger, key, accuracies, expected)
def test_early_stopping_trigger_with_loss():
- key = 'main/loss'
+ key = "main/loss"
trigger = ppe.training.triggers.EarlyStoppingTrigger(
- monitor=key,
- patience=3,
- check_trigger=(1, 'epoch'))
+ monitor=key, patience=3, check_trigger=(1, "epoch")
+ )
accuracies = [
torch.Tensor(numpy.asarray(acc, dtype=numpy.float32))
- for acc in [100, 80, 30, 10, 20, 24, 30, 35]]
+ for acc in [100, 80, 30, 10, 20, 24, 30, 35]
+ ]
expected = [False, False, False, False, False, False, True, True]
_test_trigger(trigger, key, accuracies, expected)
def test_early_stopping_trigger_with_max_epoch():
- key = 'main/loss'
+ key = "main/loss"
trigger = ppe.training.triggers.EarlyStoppingTrigger(
monitor=key,
patience=3,
- check_trigger=(1, 'epoch'),
- max_trigger=(3, 'epoch'))
+ check_trigger=(1, "epoch"),
+ max_trigger=(3, "epoch"),
+ )
accuracies = [
torch.Tensor(numpy.asarray(acc, dtype=numpy.float32))
- for acc in [100, 80, 30]]
+ for acc in [100, 80, 30]
+ ]
expected = [False, False, True]
_test_trigger(trigger, key, accuracies, expected)
def test_early_stopping_trigger_with_max_iteration():
- key = 'main/loss'
+ key = "main/loss"
trigger = ppe.training.triggers.EarlyStoppingTrigger(
monitor=key,
patience=3,
- check_trigger=(1, 'epoch'),
- max_trigger=(3, 'iteration'))
+ check_trigger=(1, "epoch"),
+ max_trigger=(3, "iteration"),
+ )
accuracies = [
torch.Tensor(numpy.asarray(acc, dtype=numpy.float32))
- for acc in [100, 80, 30]]
+ for acc in [100, 80, 30]
+ ]
expected = [False, False, True]
_test_trigger(trigger, key, accuracies, expected)
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py
index 9f4630434..bc203b765 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_interval_trigger.py
@@ -1,24 +1,22 @@
import pytest
-
from pytorch_pfn_extras import training
from pytorch_pfn_extras.training import triggers
-
_argvalues = [
# iteration
- (5, (2, 'iteration'), [False, True, False, True, False, True, False], 4),
+ (5, (2, "iteration"), [False, True, False, True, False, True, False], 4),
# basic epoch
- (1, (3, 'epoch'), [False, False, True, False, False, True, False], 4),
+ (1, (3, "epoch"), [False, False, True, False, False, True, False], 4),
# fractional epoch
- (2, (1.5, 'epoch'), [False, False, True, False, False, True, False], 4),
+ (2, (1.5, "epoch"), [False, False, True, False, False, True, False], 4),
]
-@pytest.mark.parametrize(
- 'iters_per_epoch,interval,expected,resume', _argvalues)
+@pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues)
def test_trigger(iters_per_epoch, interval, expected, resume):
trainer = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = triggers.IntervalTrigger(*interval)
for e in expected:
@@ -28,11 +26,11 @@ def test_trigger(iters_per_epoch, interval, expected, resume):
assert trigger(trainer) == e
-@pytest.mark.parametrize(
- 'iters_per_epoch,interval,expected,resume', _argvalues)
+@pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues)
def test_resumed_trigger(iters_per_epoch, interval, expected, resume):
trainer = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = triggers.IntervalTrigger(*interval)
for e in expected[:resume]:
@@ -52,12 +50,11 @@ def test_resumed_trigger(iters_per_epoch, interval, expected, resume):
assert new_trigger(trainer) == e
-@pytest.mark.parametrize(
- 'iters_per_epoch,interval,expected,resume', _argvalues)
+@pytest.mark.parametrize("iters_per_epoch,interval,expected,resume", _argvalues)
def test_str(iters_per_epoch, interval, expected, resume):
trigger = triggers.IntervalTrigger(*interval)
- expected = 'IntervalTrigger({}, \'{}\')'.format(*interval)
+ expected = "IntervalTrigger({}, '{}')".format(*interval)
actual = str(trigger)
assert expected == actual, 'Expected "{}" == "{}"'.format(expected, actual)
@@ -65,4 +62,4 @@ def test_str(iters_per_epoch, interval, expected, resume):
def test_invalid_unit():
with pytest.raises(ValueError):
- triggers.IntervalTrigger(1, 'day')
+ triggers.IntervalTrigger(1, "day")
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py
index b72ca885c..ec307f43d 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_minmax_value_trigger.py
@@ -1,7 +1,6 @@
import pytest
-
-from pytorch_pfn_extras.training import triggers
from pytorch_pfn_extras import training
+from pytorch_pfn_extras.training import triggers
def _test_trigger(manager, trigger, key, accuracies, expected):
@@ -17,80 +16,136 @@ def _compare(best_value, new_value):
_trigger_test_params = [
# interval = 1 iterations
- (triggers.MaxValueTrigger, ((1, 'iteration'),), 1,
- [0.5, 0.5, 0.4, 0.6], [True, False, False, True], 1),
- (triggers.MinValueTrigger, ((1, 'iteration'),), 1,
- [0.5, 0.5, 0.4, 0.6], [True, False, True, False], 1),
+ (
+ triggers.MaxValueTrigger,
+ ((1, "iteration"),),
+ 1,
+ [0.5, 0.5, 0.4, 0.6],
+ [True, False, False, True],
+ 1,
+ ),
+ (
+ triggers.MinValueTrigger,
+ ((1, "iteration"),),
+ 1,
+ [0.5, 0.5, 0.4, 0.6],
+ [True, False, True, False],
+ 1,
+ ),
# interval = 2 iterations
- (triggers.MaxValueTrigger, ((2, 'iteration'),), 1,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, False, False, True], 2),
- (triggers.MinValueTrigger, ((2, 'iteration'),), 1,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, True, False, False], 2),
+ (
+ triggers.MaxValueTrigger,
+ ((2, "iteration"),),
+ 1,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, False, False, True],
+ 2,
+ ),
+ (
+ triggers.MinValueTrigger,
+ ((2, "iteration"),),
+ 1,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, True, False, False],
+ 2,
+ ),
# interval = 2 iterations, unaligned resume
- (triggers.MaxValueTrigger, ((2, 'iteration'),), 1,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, False, False, True], 3),
- (triggers.MinValueTrigger, ((2, 'iteration'),), 1,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, True, False, False], 3),
+ (
+ triggers.MaxValueTrigger,
+ ((2, "iteration"),),
+ 1,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, False, False, True],
+ 3,
+ ),
+ (
+ triggers.MinValueTrigger,
+ ((2, "iteration"),),
+ 1,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, True, False, False],
+ 3,
+ ),
# interval = 1 epoch, 1 epoch = 2 iterations
- (triggers.MaxValueTrigger, ((1, 'epoch'),), 2,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, False, False, True], 2),
- (triggers.MinValueTrigger, ((1, 'epoch'),), 2,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, True, False, False], 2),
+ (
+ triggers.MaxValueTrigger,
+ ((1, "epoch"),),
+ 2,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, False, False, True],
+ 2,
+ ),
+ (
+ triggers.MinValueTrigger,
+ ((1, "epoch"),),
+ 2,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, True, False, False],
+ 2,
+ ),
# interval = 1 epoch, 1 epoch = 2 iterations, unaligned resume
- (triggers.MaxValueTrigger, ((1, 'epoch'),), 2,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, False, False, True], 3),
- (triggers.MinValueTrigger, ((1, 'epoch'),), 2,
- [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
- [False, True, False, False, False, True, False, False], 3),
-
+ (
+ triggers.MaxValueTrigger,
+ ((1, "epoch"),),
+ 2,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, False, False, True],
+ 3,
+ ),
+ (
+ triggers.MinValueTrigger,
+ ((1, "epoch"),),
+ 2,
+ [0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.6, 0.6],
+ [False, True, False, False, False, True, False, False],
+ 3,
+ ),
# best_value trigger test
- (triggers.BestValueTrigger, (_compare, (1, 'iteration')), 2,
- [0.5, -0.5, -0.6, 0.6, 0.4, -0.4, -0.3, 0.3],
- [True, False, False, False, True, False, True, False], 3),
+ (
+ triggers.BestValueTrigger,
+ (_compare, (1, "iteration")),
+ 2,
+ [0.5, -0.5, -0.6, 0.6, 0.4, -0.4, -0.3, 0.3],
+ [True, False, False, False, True, False, True, False],
+ 3,
+ ),
]
@pytest.mark.parametrize(
- 'trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume',
- _trigger_test_params
+ "trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume",
+ _trigger_test_params,
)
def test_trigger(
- trigger_type, trigger_args, iters_per_epoch, accuracies, expected,
- resume):
- key = 'main/accuracy'
+ trigger_type, trigger_args, iters_per_epoch, accuracies, expected, resume
+):
+ key = "main/accuracy"
manager = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = trigger_type(key, *trigger_args)
- _test_trigger(
- manager, trigger, key, accuracies, expected)
+ _test_trigger(manager, trigger, key, accuracies, expected)
@pytest.mark.parametrize(
- 'trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume',
- _trigger_test_params
+ "trigger_type,trigger_args,iters_per_epoch,accuracies,expected,resume",
+ _trigger_test_params,
)
def test_resumed_trigger(
- trigger_type, trigger_args, iters_per_epoch, accuracies, expected,
- resume):
- key = 'main/accuracy'
+ trigger_type, trigger_args, iters_per_epoch, accuracies, expected, resume
+):
+ key = "main/accuracy"
manager = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = trigger_type(key, *trigger_args)
- _test_trigger(
- manager, trigger, key, accuracies[:resume],
- expected[:resume])
+ _test_trigger(manager, trigger, key, accuracies[:resume], expected[:resume])
state = trigger.state_dict()
new_trigger = trigger_type(key, *trigger_args)
new_trigger.load_state_dict(state)
_test_trigger(
- manager, new_trigger, key, accuracies[resume:], expected[resume:])
+ manager, new_trigger, key, accuracies[resume:], expected[resume:]
+ )
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py
index 5b27a05e4..66404ceb4 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_once_trigger.py
@@ -1,17 +1,17 @@
import random
-import pytest
+import pytest
import pytorch_pfn_extras as ppe
-
_parametrize = pytest.mark.parametrize(
- 'iters_per_epoch,call_on_resume,resume',
+ "iters_per_epoch,call_on_resume,resume",
[
# basic
(5, False, 4),
# call on resume
(5, True, 4),
- ])
+ ],
+)
@_parametrize
@@ -20,8 +20,8 @@ def test_trigger(iters_per_epoch, call_on_resume, resume):
expected = [True] + [False] * 6
finished = [False] + [True] * 6
manager = ppe.training.ExtensionsManager(
- {}, [], 100,
- iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = ppe.training.triggers.OnceTrigger(call_on_resume)
for e, f in zip(expected, finished):
assert trigger.finished == f
@@ -38,8 +38,8 @@ def test_resumed_trigger(iters_per_epoch, call_on_resume, resume):
expected[resume] = True
finished[resume] = False
manager = ppe.training.ExtensionsManager(
- {}, [], 100,
- iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = ppe.training.triggers.OnceTrigger(call_on_resume)
for e, f in zip(expected[:resume], finished[:resume]):
with manager.run_iteration():
@@ -64,8 +64,8 @@ def test_trigger_sparse_call(iters_per_epoch, call_on_resume, resume):
finished = [False] + [True] * 6
for _ in range(10):
manager = ppe.training.ExtensionsManager(
- {}, [], 100,
- iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = ppe.training.triggers.OnceTrigger(call_on_resume)
accumulated = False
accumulated_finished = True
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py
index d06858b4a..9d19eeb20 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_schedule_trigger.py
@@ -1,25 +1,30 @@
import pytest
-
from pytorch_pfn_extras import training
from pytorch_pfn_extras.training import triggers
-
_scheduled_trigger_test_params = [
# single iteration
- (2, (2, 'iteration'),
- [False, True, False, False, False, False, False], 3),
+ (2, (2, "iteration"), [False, True, False, False, False, False, False], 3),
# multiple iteration
- (2, ([2, 4], 'iteration'),
- [False, True, False, True, False, False, False], 3),
+ (
+ 2,
+ ([2, 4], "iteration"),
+ [False, True, False, True, False, False, False],
+ 3,
+ ),
# single epoch
- (3, (1, 'epoch'), [False, False, True, False, False, False, False], 3),
+ (3, (1, "epoch"), [False, False, True, False, False, False, False], 3),
# multiple epoch
- (3, ([1, 2], 'epoch'), [False, False, True, False, False, True, False], 4),
+ (3, ([1, 2], "epoch"), [False, False, True, False, False, True, False], 4),
# single fractional epoch
- (2, (1.5, 'epoch'), [False, False, True, False, False, False, False], 4),
+ (2, (1.5, "epoch"), [False, False, True, False, False, False, False], 4),
# multiple fractional epoch
- (2, ([1.5, 2.5], 'epoch'),
- [False, False, True, False, True, False, False], 4),
+ (
+ 2,
+ ([1.5, 2.5], "epoch"),
+ [False, False, True, False, True, False, False],
+ 4,
+ ),
# TODO(imanishi): Restore these tests after supported.
# # single unaligned epoch
# (2.5, (1, 'epoch'), [False, False, True, False, False, False, False], 4),
@@ -42,30 +47,27 @@ def _test_trigger(trainer, trigger, expected):
@pytest.mark.parametrize(
- 'iters_per_epoch,schedule,expected,resume',
- _scheduled_trigger_test_params
+ "iters_per_epoch,schedule,expected,resume", _scheduled_trigger_test_params
)
def test_trigger(iters_per_epoch, schedule, expected, resume):
trainer = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = triggers.ManualScheduleTrigger(*schedule)
_test_trigger(trainer, trigger, expected)
@pytest.mark.parametrize(
- 'iters_per_epoch,schedule,expected,resume',
- _scheduled_trigger_test_params
+ "iters_per_epoch,schedule,expected,resume", _scheduled_trigger_test_params
)
-def test_resumed_trigger(
- iters_per_epoch, schedule, expected, resume):
+def test_resumed_trigger(iters_per_epoch, schedule, expected, resume):
trainer = training.ExtensionsManager(
- {}, [], 100, iters_per_epoch=iters_per_epoch)
+ {}, [], 100, iters_per_epoch=iters_per_epoch
+ )
trigger = triggers.ManualScheduleTrigger(*schedule)
- _test_trigger(
- trainer, trigger,
- expected[:resume])
+ _test_trigger(trainer, trigger, expected[:resume])
state = trigger.state_dict()
new_trigger = triggers.ManualScheduleTrigger(*schedule)
@@ -76,4 +78,4 @@ def test_resumed_trigger(
def test_invalid_unit():
with pytest.raises(ValueError):
- triggers.ManualScheduleTrigger(1, 'day')
+ triggers.ManualScheduleTrigger(1, "day")
diff --git a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py
index 6f7c76624..cf3efb694 100644
--- a/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py
+++ b/tests/pytorch_pfn_extras_tests/training_tests/triggers_tests/test_time_trigger.py
@@ -2,7 +2,6 @@
class DummyTrainer:
-
def __init__(self):
self.elapsed_time = 0
diff --git a/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py b/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py
index 8b44c2eb8..0a51a6fe8 100644
--- a/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py
+++ b/tests/pytorch_pfn_extras_tests/utils_tests/test_checkpoint.py
@@ -1,7 +1,6 @@
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
class SubNet(torch.nn.Module):
@@ -27,9 +26,9 @@ def __init__(self, checkpoint_type):
def forward(self, x):
x = self.bn1(self.conv1(x)).relu()
- if self.checkpoint_type == 'none':
+ if self.checkpoint_type == "none":
x = self.part1(x)
- elif self.checkpoint_type == 'bnaware':
+ elif self.checkpoint_type == "bnaware":
x = ppe.utils.checkpoint.checkpoint(self.part1, x)
x = self.part2(x)
@@ -58,7 +57,7 @@ def _get_bn_stats_test_checkpoint(cp_type):
@pytest.mark.gpu
def test_checkpoint():
- baseline = _get_bn_stats_test_checkpoint('none')
- ckpt = _get_bn_stats_test_checkpoint('bnaware')
+ baseline = _get_bn_stats_test_checkpoint("none")
+ ckpt = _get_bn_stats_test_checkpoint("bnaware")
for p_b, p_c in zip(baseline, ckpt):
assert torch.allclose(p_b, p_c)
diff --git a/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py b/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py
index afcac9f0e..9ac934f30 100644
--- a/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py
+++ b/tests/pytorch_pfn_extras_tests/utils_tests/test_comparer.py
@@ -1,9 +1,8 @@
import typing
import pytest
-import torch
-
import pytorch_pfn_extras as ppe
+import torch
class Model(torch.nn.Module):
@@ -42,13 +41,15 @@ def _get_trainer_with_evaluator(device, ret_val, model_class=Model):
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
evaluator = ppe.engine.create_evaluator(model, device=device)
trainer = ppe.engine.create_trainer(
- model, optimizer, 1, device=device, evaluator=evaluator)
+ model, optimizer, 1, device=device, evaluator=evaluator
+ )
return trainer
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_every_iter(engine_fn):
engine_cpu = engine_fn("cpu", 1.0)
engine_gpu = engine_fn("cuda:0", 1.0)
@@ -66,8 +67,9 @@ def test_compare_every_iter(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_comparer_wrong(engine_fn):
engine_cpu = engine_fn("cpu", 1.0)
engine_gpu = engine_fn("cuda:0", 0.5)
@@ -101,8 +103,9 @@ def __call__(self, eng_name_1, eng_name_2, out_name, out_1, out_2):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_comparer_n_iters(engine_fn):
engine_cpu = engine_fn("cpu", 1.0)
engine_gpu = engine_fn("cuda:0", 1.0)
@@ -119,7 +122,8 @@ def test_comparer_n_iters(engine_fn):
eval_2 = list(torch.ones(10) for _ in range(10))
comp.compare(
{"cpu": (train_1, eval_1), "gpu": (train_2, eval_2)},
- n_iters=n_iters)
+ n_iters=n_iters,
+ )
assert comp.compare_fn.times_called == 6
else:
comp.compare({"cpu": train_1, "gpu": train_2}, n_iters=n_iters)
@@ -127,14 +131,16 @@ def test_comparer_n_iters(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_comparer_kwargs(engine_fn):
engine_cpu = engine_fn("cpu", 1.0)
engine_gpu = engine_fn("cuda:0", 0.991)
compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2)
comp = ppe.utils.comparer.OutputsComparer(
- {"cpu": engine_cpu, "gpu": engine_gpu}, "a",
+ {"cpu": engine_cpu, "gpu": engine_gpu},
+ "a",
compare_fn=compare_fn,
)
train_1 = list(torch.ones(10) for _ in range(10))
@@ -150,22 +156,29 @@ def test_comparer_kwargs(engine_fn):
@pytest.mark.gpu
def test_comparer_incompat_trigger():
model_cpu = Model("cpu", 1.0)
- ppe.to(model_cpu, 'cpu')
+ ppe.to(model_cpu, "cpu")
optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=1.0)
trainer_cpu = ppe.engine.create_trainer(
- model_cpu, optimizer_cpu, 1, device="cpu",
+ model_cpu,
+ optimizer_cpu,
+ 1,
+ device="cpu",
)
model_gpu = Model("cuda:0", 1.0)
- ppe.to(model_gpu, 'cuda:0')
+ ppe.to(model_gpu, "cuda:0")
optimizer_gpu = torch.optim.SGD(model_gpu.parameters(), lr=1.0)
trainer_gpu = ppe.engine.create_trainer(
- model_gpu, optimizer_gpu, 1, device="cuda:0",
+ model_gpu,
+ optimizer_gpu,
+ 1,
+ device="cuda:0",
stop_trigger=(1, "iteration"),
)
comp = ppe.utils.comparer.OutputsComparer(
- {"cpu": trainer_cpu, "gpu": trainer_gpu}, "a",
+ {"cpu": trainer_cpu, "gpu": trainer_gpu},
+ "a",
)
train_1 = list(torch.ones(10) for _ in range(10))
train_2 = list(torch.ones(10) for _ in range(10))
@@ -174,13 +187,15 @@ def test_comparer_incompat_trigger():
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_concurrency(engine_fn):
engine_cpu = engine_fn("cpu", 1.0)
engine_gpu = engine_fn("cuda:0", 1.0)
comp = ppe.utils.comparer.OutputsComparer(
- {"cpu": engine_cpu, "gpu": engine_gpu}, "a",
+ {"cpu": engine_cpu, "gpu": engine_gpu},
+ "a",
concurrency=1,
)
train_1 = list(torch.ones(10) for _ in range(10))
@@ -194,13 +209,15 @@ def test_compare_concurrency(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_concurrency_wrong(engine_fn):
engine_cpu = engine_fn("cpu", 1.0)
engine_gpu = engine_fn("cuda:0", 0.5)
comp = ppe.utils.comparer.OutputsComparer(
- {"cpu": engine_cpu, "gpu": engine_gpu}, "a",
+ {"cpu": engine_cpu, "gpu": engine_gpu},
+ "a",
concurrency=1,
)
train_1 = list(torch.ones(10) for _ in range(10))
@@ -232,22 +249,23 @@ def forward(self, x):
def test_model_comparer():
model_cpu = ModelForComparer()
model_gpu = ModelForComparer()
- ppe.to(model_cpu, 'cpu')
- ppe.to(model_gpu, 'cuda:0')
+ ppe.to(model_cpu, "cpu")
+ ppe.to(model_gpu, "cuda:0")
# Make the models to have the same initial weights
model_gpu.load_state_dict(model_cpu.state_dict())
- ppe.to(model_gpu, device='cuda:0')
+ ppe.to(model_gpu, device="cuda:0")
optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01)
trainer_cpu = ppe.engine.create_trainer(
- model_cpu, optimizer_cpu, 1, device='cpu')
+ model_cpu, optimizer_cpu, 1, device="cpu"
+ )
optimizer_gpu = torch.optim.SGD(model_gpu.parameters(), lr=0.01)
trainer_gpu = ppe.engine.create_trainer(
- model_gpu, optimizer_gpu, 1, device='cuda:0')
+ model_gpu, optimizer_gpu, 1, device="cuda:0"
+ )
compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2)
comp = ppe.utils.comparer.ModelComparer(
- {"cpu": trainer_cpu, "gpu": trainer_gpu},
- compare_fn=compare_fn
+ {"cpu": trainer_cpu, "gpu": trainer_gpu}, compare_fn=compare_fn
)
train_1 = list(torch.ones(2, 10, 10, 10) for _ in range(10))
@@ -259,19 +277,20 @@ def test_model_comparer():
def test_model_comparer_invalid():
model_cpu = ModelForComparer()
model_gpu = ModelForComparer()
- ppe.to(model_cpu, 'cpu')
- ppe.to(model_gpu, device='cuda:0')
+ ppe.to(model_cpu, "cpu")
+ ppe.to(model_gpu, device="cuda:0")
optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01)
trainer_cpu = ppe.engine.create_trainer(
- model_cpu, optimizer_cpu, 1, device='cpu')
+ model_cpu, optimizer_cpu, 1, device="cpu"
+ )
optimizer_gpu = torch.optim.SGD(model_gpu.parameters(), lr=0.01)
trainer_gpu = ppe.engine.create_trainer(
- model_gpu, optimizer_gpu, 1, device='cuda:0')
+ model_gpu, optimizer_gpu, 1, device="cuda:0"
+ )
compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2)
comp = ppe.utils.comparer.ModelComparer(
- {"cpu": trainer_cpu, "gpu": trainer_gpu},
- compare_fn=compare_fn
+ {"cpu": trainer_cpu, "gpu": trainer_gpu}, compare_fn=compare_fn
)
train_1 = list(torch.ones(2, 10, 10, 10) for _ in range(10))
@@ -294,12 +313,15 @@ def forward(self, x):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_tuple_output(engine_fn):
engine_cpu = engine_fn("cpu", 1.0, model_class=ModelRetTuple)
engine_gpu = engine_fn("cuda:0", 1.0, model_class=ModelRetTuple)
- comp = ppe.utils.comparer.OutputsComparer({"cpu": engine_cpu, "gpu": engine_gpu})
+ comp = ppe.utils.comparer.OutputsComparer(
+ {"cpu": engine_cpu, "gpu": engine_gpu}
+ )
train_1 = list(torch.ones(10) for _ in range(10))
train_2 = list(torch.ones(10) for _ in range(10))
if engine_fn is _get_trainer_with_evaluator:
@@ -329,12 +351,15 @@ def forward(self, x):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_namedtuple_output(engine_fn):
engine_cpu = engine_fn("cpu", 1.0, model_class=ModelRetNamedTuple)
engine_gpu = engine_fn("cuda:0", 1.0, model_class=ModelRetNamedTuple)
- comp = ppe.utils.comparer.OutputsComparer({"cpu": engine_cpu, "gpu": engine_gpu})
+ comp = ppe.utils.comparer.OutputsComparer(
+ {"cpu": engine_cpu, "gpu": engine_gpu}
+ )
train_1 = list(torch.ones(10) for _ in range(10))
train_2 = list(torch.ones(10) for _ in range(10))
if engine_fn is _get_trainer_with_evaluator:
diff --git a/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py b/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py
index 51532836e..bccda9b34 100644
--- a/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py
+++ b/tests/pytorch_pfn_extras_tests/utils_tests/test_new_comparer.py
@@ -2,10 +2,9 @@
import typing
import pytest
+import pytorch_pfn_extras as ppe
import torch
import torch.nn.functional as F
-
-import pytorch_pfn_extras as ppe
from pytorch_pfn_extras_tests.runtime_tests.test_jit_runtime import JITRuntime
@@ -25,18 +24,28 @@ def forward(self, x):
def _get_trainer(
- model_class, device, args, loader, *,
- seed=0, max_epochs=10, stop_trigger=None):
+ model_class,
+ device,
+ args,
+ loader,
+ *,
+ seed=0,
+ max_epochs=10,
+ stop_trigger=None,
+):
torch.manual_seed(seed)
model = model_class(device, *args)
ppe.to(model, device)
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
trainer = ppe.engine.create_trainer(
- model, optimizer, max_epochs, device=device, stop_trigger=stop_trigger)
+ model, optimizer, max_epochs, device=device, stop_trigger=stop_trigger
+ )
return trainer, (loader,)
-def _get_evaluator(model_class, device, args, loader, *, seed=0, max_epochs=None):
+def _get_evaluator(
+ model_class, device, args, loader, *, seed=0, max_epochs=None
+):
torch.manual_seed(seed)
model = model_class(device, *args)
ppe.to(model, device)
@@ -45,22 +54,35 @@ def _get_evaluator(model_class, device, args, loader, *, seed=0, max_epochs=None
def _get_trainer_with_evaluator(
- model_class, device, args, loader, *,
- seed=0, max_epochs=10, stop_trigger=None):
+ model_class,
+ device,
+ args,
+ loader,
+ *,
+ seed=0,
+ max_epochs=10,
+ stop_trigger=None,
+):
torch.manual_seed(seed)
model = model_class(device, *args)
ppe.to(model, device)
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
evaluator = ppe.engine.create_evaluator(model, device=device)
trainer = ppe.engine.create_trainer(
- model, optimizer, max_epochs, device=device,
- evaluator=evaluator, stop_trigger=stop_trigger)
+ model,
+ optimizer,
+ max_epochs,
+ device=device,
+ evaluator=evaluator,
+ stop_trigger=stop_trigger,
+ )
return trainer, (loader, loader)
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_every_epoch(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)
@@ -72,8 +94,9 @@ def test_compare_every_epoch(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_comparer_wrong(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)
@@ -91,7 +114,10 @@ def __init__(self, n_iters=None):
self.n_iters = n_iters
def __call__(self, eng_name_1, eng_name_2, out_name, out_1, out_2):
- assert out_name in ("output:a", "output:iter",)
+ assert out_name in (
+ "output:a",
+ "output:iter",
+ )
assert eng_name_1 in ("cpu", "gpu")
assert eng_name_1 != eng_name_2
if out_name == "output:iter":
@@ -104,16 +130,22 @@ def __call__(self, eng_name_1, eng_name_2, out_name, out_1, out_2):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_trainer_with_evaluator]
+)
def test_comparer_trigger(engine_fn):
n_iters = 3
loader = list(torch.ones(10) for _ in range(10))
- trainer_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader, max_epochs=1)
- trainer_gpu, loaders_gpu = engine_fn(Model, "cuda:0", [1.0], loader, max_epochs=1)
+ trainer_cpu, loaders_cpu = engine_fn(
+ Model, "cpu", [1.0], loader, max_epochs=1
+ )
+ trainer_gpu, loaders_gpu = engine_fn(
+ Model, "cuda:0", [1.0], loader, max_epochs=1
+ )
compare_fn = _CustomComparer(n_iters)
comp = ppe.utils.comparer.Comparer(
- trigger=(n_iters, "iteration"), compare_fn=compare_fn)
+ trigger=(n_iters, "iteration"), compare_fn=compare_fn
+ )
comp.add_engine("cpu", trainer_cpu, *loaders_cpu)
comp.add_engine("gpu", trainer_gpu, *loaders_gpu)
comp.compare()
@@ -124,8 +156,9 @@ def test_comparer_trigger(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_comparer_kwargs(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)
@@ -138,13 +171,15 @@ def test_comparer_kwargs(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_trainer_with_evaluator]
+)
def test_comparer_incompat_trigger(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
trainer_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)
- trainer_gpu, loaders_gpu = engine_fn(Model, "cuda:0", [1.0], loader,
- stop_trigger=(1, "iteration"))
+ trainer_gpu, loaders_gpu = engine_fn(
+ Model, "cuda:0", [1.0], loader, stop_trigger=(1, "iteration")
+ )
comp = ppe.utils.comparer.Comparer(outputs=["a"])
comp.add_engine("cpu", trainer_cpu, *loaders_cpu)
comp.add_engine("gpu", trainer_gpu, *loaders_gpu)
@@ -153,8 +188,9 @@ def test_comparer_incompat_trigger(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_concurrency(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)
@@ -166,8 +202,9 @@ def test_compare_concurrency(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_concurrency_wrong(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)
@@ -195,8 +232,9 @@ def forward(self, x):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_model_comparer(engine_fn):
loader = list(torch.ones(2, 10, 10, 10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader)
@@ -210,12 +248,17 @@ def test_model_comparer(engine_fn):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_model_comparer_invalid(engine_fn):
loader = list(torch.ones(2, 10, 10, 10) for _ in range(10))
- engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader, seed=0)
- engine_gpu, loaders_gpu = engine_fn(ModelForComparer, "cuda:0", [], loader, seed=1)
+ engine_cpu, loaders_cpu = engine_fn(
+ ModelForComparer, "cpu", [], loader, seed=0
+ )
+ engine_gpu, loaders_gpu = engine_fn(
+ ModelForComparer, "cuda:0", [], loader, seed=1
+ )
comp = ppe.utils.comparer.Comparer(outputs=["a"])
compare_fn = ppe.utils.comparer.get_default_comparer(rtol=1e-2, atol=1e-2)
comp = ppe.utils.comparer.Comparer(compare_fn=compare_fn, params=True)
@@ -239,8 +282,9 @@ def forward(self, x):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_tuple_output(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(ModelRetTuple, "cpu", [1.0], loader)
@@ -270,12 +314,17 @@ def forward(self, x):
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_namedtuple_output(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
- engine_cpu, loaders_cpu = engine_fn(ModelRetNamedTuple, "cpu", [1.0], loader)
- engine_gpu, loaders_gpu = engine_fn(ModelRetNamedTuple, "cuda:0", [1.0], loader)
+ engine_cpu, loaders_cpu = engine_fn(
+ ModelRetNamedTuple, "cpu", [1.0], loader
+ )
+ engine_gpu, loaders_gpu = engine_fn(
+ ModelRetNamedTuple, "cuda:0", [1.0], loader
+ )
comp = ppe.utils.comparer.Comparer()
comp.add_engine("cpu", engine_cpu, *loaders_cpu)
comp.add_engine("gpu", engine_gpu, *loaders_gpu)
@@ -290,76 +339,122 @@ def __init__(self, device, offset):
def forward(self, x, t):
y = self.model(x)
- prefix = 'train' if self.training else 'val'
+ prefix = "train" if self.training else "val"
loss = F.l1_loss(y, t)
- ppe.reporting.report({prefix + '/loss': loss})
+ ppe.reporting.report({prefix + "/loss": loss})
return loss + self.offset
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_jit_runtime_output_comparer(engine_fn):
loader = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(MyModel, "cpu", [0.0], loader)
engine_gpu, loaders_gpu = engine_fn(MyModel, "cuda:0", [0.0], loader)
engine_jit, loaders_jit = engine_fn(MyModel, "jit-cpu", [0.0], loader)
comp = ppe.utils.comparer.Comparer()
- comp.add_engine('cpu', engine_cpu, *loaders_cpu)
- comp.add_engine('gpu', engine_gpu, *loaders_gpu)
- comp.add_engine('jit', engine_jit, *loaders_jit)
+ comp.add_engine("cpu", engine_cpu, *loaders_cpu)
+ comp.add_engine("gpu", engine_gpu, *loaders_gpu)
+ comp.add_engine("jit", engine_jit, *loaders_jit)
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_jit_runtime_output_comparer_invalid(engine_fn):
loader = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(MyModel, "cpu", [0.0], loader)
engine_gpu, loaders_gpu = engine_fn(MyModel, "cuda:0", [0.0], loader)
engine_jit, loaders_jit = engine_fn(MyModel, "jit-cpu", [0.5], loader)
comp = ppe.utils.comparer.Comparer()
- comp.add_engine('cpu', engine_cpu, *loaders_cpu)
- comp.add_engine('gpu', engine_gpu, *loaders_gpu)
- comp.add_engine('jit', engine_jit, *loaders_jit)
+ comp.add_engine("cpu", engine_cpu, *loaders_cpu)
+ comp.add_engine("gpu", engine_gpu, *loaders_gpu)
+ comp.add_engine("jit", engine_jit, *loaders_jit)
with pytest.raises(AssertionError):
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_jit_runtime_model_comparer(engine_fn):
- loader = torch.utils.data.DataLoader([torch.rand(20,) for i in range(100)])
+ loader = torch.utils.data.DataLoader(
+ [
+ torch.rand(
+ 20,
+ )
+ for i in range(100)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader)
engine_gpu, loaders_gpu = engine_fn(ModelForComparer, "cuda:0", [], loader)
engine_jit, loaders_jit = engine_fn(ModelForComparer, "jit-cpu", [], loader)
comp = ppe.utils.comparer.Comparer(params=True)
- comp.add_engine('cpu', engine_cpu, *loaders_cpu)
- comp.add_engine('gpu', engine_gpu, *loaders_gpu)
- comp.add_engine('jit', engine_jit, *loaders_jit)
+ comp.add_engine("cpu", engine_cpu, *loaders_cpu)
+ comp.add_engine("gpu", engine_gpu, *loaders_gpu)
+ comp.add_engine("jit", engine_jit, *loaders_jit)
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_jit_runtime_model_comparer_invalid(engine_fn):
- loader = torch.utils.data.DataLoader([torch.rand(20,) for i in range(100)])
+ loader = torch.utils.data.DataLoader(
+ [
+ torch.rand(
+ 20,
+ )
+ for i in range(100)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
- engine_cpu, loaders_cpu = engine_fn(ModelForComparer, "cpu", [], loader, seed=0)
- engine_gpu, loaders_gpu = engine_fn(ModelForComparer, "cuda:0", [], loader, seed=0)
- engine_jit, loaders_jit = engine_fn(ModelForComparer, "jit-cpu", [], loader, seed=1)
+ engine_cpu, loaders_cpu = engine_fn(
+ ModelForComparer, "cpu", [], loader, seed=0
+ )
+ engine_gpu, loaders_gpu = engine_fn(
+ ModelForComparer, "cuda:0", [], loader, seed=0
+ )
+ engine_jit, loaders_jit = engine_fn(
+ ModelForComparer, "jit-cpu", [], loader, seed=1
+ )
comp = ppe.utils.comparer.Comparer(params=True)
- comp.add_engine('cpu', engine_cpu, *loaders_cpu)
- comp.add_engine('gpu', engine_gpu, *loaders_gpu)
- comp.add_engine('jit', engine_jit, *loaders_jit)
+ comp.add_engine("cpu", engine_cpu, *loaders_cpu)
+ comp.add_engine("gpu", engine_gpu, *loaders_gpu)
+ comp.add_engine("jit", engine_jit, *loaders_jit)
with pytest.raises(AssertionError):
comp.compare()
@@ -373,104 +468,156 @@ def __init__(self, device, intermediate_value):
def forward(self, x, t):
y = self.model(x)
for i in range(5):
- ppe.utils.comparer.intermediate_value('y', y + self.hidden + i)
+ ppe.utils.comparer.intermediate_value("y", y + self.hidden + i)
loss = F.l1_loss(y, t)
return loss
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_intermediate(engine_fn):
loader = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(
- ModelForIntermediateValue, "cpu", [10.0], loader)
+ ModelForIntermediateValue, "cpu", [10.0], loader
+ )
engine_gpu, loaders_gpu = engine_fn(
- ModelForIntermediateValue, "cuda:0", [10.0], loader)
+ ModelForIntermediateValue, "cuda:0", [10.0], loader
+ )
engine_jit, loaders_jit = engine_fn(
- ModelForIntermediateValue, "jit-cpu", [10.0], loader)
+ ModelForIntermediateValue, "jit-cpu", [10.0], loader
+ )
comp = ppe.utils.comparer.Comparer()
- comp.add_engine('cpu', engine_cpu, *loaders_cpu)
- comp.add_engine('gpu', engine_gpu, *loaders_gpu)
- comp.add_engine('jit', engine_jit, *loaders_jit)
+ comp.add_engine("cpu", engine_cpu, *loaders_cpu)
+ comp.add_engine("gpu", engine_gpu, *loaders_gpu)
+ comp.add_engine("jit", engine_jit, *loaders_jit)
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_intermediate_invalid(engine_fn):
loader = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(100)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(100)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(
- ModelForIntermediateValue, "cpu", [10.0], loader)
+ ModelForIntermediateValue, "cpu", [10.0], loader
+ )
engine_gpu, loaders_gpu = engine_fn(
- ModelForIntermediateValue, "cuda:0", [10.0], loader)
+ ModelForIntermediateValue, "cuda:0", [10.0], loader
+ )
engine_jit, loaders_jit = engine_fn(
- ModelForIntermediateValue, "jit-cpu", [11.1], loader)
+ ModelForIntermediateValue, "jit-cpu", [11.1], loader
+ )
comp = ppe.utils.comparer.Comparer()
- comp.add_engine('cpu', engine_cpu, *loaders_cpu)
- comp.add_engine('gpu', engine_gpu, *loaders_gpu)
- comp.add_engine('jit', engine_jit, *loaders_jit)
+ comp.add_engine("cpu", engine_cpu, *loaders_cpu)
+ comp.add_engine("gpu", engine_gpu, *loaders_gpu)
+ comp.add_engine("jit", engine_jit, *loaders_jit)
with pytest.raises(AssertionError):
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
@pytest.mark.parametrize("model_class", [MyModel, ModelForIntermediateValue])
@pytest.mark.parametrize("params", [False, True])
def test_dump(engine_fn, model_class, params):
loader = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(5)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(5)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(model_class, "cpu", [1.0], loader)
engine_gpu, loaders_gpu = engine_fn(model_class, "cuda:0", [1.0], loader)
engine_jit, loaders_jit = engine_fn(model_class, "jit-cpu", [1.0], loader)
comp = ppe.utils.comparer.Comparer(params=params)
with tempfile.TemporaryDirectory() as tmpdir:
- comp.dump(engine_cpu, f'{tmpdir}/cpu', *loaders_cpu)
- comp.dump(engine_gpu, f'{tmpdir}/gpu', *loaders_gpu)
- comp.dump(engine_jit, f'{tmpdir}/jit', *loaders_jit)
- comp.add_dump('cpu', f'{tmpdir}/cpu')
- comp.add_dump('gpu', f'{tmpdir}/gpu')
- comp.add_dump('jit', f'{tmpdir}/jit')
+ comp.dump(engine_cpu, f"{tmpdir}/cpu", *loaders_cpu)
+ comp.dump(engine_gpu, f"{tmpdir}/gpu", *loaders_gpu)
+ comp.dump(engine_jit, f"{tmpdir}/jit", *loaders_jit)
+ comp.add_dump("cpu", f"{tmpdir}/cpu")
+ comp.add_dump("gpu", f"{tmpdir}/gpu")
+ comp.add_dump("jit", f"{tmpdir}/jit")
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
-@pytest.mark.parametrize("model_class", [
- MyModel, ModelForIntermediateValue])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
+@pytest.mark.parametrize("model_class", [MyModel, ModelForIntermediateValue])
def test_dump_invalid(engine_fn, model_class):
loader = torch.utils.data.DataLoader(
- [(torch.rand(20,), torch.rand(10,)) for i in range(5)])
+ [
+ (
+ torch.rand(
+ 20,
+ ),
+ torch.rand(
+ 10,
+ ),
+ )
+ for i in range(5)
+ ]
+ )
ppe.runtime.runtime_registry.register("jit-cpu", JITRuntime)
engine_cpu, loaders_cpu = engine_fn(model_class, "cpu", [1.0], loader)
engine_gpu, loaders_gpu = engine_fn(model_class, "cuda:0", [1.0], loader)
engine_jit, loaders_jit = engine_fn(model_class, "jit-cpu", [2.0], loader)
comp = ppe.utils.comparer.Comparer()
- with tempfile.TemporaryDirectory() as tmpdir1, \
- tempfile.TemporaryDirectory() as tmpdir2, \
- tempfile.TemporaryDirectory() as tmpdir3:
+ with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2, tempfile.TemporaryDirectory() as tmpdir3:
comp.dump(engine_cpu, tmpdir1, *loaders_cpu)
comp.dump(engine_gpu, tmpdir2, *loaders_gpu)
comp.dump(engine_jit, tmpdir3, *loaders_jit)
- comp.add_dump('cpu', tmpdir1)
- comp.add_dump('gpu', tmpdir2)
- comp.add_dump('jit', tmpdir3)
+ comp.add_dump("cpu", tmpdir1)
+ comp.add_dump("gpu", tmpdir2)
+ comp.add_dump("jit", tmpdir3)
with pytest.raises(AssertionError):
comp.compare()
@pytest.mark.gpu
-@pytest.mark.parametrize("engine_fn", [
- _get_trainer, _get_evaluator, _get_trainer_with_evaluator])
+@pytest.mark.parametrize(
+ "engine_fn", [_get_trainer, _get_evaluator, _get_trainer_with_evaluator]
+)
def test_compare_baseline(engine_fn):
loader = list(torch.ones(10) for _ in range(10))
engine_cpu, loaders_cpu = engine_fn(Model, "cpu", [1.0], loader)