Skip to content

(Nonrecord) Async Prefetching#591

Open
SirSaltySalmon wants to merge 3 commits intoopenai:mainfrom
SirSaltySalmon:async_prefetching
Open

(Nonrecord) Async Prefetching#591
SirSaltySalmon wants to merge 3 commits intoopenai:mainfrom
SirSaltySalmon:async_prefetching

Conversation

@SirSaltySalmon
Copy link
Copy Markdown

@SirSaltySalmon SirSaltySalmon commented Mar 24, 2026

Asynchronous Prefetching — submission notes

1191 (with technique) vs 1137 (default) steps in 600s on local compute

Key changes

Same model, optimizer, data layout, and training math as baseline. This is a general purpose rework that could apply to most other approaches for slight speed boosts. Overlap CPU data prep and host→device copies with GPU work so the GPU spends less time idle.

Area Original (train_gpt_og_linux.py) Improved
Training batches Each step: read tokens on CPU, then H2D — all on the main thread before forward. Background thread (PrefetchingDistributedTokenLoader) builds the next pinned CPU batch while the GPU runs the current step. Primary win: CPU work overlaps GPU compute (not GPU-side double-buffering of H2D vs forward).
H2D Single default stream. Optional dedicated CUDA copy stream (TRAIN_COPY_STREAM, off when timing diagnostics are on). Transfers use pinned memory; the training path still waits for that step’s H2D before forward (wait_stream).
Validation Simple loop: slice → GPU → forward; BPB byte math on GPU. Prefetch thread for pinned CPU batches; double-buffered H2D with copy stream + events so the next batch can copy while the current forward runs. Default VAL_BYTECOUNT_DEVICE=cpu moves BPB byte counting off the GPU vs the original (set cuda to mirror baseline GPU LUT math).

Diagnostics

To measure how much time this actually saves, I added TRAINING_TIMING_BREAKDOWN (batch CPU vs H2D vs FWD/BWD/opt vs val; adds syncs). When enabled, lines log every TRAINING_TIMING_EVERY steps (default 200) and for early steps (first 10). Extra logs: train/val I/O mode, val_stage_time_ms, train vs val wall time split.

VAL_BYTECOUNT_DEVICE defaults to cpu in the improved script (not an extra flag you must set). Use cuda if you want validation byte math on the GPU like the original.

Optional VAL_PROGRESS_LOG_EVERY (default 0): set to a positive value to log per-batch validation progress (val_progress:...).

Defaults & toggles

Overlap features are on by default (TRAIN_PREFETCH, TRAIN_COPY_STREAM, VAL_PREFETCH, VAL_COPY_STREAM, etc.) and can be turned off via env vars if needed. TRAINING_TIMING_BREAKDOWN defaults to 0 and is not displayed. Prefetch/overlap are automatically disabled when TRAINING_TIMING_BREAKDOWN=1 so timings stay interpretable.

Idea

Prefetch training and validation batches asynchronously and parallelize CPU ↔ GPU transfers with compute to minimize pipeline bubbles under a fixed wall-clock budget.
This is an intuitive idea that I came up with that could help models with real research and architectural advancements place slightly higher.

Why this may be unimpactful in some cases

With TRAINING_TIMING_BREAKDOWN=1, early-step lines look like this (same hardware / config as above; grad_accum_steps=8, per-micro averages for batch/forward/backward):

timing_breakdown step:1 micro_steps:8 batch_cpu_ms:0.29 batch_h2d_ms:0.35 forward_ms:30.54 backward_ms:64.93 grad_clip_ms:0.00 optimizer_ms:55.37 val_ms:121092.09 explicit_sync_ms:0.16 (per_optimizer_step; forward/backward/batch averaged over micro_steps; grad_accum_steps=8) timing_breakdown step:2 micro_steps:8 batch_cpu_ms:0.29 batch_h2d_ms:0.35 forward_ms:30.29 backward_ms:64.72 grad_clip_ms:0.00 optimizer_ms:55.08 val_ms:0.00 explicit_sync_ms:0.00 (per_optimizer_step; forward/backward/batch averaged over micro_steps; grad_accum_steps=8) timing_breakdown step:3 micro_steps:8 batch_cpu_ms:0.28 batch_h2d_ms:0.37 forward_ms:30.66 backward_ms:65.18 grad_clip_ms:0.00 optimizer_ms:54.45 val_ms:0.00 explicit_sync_ms:0.00 (per_optimizer_step; forward/backward/batch averaged over micro_steps; grad_accum_steps=8) timing_breakdown step:4 micro_steps:8 batch_cpu_ms:0.31 batch_h2d_ms:0.34 forward_ms:30.34 backward_ms:64.43 grad_clip_ms:0.00 optimizer_ms:55.19 val_ms:0.00 explicit_sync_ms:0.00 (per_optimizer_step; forward/backward/batch averaged over micro_steps; grad_accum_steps=8) 

