Table of Contents
PYTORCH DOCUMENTATION
PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.
Notes
Automatic Mixed Precision examples
Autograd mechanics
Broadcasting semantics
CPU threading and TorchScript inference
CUDA semantics
Distributed Data Parallel
Extending PyTorch
Frequently Asked Questions
Features for large-scale deployments
Multiprocessing best practices
Reproducibility
Serialization semantics
Windows FAQ
Language Bindings
C++
Javadoc
Python API
torch
torch.nn
torch.nn.functional
torch.Tensor
Tensor Attributes
Tensor Views
torch.autograd
torch.cuda
torch.cuda.amp
torch.distributed
torch.distributions
torch.hub
torch.jit
torch.nn.init
torch.onnx
torch.optim
Quantization
Distributed RPC Framework
torch.random
torch.sparse
torch.Storage
torch.utils.bottleneck
torch.utils.checkpoint
torch.utils.cpp_extension
torch.utils.data
torch.utils.dlpack
torch.utils.model_zoo
torch.utils.tensorboard
Type Info
Named Tensors
Named Tensors operator coverage
torch.__config__
Libraries
torchaudio
torchtext
torchvision
TorchElastic
TorchServe
PyTorch on XLA Devices
Community
PyTorch Contribution Guide
PyTorch Governance
PyTorch Governance | Persons of Interest
INDICES AND TABLES
Index
Module Index
© Copyright 2019, Torch Contributors.
Built with Sphinx using a theme provided by Read the Docs.
Next
Docs
Tutorials
Resources
Access comprehensive developer documentation for
Get in-depth tutorials for beginners and advanced
Find development resources and get your questions
PyTorch
View Docs
developers
View Tutorials
answered
View Resources
PyTorch
Get Started
Features
Ecosystem
Blog
Contributing
Resources
Tutorials
Docs
Discuss
Github Issues
Brand Guidelines
Stay Connected
Email Address
Table of Contents
AUTOMATIC MIXED PRECISION EXAMPLES
•
WARNING
torch.cuda.amp.GradScaler is not a complete implementation of automatic mixed precision. GradScaler is only useful if you manually run regions of your model in
float16 . If you aren’t sure how to choose op precision manually, the master branch and nightly pip/conda builds include a context manager that chooses op precision
automatically wherever it’s enabled. See the master documentation for details.
Gradient Scaling
Typical Use
Working with Unscaled Gradients
Gradient clipping
Working with Scaled Gradients
Gradient penalty
Working with Multiple Losses and Optimizers
Gradient Scaling
Gradient scaling helps prevent gradient underflow when training with mixed precision, as explained here.
Instances of torch.cuda.amp.GradScaler help perform the steps of gradient scaling conveniently, as shown in the following code snippets.
Typical Use
# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales the loss, and calls backward() on the scaled loss to create scaled gradients.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
Working with Unscaled Gradients
All gradients produced by scaler.scale(loss).backward() are scaled. If you wish to modify or inspect the parameters’ .grad attributes between backward() and
scaler.step(optimizer) , you should unscale them first. For example, gradient clipping manipulates a set of gradients such that their global norm (see
torch.nn.utils.clip_grad_norm_() ) or maximum magnitude (see torch.nn.utils.clip_grad_value_() ) is
clip without unscaling, the gradients’ norm/maximum magnitude would also be scaled, so your requested threshold (which was meant to be the threshold for unscaled gradients) would be
invalid.
some user-imposed threshold. If you attempted to
<=
scaler.unscale_(optimizer) unscales gradients held by optimizer ’s assigned parameters. If your model or models contain other parameters that were assigned to another
optimizer (say optimizer2 ), you may call scaler.unscale_(optimizer2) separately to unscale those parameters’ gradients as well.
Gradient clipping
Calling scaler.unscale_(optimizer) before clipping enables you to clip unscaled gradients as usual:
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# optimizer's gradients are already unscaled, so scaler.step does not unscale them,
# although it still skips optimizer.step() if the gradients contain infs or NaNs.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
scaler records that scaler.unscale_(optimizer) was already called for this optimizer this iteration, so scaler.step(optimizer) knows not to redundantly unscale
gradients before (internally) calling optimizer.step() .
•
WARNING
unscale_() should only be called once per optimizer per step() call, and only after all gradients for that optimizer’s assigned parameters have been accumulated. Calling
unscale_() twice for a given optimizer between each step() triggers a RuntimeError.
Working with Scaled Gradients
For some operations, you may need to work with scaled gradients in a setting where scaler.unscale_ is unsuitable.
Gradient penalty
A gradient penalty implementation typically creates gradients out-of-place using torch.autograd.grad() , combines them to create the penalty value, and adds the penalty value to
the loss.
Here’s an ordinary example of an L2 penalty without gradient scaling:
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Creates some gradients out-of-place
grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
# Computes the penalty term and adds it to the loss
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
loss.backward()
optimizer.step()
To implement a gradient penalty with gradient scaling, the loss passed to torch.autograd.grad() should be scaled. The resulting out-of-place gradients will therefore be scaled, and
should be unscaled before being combined to create the penalty value.
Here’s how that looks for the same L2 penalty:
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
# Scales the loss for the out-of-place backward pass, resulting in scaled grad_params
scaled_grad_params = torch.autograd.grad(scaler.scale(loss), model.parameters(), create_graph=True)
# Unscales grad_params before computing the penalty. grad_params are not owned
# by any optimizer, so ordinary division is used instead of scaler.unscale_:
inv_scale = 1./scaler.get_scale()
grad_params = [p*inv_scale for p in scaled_grad_params]
# Computes the penalty term and adds it to the loss
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm.sqrt()
loss = loss + grad_norm
# Applies scaling to the backward call as usual. Accumulates leaf gradients that are correctly scaled.
scaler.scale(loss).backward()
# step() and update() proceed as usual.
scaler.step(optimizer)
scaler.update()
Working with Multiple Losses and Optimizers
If your network has multiple losses, you must call scaler.scale on each of them individually. If your network has multiple optimizers, you may call scaler.unscale_ on any of
them individually, and you must call scaler.step on each of them individually.
However, scaler.update() should only be called once, after all optimizers used this iteration have been stepped:
scaler = torch.cuda.amp.GradScaler()
for epoch in epochs:
for input, target in data:
optimizer0.zero_grad()
optimizer1.zero_grad()
output0 = model0(input)
output1 = model1(input)
loss0 = loss_fn(2 * output0 + 3 * output1, target)
loss1 = loss_fn(3 * output0 - 5 * output1, target)
scaler.scale(loss0).backward(retain_graph=True)
scaler.scale(loss1).backward()
# You can choose which optimizers receive explicit unscaling, if you
# want to inspect or modify the gradients of the params they own.
scaler.unscale_(optimizer0)
scaler.step(optimizer0)
scaler.step(optimizer1)
scaler.update()
Each optimizer independently checks its gradients for infs/NaNs, and therefore makes an independent decision whether or not to skip the step. This may result in one optimizer skipping the
step while the other one does not. Since step skipping occurs rarely (every several hundred iterations) this should not impede convergence. If you observe poor convergence after adding
gradient scaling to a multiple-optimizer model, please file an issue.
Previous
Next
© Copyright 2019, Torch Contributors.
Built with Sphinx using a theme provided by Read the Docs.
Docs
Tutorials
Resources
Access comprehensive developer documentation for
Get in-depth tutorials for beginners and advanced
Find development resources and get your questions
PyTorch
View Docs
developers
View Tutorials
answered
View Resources
PyTorch
Get Started
Features
Ecosystem
Blog
Contributing
Resources
Tutorials
Docs
Discuss
Github Issues
Brand Guidelines
Stay Connected
Email Address
Table of Contents
AUTOGRAD MECHANICS
This note will present an overview of how autograd works and records the operations. It’s not strictly necessary to understand all this, but we recommend getting familiar with it, as it will
help you write more efficient, cleaner programs, and can aid you in debugging.
Excluding subgraphs from backward
Every Tensor has a flag: requires_grad that allows for fine grained exclusion of subgraphs from gradient computation and can increase efficiency.
requires_grad
If there’s a single input to an operation that requires gradient, its output will also require gradient. Conversely, only if all inputs don’t require gradient, the output also won’t require it.
Backward computation is never performed in the subgraphs, where all Tensors didn’t require gradients.
>>> x = torch.randn(5, 5) # requires_grad=False by default
>>> y = torch.randn(5, 5) # requires_grad=False by default
>>> z = torch.randn((5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True
This is especially useful when you want to freeze part of your model, or you know in advance that you’re not going to use gradients w.r.t. some parameters. For example if you want to
finetune a pretrained CNN, it’s enough to switch the requires_grad flags in the frozen base, and no intermediate buffers will be saved, until the computation gets to the last layer,
where the affine transform will use weights that require gradient, and the output of the network will also require them.
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Replace the last fully-connected layer
# Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)
# Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
How autograd encodes the history
Autograd is reverse automatic differentiation system. Conceptually, autograd records a graph recording all of the operations that created the data as you execute operations, giving you a
directed acyclic graph whose leaves are the input tensors and roots are the output tensors. By tracing this graph from roots to leaves, you can automatically compute the gradients using the
chain rule.
Internally, autograd represents this graph as a graph of Function objects (really expressions), which can be apply() ed to compute the result of evaluating the graph. When
computing the forwards pass, autograd simultaneously performs the requested computations and builds up a graph representing the function that computes the gradient (the .grad_fn
attribute of each torch.Tensor is an entry point into this graph). When the forwards pass is completed, we evaluate this graph in the backwards pass to compute the gradients.
An important thing to note is that the graph is recreated from scratch at every iteration, and this is exactly what allows for using arbitrary Python control flow statements, that can change
the overall shape and size of the graph at every iteration. You don’t have to encode all possible paths before you launch the training - what you run is what you differentiate.
In-place operations with autograd
Supporting in-place operations in autograd is a hard matter, and we discourage their use in most cases. Autograd’s aggressive buffer freeing and reuse makes it very efficient and there are
very few occasions when in-place operations actually lower memory usage by any significant amount. Unless you’re operating under heavy memory pressure, you might never need to use
them.
There are two main reasons that limit the applicability of in-place operations:
1. In-place operations can potentially overwrite values required to compute gradients.
2. Every in-place operation actually requires the implementation to rewrite the computational graph. Out-of-place versions simply allocate new objects and keep references to the old
graph, while in-place operations, require changing the creator of all inputs to the Function representing this operation. This can be tricky, especially if there are many Tensors that
reference the same storage (e.g. created by indexing or transposing), and in-place functions will actually raise an error if the storage of modified inputs is referenced by any other
Tensor .
In-place correctness checks
Every tensor keeps a version counter, that is incremented every time it is marked dirty in any operation. When a Function saves any tensors for backward, a version counter of their
containing Tensor is saved as well. Once you access self.saved_tensors it is checked, and if it is greater than the saved value an error is raised. This ensures that if you’re using in-
place functions and not seeing any errors, you can be sure that the computed gradients are correct.
Previous
Next
© Copyright 2019, Torch Contributors.
Built with Sphinx using a theme provided by Read the Docs.
Docs
Tutorials
Resources
Access comprehensive developer documentation for
Get in-depth tutorials for beginners and advanced
Find development resources and get your questions
PyTorch
View Docs
developers
View Tutorials
answered
View Resources
PyTorch
Get Started
Features
Ecosystem
Blog
Contributing
Resources
Tutorials
Docs
Discuss
Github Issues
Brand Guidelines
Stay Connected
Email Address