An interactive tour of sampling a continuous mixture distribution with Trajectory Balance — the pedagogical bug, mode collapse, and the off-policy fix.
Most introductions to GFlowNets target the discrete case. This post walks through the continuous version — sampling real numbers proportional to a reward — using a tiny pretrained policy that runs entirely in your browser. Every plot below is live: drag the temperature slider to rescale the action noise at sampling time, hit resample to draw a new batch.
The example is adapted from the continuous intro notebook in torchgfn. The training script that produced the models on this page lives at scripts/train_distill_gflownets.py — run it yourself to regenerate them.
The target is a mixture of two Gaussians on the real line. We start from $S_0 = 0$ and take $N = 5$ steps, each adding a scalar action $\Delta x \sim \mathcal{N}(\mu(s), \sigma(s))$ whose parameters come from a small MLP. The state is $(x, t)$ — we carry the step counter so the space stays acyclic.
The model is a 2-layer MLP (hidden dim 64), about 4.4K parameters. It emits two numbers per state: a mean and a pre-sigmoid standard deviation that we squash into $[0.1, 1.0]$. Nothing exotic.
Here is the reward landscape. Dashed lines mark the modes; the yellow dot is $S_0$:
Before the real models load you will see samples drawn from a placeholder policy with $\mu = 0, \sigma = 0.3$. Once the ONNX files are deployed, the histogram should concentrate around the two modes.
The first model was trained with Trajectory Balance
Slide the temperature down toward 0 and the samples concentrate on the highest-reward regions (this is the GFlowNet-as-argmax limit). Push it past 1 and the distribution flattens — you are effectively doing a noisier version of the learned policy.
The Trajectory Balance loss requires matching the forward-trajectory log-probability $\log P_F(\tau)$ with the backward $\log P_B(\tau)$:
\[\mathcal{L}(\tau) = \left( \log Z_\theta + \sum_t \log P_F(s_{t+1} \mid s_t) - \log R(x) - \sum_t \log P_B(s_t \mid s_{t+1}) \right)^2.\]In our setup the very last backward transition, $S_1 \to S_0$, is forced — all trajectories start at the same $S_0$, so under any backward policy that reaches $S_1$, the probability of stepping to $S_0$ is 1, i.e.\ log-probability 0. If you accidentally include that transition in the $\log P_B$ sum, you end up training a backward network to concentrate mass on a specific, arbitrary Δ, which distorts the gradient. The symptom: samples look almost right but skewed.
Below is the same environment with a model trained with that bug in the backward loop:
The fix in the training loop is a one-character change: iterate the backward loop over range(N, 1, -1) instead of range(N, 0, -1). The last (forced) transition then correctly contributes 0 to $\log P_B$ without being sampled.
Things get harder when the modes are far from $S_0$. Below, the modes are at $\pm 3$, still with trajectory length 5. On-policy training — sampling actions only from the current learned policy — gets stuck early on whichever mode the initial exploration happens to stumble into:
The histogram at $T = 1$ should sit almost entirely on one side. Cranking the temperature up diffuses the samples, but it does not rescue the other mode — the policy simply never learned it.
The fix is to sample training actions from a noisier distribution than the current policy. Concretely: keep the same learned $(\mu, \sigma)$, but inflate $\sigma$ by a schedule that starts at 2 and linearly decays to 0 over training. The forward log-probability is still evaluated under the learned policy (not the exploration policy — that would bias the gradient), but the sampled trajectories now cover both modes:
Same architecture, same number of iterations, just a different sampling distribution during rollout.
Four modes with varying scales, $S_0$ closest to the first mode, 10K iterations. This one is deliberately at the edge of what the default hyperparameters can solve — the tallest modes get learned first, and the small, far-away mode often remains undersampled:
Longer trajectories, more iterations, or a non-constant exploration schedule would all help here. The takeaway is that continuous GFlowNets scale with trajectory length, not with the number of modes directly.
The slider on each demo multiplies the per-step $\sigma$ at sampling time — it has no effect on what was learned, only on how we draw from it. A few things to notice:
Temperature-tempering of a trained GFlowNet is the continuous analogue of the standard tempered-sampling trick. It is cheap, deterministic, and often enough to push a well-trained policy toward whichever end of the exploration-exploitation axis you need at inference.