Skip to content

Commit

Permalink
Merge pull request #693 from linshokaku/code-format
Browse files Browse the repository at this point in the history
fix code format (exclude pfto)
  • Loading branch information
emcastillo authored May 29, 2023
2 parents 12effd8 + 2686502 commit 25b742d
Show file tree
Hide file tree
Showing 179 changed files with 8,865 additions and 5,784 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pretest-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Code Style
run: |
pip install pysen black==21.11b1 flake8==4.0.1 isort==5.10.1 mypy==0.991
pip install pysen black==23.3.0 flake8==4.0.1 isort==5.10.1 mypy==0.991
pip install types-PyYAML types-setuptools
cp "$(pip show torch | awk '/^Location:/ { print $2 }')/torch/__init__.py" stubs/torch/__init__.py
MYPYPATH="${PWD}/stubs" pysen run lint
Expand Down
18 changes: 9 additions & 9 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

# -- Project information -----------------------------------------------------

project = 'pytorch-pfn-extras'
copyright = '2021, Preferred Networks, Inc.'
author = 'Preferred Networks, Inc.'
project = "pytorch-pfn-extras"
copyright = "2021, Preferred Networks, Inc."
author = "Preferred Networks, Inc."


# -- General configuration ---------------------------------------------------
Expand All @@ -28,15 +28,15 @@
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
'myst_parser',
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
"myst_parser",
]

autodoc_typehints = "description"

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand All @@ -51,9 +51,9 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'pydata_sphinx_theme'
html_theme = "pydata_sphinx_theme"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
151 changes: 99 additions & 52 deletions example/ignite-mnist.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from argparse import ArgumentParser

import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.training.extensions as extensions
import torch
import torch.nn.functional as F
from ignite.engine import (
Events,
create_supervised_evaluator,
create_supervised_trainer,
)
from ignite.metrics import Accuracy, Loss
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import MNIST

from ignite.engine import Events
from ignite.engine import create_supervised_trainer
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

import pytorch_pfn_extras as ppe
import pytorch_pfn_extras.training.extensions as extensions
from torchvision.transforms import Compose, Normalize, ToTensor


class Net(nn.Module):
Expand All @@ -40,55 +40,75 @@ def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
MNIST(download=True, root="../data", transform=data_transform,
train=True),
batch_size=train_batch_size, shuffle=True)
MNIST(
download=True, root="../data", transform=data_transform, train=True
),
batch_size=train_batch_size,
shuffle=True,
)

val_loader = DataLoader(
MNIST(download=False, root="../data", transform=data_transform,
train=False),
batch_size=val_batch_size, shuffle=False)
MNIST(
download=False,
root="../data",
transform=data_transform,
train=False,
),
batch_size=val_batch_size,
shuffle=False,
)
return train_loader, val_loader


def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
train_loader, val_loader = get_data_loaders(
train_batch_size, val_batch_size)
train_batch_size, val_batch_size
)
model = Net()
device = 'cpu'
device = "cpu"
if torch.cuda.is_available():
device = 'cuda:0'
device = "cuda:0"
model = model.to(device)
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
optimizer.step()
trainer = create_supervised_trainer(
model, optimizer, F.nll_loss, device=device)
model, optimizer, F.nll_loss, device=device
)
evaluator = create_supervised_evaluator(
model,
metrics={'acc': Accuracy(), 'loss': Loss(F.nll_loss)},
device=device)
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=device,
)

# manager.extend(...) also works
my_extensions = [
extensions.LogReport(),
extensions.ProgressBar(),
extensions.observe_lr(optimizer=optimizer),
extensions.ParameterStatistics(model, prefix='model'),
extensions.ParameterStatistics(model, prefix="model"),
extensions.VariableStatisticsPlot(model),
extensions.snapshot(),
extensions.IgniteEvaluator(
evaluator, val_loader, model, progress_bar=True),
extensions.PlotReport(['train/loss'], 'epoch', filename='loss.png'),
extensions.PrintReport([
'epoch', 'iteration', 'train/loss', 'lr',
'model/fc2.bias/grad/min', 'val/loss', 'val/acc',
]),
evaluator, val_loader, model, progress_bar=True
),
extensions.PlotReport(["train/loss"], "epoch", filename="loss.png"),
extensions.PrintReport(
[
"epoch",
"iteration",
"train/loss",
"lr",
"model/fc2.bias/grad/min",
"val/loss",
"val/acc",
]
),
]
models = {'main': model}
optimizers = {'main': optimizer}
models = {"main": model}
optimizers = {"main": optimizer}
manager = ppe.training.IgniteExtensionsManager(
trainer, models, optimizers, args.epochs,
extensions=my_extensions)
trainer, models, optimizers, args.epochs, extensions=my_extensions
)

# Lets load the snapshot
if args.snapshot is not None:
Expand All @@ -97,30 +117,57 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):

@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
ppe.reporting.report({'train/loss': engine.state.output})
ppe.reporting.report({"train/loss": engine.state.output})

trainer.run(train_loader, max_epochs=epochs)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--val_batch_size', type=int, default=1000,
help='input batch size for validation (default: 1000)')
parser.add_argument('--epochs', type=int, default=10,
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5,
help='SGD momentum (default: 0.5)')
parser.add_argument('--log_interval', type=int, default=10,
help='how many batches to wait before logging '
'training status')
parser.add_argument('--snapshot', type=str, default=None,
help='path to snapshot file')
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--val_batch_size",
type=int,
default=1000,
help="input batch size for validation (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--lr", type=float, default=0.01, help="learning rate (default: 0.01)"
)
parser.add_argument(
"--momentum",
type=float,
default=0.5,
help="SGD momentum (default: 0.5)",
)
parser.add_argument(
"--log_interval",
type=int,
default=10,
help="how many batches to wait before logging " "training status",
)
parser.add_argument(
"--snapshot", type=str, default=None, help="path to snapshot file"
)

args = parser.parse_args()

run(args.batch_size, args.val_batch_size, args.epochs, args.lr,
args.momentum, args.log_interval)
run(
args.batch_size,
args.val_batch_size,
args.epochs,
args.lr,
args.momentum,
args.log_interval,
)
Loading

0 comments on commit 25b742d

Please sign in to comment.