đź§ Introduction: What Is torch.func
?
In PyTorch, most models are built using stateful objects like nn.Module
. While this is intuitive and powerful, functional programming can unlock more control, composability, and efficiency in advanced machine learning workflows.
That’s where torch.func
comes in.
torch.func
is a PyTorch module introduced to bring functional programming paradigms into the PyTorch ecosystem. It allows users to:
- Apply automatic differentiation (
grad
,vjp
,jvp
) - Perform vectorization using
vmap
- Use functional calls with stateless modules
- Enable pure functions, reducing side effects
This is especially useful in higher-order optimization, meta-learning, batched operations, and advanced gradient computations.
đź”§ Getting Started with torch.func
Before diving in, make sure you’re using PyTorch 2.0 or later, as torch.func
is part of the newer functional APIs.
🛠️ Code Examples Using torch.func
1. Stateless Model Execution with functional_call
import torch
import torch.nn as nn
from torch.func import functional_call
model = nn.Linear(2, 2)
x = torch.randn(1, 2)
# Extract model parameters
params = dict(model.named_parameters())
# Run model in a functional/stateless way
output = functional_call(model, params, (x,))
print("Functional output:", output)
2. Automatic Differentiation with grad
from torch.func import grad
def f(x):
return (x**2).sum()
# Compute gradient of f w.r.t x
grad_f = grad(f)
x = torch.tensor([1.0, 2.0, 3.0])
print("Gradient:", grad_f(x))
3. Batching with vmap
from torch.func import vmap
def linear(x, w, b):
return x @ w.T + b
# Batch inputs
xs = torch.randn(10, 3)
ws = torch.randn(10, 2, 3)
bs = torch.randn(10, 2)
# Apply vectorization
batched_output = vmap(linear)(xs, ws, bs)
print("Batched output:", batched_output.shape)
📚 Common Methods in torch.func
Function | Description |
---|---|
grad(func) | Computes gradient of a scalar function |
vmap(func) | Vectorizes a function over batch dimensions |
functional_call(model, params, args) | Executes a model statelessly |
jacrev(func) | Computes Jacobian via reverse-mode AD |
jacfwd(func) | Computes Jacobian via forward-mode AD |
jvp(func) | Computes Jacobian-vector product |
vjp(func) | Computes vector-Jacobian product |
These methods work seamlessly with PyTorch’s autograd and enable functional transformations similar to JAX.
âť— Common Errors and Debugging Tips
❌ Error: TypeError: 'Linear' object is not callable
Fix: When using functional_call
, ensure the third argument is a tuple of inputs:
pythonCopyEditfunctional_call(model, params, (x,))
❌ Error: RuntimeError: grad can only be applied to scalar outputs
Cause: torch.func.grad
requires the function to return a scalar (like loss).
Fix:
def f(x): return x.sum() # or use .mean(), .norm(), etc.
❌ Shape Mismatch in vmap
Tip: Ensure all batched inputs have the same leading dimensions:
pythonCopyEditvmap(fn)(batched_input1, batched_input2)
Check .shape
of inputs if unexpected broadcasting occurs.
❌ Using Stateful Layers with functional_call
functional_call
expects pure functions. If your model uses BatchNorm
or other layers that track state, you may need to pass buffers as well:
pythonCopyEditparams = dict(model.named_parameters())
buffers = dict(model.named_buffers())
Then pass both into functional_call
.
âś… People Also Ask (FAQ)
🔹 What is the torch
function?
In general, the term “torch function” refers to any callable provided by the PyTorch library. This includes mathematical functions, neural network modules, tensor operations, and higher-order functions under torch.func
.
🔹 What is the function of the torch
library?
PyTorch (torch
) is a Python-based deep learning library used for:
- Building neural networks
- Automatic differentiation
- Tensor computation
- GPU acceleration
- Functional and modular modeling
It’s widely used in both research and production for training and deploying ML models.
🔹 What does torch.where
do?
torch.where(condition, x, y)
returns elements chosen from x
or y
depending on the condition
.
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
c = torch.where(a > 2, a, b)
# c = [4, 5, 3]
This is useful for conditional selection or masking.
🔹 Is torch.func
the same as torch.nn.functional
?
No.
torch.nn.functional
provides stateless versions of PyTorch layers likeF.relu
,F.linear
, etc.torch.func
provides functional programming utilities likevmap
,grad
, andfunctional_call
.
🔹 When should I use torch.func
?
Use torch.func
when:
- You need per-sample gradients
- You want to apply batched operations efficiently
- You’re implementing meta-learning or higher-order optimization
- You need full control over function transformation and differentiation
📌 Final Thoughts
torch.func
marks a new era in PyTorch’s evolution by empowering developers with functional programming capabilities. It helps build stateless, composable, and differentiable functions—a crucial foundation for advanced ML workflows.