Stochastic Gradient Descent

Understanding optimization dynamics, loss landscapes, and surprising generalization behaviors

Stochastic Methods in Machine Learning: AGH

Stochastic Gradient Descent

$$\theta_{t+1} = \theta_t - \eta \cdot \nabla L(\theta_t;\, x_i, y_i)$$

Why the Negative Gradient?

The gradient $\nabla L$ points in the direction of steepest ascent. Moving in the negative gradient direction gives steepest descent: the fastest local decrease of the loss.

Full-Batch GD vs Mini-Batch SGD

  • Full-batch GD: Computes gradient over entire dataset: smooth, deterministic path
  • Mini-batch SGD: Approximates gradient from a random subset: noisy but faster per step
  • The noise is not a bug: it helps escape sharp minima and acts as implicit regularization

Learning Rate $\eta$

Too large → overshoots minimum, may diverge
Too small → very slow convergence

Full-batch GD Mini-batch SGD

Click anywhere on the contour plot to place a starting point

Edge of Stability

Cohen, Kaur, Li, Kolter & Talwalkar, 2021: "Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability"

Sharpness as $\lambda_1(H)$

Sharpness is measured as the largest eigenvalue of the Hessian $\lambda_{\max}(\nabla^2 L)$. Classical theory says GD converges only when $\lambda_{\max} < 2/\eta$.

Full-Batch GD, Not SGD

Unlike most analyses, this paper studies deterministic gradient descent (full batch). The surprising behaviors arise without stochastic noise: they are intrinsic to the optimization landscape of neural networks.

The Edge of Stability Phenomenon

  • Phase 1 (Progressive sharpening): GD steadily increases sharpness toward the threshold $2/\eta$
  • Phase 2 (Edge of stability): Sharpness hovers at $\approx 2/\eta$, loss decreases non-monotonically with local oscillations

Why It Matters

Classical convergence guarantees assume sharpness stays below $2/\eta$ throughout training. In practice, GD self-tunes to violate this assumption: it operates right at the stability boundary. This means standard theoretical results do not explain why GD works on neural networks.

Edge of Stability

Top row: train loss. Bottom row: sharpness $\lambda_{\max}$. Dashed lines show $2/\eta$ stability thresholds for each learning rate.

Three Factors Influencing Minima in SGD

Jastrzebski, Kenton, Arpit, Ballas, Fischer, Bengio & Storkey, 2018

$$d\theta = -g(\theta)\,dt + \sqrt{\frac{\eta}{S}}\, R(\theta)\, dW(t)$$

The Key Ratio: $\eta / S$

The noise scale (effective temperature) of SGD is $\eta / S$ (learning rate / batch size). Only this ratio matters, not individual values!

  • Higher $\eta/S$ → more noise → escapes sharp minima → finds flat minima
  • Lower $\eta/S$ → less noise → gets trapped in sharp minima

Three Factors

Factor Effect
Learning rate $\eta$ Higher → more noise, wider minima
Batch size $S$ Smaller → more noise, wider minima
Gradient covariance Determines noise structure (anisotropic)

Large-Batch Training: Sharp vs Flat Minima

Keskar, Mudigere, Nocedal, Smelyanskiy & Tang, 2017

The Generalization Gap

Large batches achieve similar training accuracy but up to 5% worse test accuracy. Why?

  • Large batch → low noise → sharp minima → poor generalization
  • Small batch → high noise → flat minima → good generalization

Sharpness Metric

$$\phi(\epsilon, A) = \frac{\max_{y \in \mathcal{C}_\epsilon} f(x+Ay) - f(x)}{1 + f(x)} \times 100$$

Measures how much loss can increase in a neighborhood. Sharp minima → high $\phi$.

Sharp vs Flat Minima

Double Descent

Belkin et al., 2019: "Reconciling modern ML practice and the bias-variance trade-off"

The Classical View

Bias-variance trade-off predicts a U-shaped test error curve: underfitting → sweet spot → overfitting. But modern overparameterized models defy this!

Three Regimes

Regime Description
Under-param.
$(p < n)$
Classical U-shape. More params reduce bias, eventually increase variance.
Interpolation
$(p \approx n)$
Model barely fits all training data: maximally jagged, highest test error.
Over-param.
$(p \gg n)$
Many interpolating solutions exist. Optimizer's implicit bias selects smoother ones: test error decreases again.
Double Descent

Grokking

Power, Burda, Edwards, Babuschkin & Misra, 2022

Critical Role of Weight Decay

Weight decay is essential for grokking. Without it, the model memorizes but never generalizes. It acts as regularization pressure, pushing the network from memorization toward discovering true algebraic structure.

Key Insight

Generalization can happen long after memorization. Conventional early stopping would miss it entirely. Challenges the view that memorization and generalization are competing alternatives at similar training times.

Grokking

"Explain Overfitting"

Naive Answer

"Overfitting is when the model memorizes the training data instead of learning the underlying pattern. It happens when the model is too complex. The solution is to use regularization, early stopping, or reduce model size."

Nuanced Answer

"The classical view says more parameters inevitably leads to overfitting, but modern deep learning shows this isn't the full picture:

  • Double descent: Beyond the interpolation threshold ($p \approx n$), test error decreases again. Overparameterized models can generalize well without explicit regularization.
  • SGD noise matters: The ratio $\eta/S$ acts as implicit regularization. Small batches and larger learning rates push optimization toward flat minima, which generalize better than sharp ones.
  • Sharp vs flat minima: Overfitting isn't just about model size: it's about which minimum you land in. Large-batch training finds sharp minima with poor generalization, even at the same training loss.
  • Grokking: Memorization and generalization aren't mutually exclusive. A model can perfectly memorize training data, then suddenly generalize much later: if you train long enough with weight decay."
1 / 7