The interfaces/ package defines a unified abstraction around diffusion, flow matching, and consistency objectives. Each interface wraps a backbone network (DiT, Lightning DiT, etc.), provides consistent sampling/training APIs, and owns the math that connects the model to a chosen stochastic process.
All concrete interfaces inherit from continuous.Interfaces, an nnx.Module that also acts as an abstract base class. The ABC specifies the end-to-end contract required by the trainers and samplers:
- Time & noise handling:
sample_tandsample_ndraw per-example timesteps and noise with a shared RNG infrastructure (network.rngs). - Transport coefficients:
c_in,c_out,c_skip, andc_noisedescribe the preconditioning factors applied before/after the backbone forward pass. - Forward simulation:
sample_x_tcombines clean data and noise into the noisy state required for each formulation based on the unique noise schedule, whiletargetproduces the regression target used in the loss. - Model outputs:
predreturns the velocity predicted by the network used in the ODE,scoremaps that into a score function when available and use it for SDE, andlossorchestrates the full training step (including time-shift heuristics and auxiliary returns).
Because the base class also implements __call__ = loss, all interfaces can be treated as callable modules inside Flax/NNX training loops.
SiTInterface: Straight-through transport with linear interpolation between data and noise; uses logit-normal time sampling and targetsn - x.EDMInterface: Implements EDM-style variance preserving diffusion with log-normal time families, EDM preconditioning, and weighted losses.MeanFlowInterface: ExtendsSiTInterfacewith guidance mixing, stochastic jump times (r), instantaneous velocities, and auxiliary regression for the Mean Flow objective.sCTInterface/sCDInterface: Skeletons for score-based consistency training variants; the base methods are stubbed for future contributions.repa.py: Wraps anyInterfacesimplementation with REPA auxiliary losses and feature detectors (DINOv2, etc.).discrete.py: Hosts discrete-time counterparts (currently experimental).
All interfaces accept a train_time_dist_type argument (uniform, lognormal, logitnormal) and automatically resolve the proper sampling strategy.
- Subclass
Interfacesand implement the abstract methods. Start by defining the transport path (sample_x_t), target, and time/noise samplers. - Reuse helpers such as
bcast_right,mean_flat, andt_shiftto stay consistent with existing loss formulations. - Expose new knobs via configs: any additional hyperparameters should come from
ConfigDictentries so they can be overridden by--config.*flags. - Integrate with samplers: ensure your
predandscoreoutput matches the velocity / score expected by the sampler insamplers/. Add new sampler variants if the interface requires different stepping logic. - Document the contract: update
interfaces/README.md(this file) with the new interface, and note any special requirements (e.g., extra RNG streams). - Add tests mirroring the layout in
tests/interface_tests/to validate loss values, sampling shapes, and gradient flow.