How to read this: batch_cpu_ms and batch_h2d_ms are ~0.3 ms per micro-step; forward_ms and backward_ms are ~30 ms and ~65 ms per micro-step. Scaled by 8 micro-steps, batch prep + H2D is on the order of ~5 ms per optimizer step, while forward + backward + optimizer is on the order of ~800+ ms. So data movement is a tiny slice of the step; overlapping it cannot move wall-clock much when the GPU is already busy with compute for almost the whole step.

Caveat: On a much faster GPU (or smaller model / larger batch so steps are shorter), the same CPU+H2D work could become a larger fraction of the step, and prefetch or val overlap might show up more in profiles. The breakdown above is not universal; it only shows why the optimization can be a no-op when compute is the bottleneck.

@SirSaltySalmon SirSaltySalmon changed the title Async Prefetching (Nonrecord) Async Prefetching Mar 24, 2026
@MatoTeziTanka
Copy link
Copy Markdown

MatoTeziTanka commented Apr 11, 2026

[RETRACTED 2026-04-11] — This IMPORT_FAIL was a false positive. Root cause: Py3.10 @DataClass + spec_from_file_location harness bug. Your code is not broken. See correction below: #591 (comment)


Community Review — (Nonrecord) Async Prefetching

Compliance: NEEDS AUTHOR ACTION — train_gpt.py fails to import on CT2038 (Python 3.10 / torch 2.10.0+cpu)

What I found: The CPU smoke test on CT2038 (proteus-engine, 128 GB RAM, Triton 3.6.0, flash_attn stub, cutlass_evt_fusion stub) failed at the import step with:

AttributeError: 'NoneType' object has no attribute '__dict__' 

A few of the common patterns I've seen for this class of error in the 2026-04-11 sweep:

Recommendation: Could you run python3 -c "import py_compile; py_compile.compile('train_gpt.py')" on your records-folder train_gpt.py under Python 3.10 specifically? The eval image is Python 3.10 per Issue #17 / the README, so any parse error on 3.10 blocks the submission at import time before any of the scored-eval logic runs.

Once the parse/import issue is fixed, I'll re-run the compliance audit through the normal pipeline. No other flags identified yet because the audit halts at the import step.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL — AttributeError: 'NoneType' object has no attribute 'dict'. Classification via classify_prs.py AST-based classifier; full compliance audit deferred until the import issue is resolved. Auto-drafted from a template and spot-checked before posting.

@MatoTeziTanka
Copy link
Copy Markdown

Retraction — this IMPORT_FAIL was a Python 3.10 @dataclass loader bug in my harness

Sorry @SirSaltySalmon, this one's on me. I re-audited the AttributeError: 'NoneType' object has no attribute '__dict__' I reported above and confirmed it's a harness bug, not a bug in your code.

Root cause:

My smoke harness loaded your file with:

spec = importlib.util.spec_from_file_location("train_module", script_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # ← crashed here

Python 3.10's dataclasses.py line 711 does sys.modules.get(cls.__module__).__dict__ while processing your @dataclass class Hyperparameters. When a module is loaded via spec_from_file_location without first registering it in sys.modules, cls.__module__ resolves to "train_module" but sys.modules["train_module"] is still None, so .dict__ crashes. This is a well-known interaction between importlib.util.spec_from_file_location and @dataclass on 3.10 — fixed in 3.11+, worked around on 3.10 by registering the module name in sys.modules before exec_module.

Verified at head 4ce0096:

Running your records/track_non_record_16mb/2026-03-24-AsyncTrainingValidation/train_gpt.py with the fix:

sys.modules["train_module"] = mod spec.loader.exec_module(mod)

…produces IMPORT_OK, HAS_HYPERPARAMETERS=True, HAS_GPT=True. Your code imports cleanly on Python 3.10.

Your PR is not broken by this error. I'm retracting the IMPORT_FAIL classification. I'll re-queue the full compliance audit (BPB check, n-gram / TTT / SLOT flags, etc.) and post findings separately.

Again — sorry for the noise. Harness bug, not your code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants