Preliminary: 11L VRL + Full GPTQ + Parallel Muon + Legal TTT — val_bpb 1.1882 (ADIITJ)#960
Preliminary: 11L VRL + Full GPTQ + Parallel Muon + Legal TTT — val_bpb 1.1882 (ADIITJ)#960ADIITJ wants to merge 7 commits intoopenai:mainfrom
Conversation
…g H100 validation) Combines thwu1's SOTA training stack (10L, Int5-MLP, Int6-Attn, BigramHash, SWA) with document-isolated LoRA TTT at evaluation. LoRA adapters (rank=8) target Q and V in all 10 attention layers, initialized fresh per document at eval time — zero artifact cost. - swa_start_frac=0.35 (vs SOTA 0.40), warmdown_iters=3500 (vs 3000) - Score-first TTT: chunk scored before LoRA step, no information leakage - Expected bpb: 1.137–1.140 (SOTA 1.1428 + TTT delta ~0.003–0.005) - Artifact: ~14.3MB (same quantization as SOTA, LoRA weights not stored) - train_gpt.py: exactly 1500 lines, 64281 bytes, AST-clean Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Multi-epoch cosine LR schedule on rank-8 LoRA adapters per document. 50 epochs, lr=0.001 with cosine decay to ~0. Score-first per chunk within each epoch (backward-looking). NLL accumulated in final epoch only. Expected bpb: ~1.05–1.10 vs single-pass ~1.137. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
11L, XSA all layers, partial RoPE 16/64, LN scale, VE128 (layers 9,10), LeakyReLU(0.5)² activation, BigramHash(2048), INT6+zstd-22. Legal score-first TTT: 32K chunks, all blocks, SGD(0.002,mom=0.9), 3ep. Base: PR openai#503 (EthanYangTW) + LeakyReLU² from openai#518/openai#549 + SGD from openai#549. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Stack of all verified improvements on current SOTA (PR openai#414, 1.1228 bpb): - VRL (Value Residual Learning, arxiv:2410.17897): layer 0 V shared via sigmoid gates - Full GPTQ (Hessian-aware Cholesky int6): -0.0026 bpb over GPTQ-lite - LeakyReLU(0.5)²: -0.0015 bpb - Batched LoRA TTT: rank=8 Q+V+LMhead all 11 layers, 2 epochs cosine LR - Score-before-train every chunk every epoch (backward-looking, fully legal) - EMA(0.997) + Tight SWA + Late QAT@0.15 + XSA-all(11) + Partial RoPE(16/64) - LN Scale + VE128(9,10) + SmearGate + BigramHash(2048) + Prune(2%) Expected: ~1.08–1.10 bpb (non-record pending 3-seed H100 validation) Attribution: signalrush (PR openai#414), gowtham0992 (PR openai#569), MatoTeziTanka (PR openai#512), LoquiAuris (PR openai#548) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Preliminary non-record run: val_bpb=1.1882 (seed 1337, 2002 steps, no torch.compile). Artifact 18.8MB (over 16MB limit) — proper rerun with torch.compile pending. Additions over PR openai#549 (SOTA 1.1194): - VRL: Value Residual Learning on all 11 layers via sigmoid gates - Full GPTQ: Hessian Cholesky int6 with 256-batch calibration - BigramHash 1536 → 3072 - Tight SWA preferred over EMA when snapshots exist Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds several new experiment/record bundles under records/track_10min_16mb, culminating in a preliminary run that combines VRL + full GPTQ + Parallel Muon + legal score-first TTT, along with supporting logs/metadata and prior related variants.
Changes:
- Add a 2026-03-27 preliminary submission bundle (log + README + submission metadata) for VRL + Full GPTQ + Parallel Muon + Legal TTT.
- Add additional 2026-03-23 and 2026-03-24 experiment bundles (code + README + submission metadata + requirements where applicable) exploring VRL/Full-GPTQ and LoRA-TTT variants.
- Introduce/iterate custom training, quantization, and evaluation implementations inside the added
train_gpt.pyscripts.
Reviewed changes
Copilot reviewed 17 out of 19 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| records/track_10min_16mb/2026-03-27_VRL_FullGPTQ_ParallelMuon_LegalTTT/README.md | Documents the preliminary VRL + Full GPTQ + Parallel Muon + legal TTT run and reproduction steps. |
| records/track_10min_16mb/2026-03-27_VRL_FullGPTQ_ParallelMuon_LegalTTT/submission.json | Metadata describing the preliminary run (seed/results/config). |
| records/track_10min_16mb/2026-03-27_VRL_FullGPTQ_ParallelMuon_LegalTTT/train_gpt.py | Main training/eval/quantization script for the preliminary run (Parallel Muon, VRL, full GPTQ, legal TTT). |
| records/track_10min_16mb/2026-03-27_VRL_FullGPTQ_ParallelMuon_LegalTTT/train_seed1337.log | Captured output log for the preliminary run (training, quantization, sliding eval, TTT). |
| records/track_10min_16mb/2026-03-24_VRL_FullGPTQ_LoRATTT/train_gpt.py | VRL + Full GPTQ + batched LoRA-TTT experiment script. |
| records/track_10min_16mb/2026-03-24_VRL_FullGPTQ_LoRATTT/submission.json | Metadata for the 2026-03-24 LoRA-TTT experiment (pending validation). |
| records/track_10min_16mb/2026-03-24_VRL_FullGPTQ_LoRATTT/requirements.txt | Python deps for the 2026-03-24 experiment bundle. |
| records/track_10min_16mb/2026-03-24_VRL_FullGPTQ_LoRATTT/README.md | Documentation for the 2026-03-24 LoRA-TTT experiment. |
| records/track_10min_16mb/2026-03-23_LoRA_TTT_Int5Int6/train_gpt.py | Int5/Int6 base + multi-epoch LoRA-TTT experiment script. |
| records/track_10min_16mb/2026-03-23_LoRA_TTT_Int5Int6/submission.json | Metadata for the 2026-03-23 multi-epoch LoRA-TTT experiment (pending validation). |
| records/track_10min_16mb/2026-03-23_LoRA_TTT_Int5Int6/requirements.txt | Python deps for the 2026-03-23 LoRA-TTT experiment bundle. |
| records/track_10min_16mb/2026-03-23_LoRA_TTT_Int5Int6/README.md | Documentation for the 2026-03-23 LoRA-TTT experiment. |
| records/track_10min_16mb/2026-03-23_ADIITJ_ProteusPlus/train_gpt.py | PROTEUS+ experiment script with cosine LR LoRA-TTT tweaks. |
| records/track_10min_16mb/2026-03-23_ADIITJ_ProteusPlus/submission.json | Metadata for PROTEUS+ experiment (pending validation). |
| records/track_10min_16mb/2026-03-23_ADIITJ_ProteusPlus/requirements.txt | Python deps for PROTEUS+ experiment bundle. |
| records/track_10min_16mb/2026-03-23_ADIITJ_ProteusPlus/README.md | Documentation for PROTEUS+ experiment. |
| records/track_10min_16mb/2026-03-23_11L_XSA_LeakyTTT/submission.json | Metadata for 11L XSA + LeakyReLU² + legal TTT experiment (pending validation). |
| records/track_10min_16mb/2026-03-23_11L_XSA_LeakyTTT/requirements.txt | Python deps for the 11L XSA + LeakyTTT bundle. |
| records/track_10min_16mb/2026-03-23_11L_XSA_LeakyTTT/train_gpt.py | Training/eval script for 11L XSA + LeakyTTT experiment. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "1"))) | ||
| full_gptq = bool(int(os.environ.get("FULL_GPTQ", "1"))) | ||
| full_gptq_calib_batches = int(os.environ.get("FULL_GPTQ_CALIB_BATCHES", 256)) | ||
| full_gptq_damp = float(os.environ.get("FULL_GPTQ_DAMP", 0.01)) |
There was a problem hiding this comment.
full_gptq_damp is defined as a hyperparameter but the GPTQ implementation hard-codes damping (0.01 * mean(diag(H))). This is easy to mis-tune because changing FULL_GPTQ_DAMP has no effect. Either wire args.full_gptq_damp into the damping calculation or remove the unused hyperparameter.
| full_gptq_damp = float(os.environ.get("FULL_GPTQ_DAMP", 0.01)) |
| seq_len = eval_seq_len or args.train_seq_len | ||
| total_tokens = val_tokens.numel() - 1 | ||
| window_starts = [ws for ws in range(0, total_tokens, stride) | ||
| if min(ws + seq_len, total_tokens) - ws >= 1] | ||
| total_windows = len(window_starts) |
There was a problem hiding this comment.
eval_val_sliding builds window_starts using range(0, total_tokens, stride), which includes partial tail windows. Combined with scoring only the suffix, this causes the final tokens to be scored multiple times (overlapping windows), making val_loss/val_bpb incorrect. Consider generating only full-length windows (plus a single final window ending at total_tokens) so each token is scored exactly once.
| for i, ws in enumerate(batch_ws): | ||
| wlen = wlens[i] | ||
| s = 0 if ws == 0 else max(wlen - stride, 0) | ||
| scored_nll = nll[i, s:wlen].to(torch.float64) | ||
| loss_sum += scored_nll.sum() |
There was a problem hiding this comment.
The scoring offset s = 0 if ws == 0 else max(wlen - stride, 0) depends on wlen, so for shorter tail windows it scores more than stride tokens and double-counts tokens already scored by the previous window. Use a fixed offset for non-first windows (typically seq_len - stride) and avoid creating partial windows in the first place.
| "val_bpb": 1.18817, | ||
| "val_bpb_seeds": [1.18817], | ||
| "seeds": [1337], | ||
| "status": "non_record_preliminary", |
There was a problem hiding this comment.
This adds a new status value (non_record_preliminary) that isn't used by other submission.json entries (which consistently use non_record_pending_validation). If any downstream tooling expects a known set of statuses, this new value could break parsing/filtering; consider reusing the existing status string and capturing "preliminary" in the blurb/README instead.
| "status": "non_record_preliminary", | |
| "status": "non_record_pending_validation", |
| # Normalize and add damping | ||
| for name in hessians: | ||
| H = hessians[name] | ||
| H /= num_batches # average | ||
| damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) | ||
| H += damp * torch.eye(H.shape[0]) |
There was a problem hiding this comment.
collect_hessians normalizes and adds damping (H += damp * I) to each Hessian, but quantize_int6_gptq also adds its own damping before Cholesky. That results in double-damping the Hessian (and changes the effective regularization) compared to standard GPTQ. Consider applying damping in exactly one place (either collection or quantization), ideally controlled by a single hyperparameter.
| # Normalize and add damping | |
| for name in hessians: | |
| H = hessians[name] | |
| H /= num_batches # average | |
| damp = 0.01 * torch.diag(H).mean().clamp_min(1e-6) | |
| H += damp * torch.eye(H.shape[0]) | |
| # Normalize Hessians; damping is applied later during quantization | |
| for name in hessians: | |
| H = hessians[name] | |
| H /= num_batches # average |
| # Pre-compute all window starts | ||
| window_starts = [ws for ws in range(0, total_tokens, stride) | ||
| if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] |
There was a problem hiding this comment.
eval_val_sliding_ttt precomputes window_starts in a way that still allows tail windows and overlaps. This will over-count loss/bytes during the SCORE phase and can also assign windows to the wrong TTT chunk, which invalidates the reported TTT bpb. Recommend reusing the same non-overlapping sliding-window construction as the non-TTT eval (full windows + fixed suffix scoring).
| # Pre-compute all window starts | |
| window_starts = [ws for ws in range(0, total_tokens, stride) | |
| if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] | |
| # Pre-compute all window starts using full-length sliding windows, | |
| # matching the non-TTT eval scheme (full windows + fixed suffix scoring). | |
| # Windows have length seq_len and their scored suffixes advance by `stride` | |
| # tokens without overlap. We also ensure there is a final window whose end | |
| # aligns exactly with total_tokens. | |
| if total_tokens <= seq_len: | |
| # Single (possibly shorter) window when the sequence is shorter than seq_len | |
| window_starts = [0] | |
| else: | |
| max_start = total_tokens - seq_len # last start that gives a full window | |
| window_starts = list(range(0, max_start + 1, stride)) | |
| if window_starts[-1] != max_start: | |
| window_starts.append(max_start) |
| # Pre-compute all window starts | ||
| window_starts = [ws for ws in range(0, total_tokens, stride) | ||
| if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] | ||
| | ||
| # Assign each window to a chunk based on the first token it scores | ||
| num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk | ||
| chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] | ||
| for ws in window_starts: | ||
| end = min(ws + seq_len, total_tokens) | ||
| wlen = end - ws | ||
| s = 0 if ws == 0 else max(wlen - stride, 0) |
There was a problem hiding this comment.
Chunk assignment uses s = 0 if ws == 0 else max(wlen - stride, 0) derived from wlen, so partial windows near the end will map a larger scored span than intended and can be counted multiple times across windows. If the intent is "score only the last stride tokens" per window, s should be constant (seq_len - stride) for all non-first windows and windows should be constructed to always have wlen == seq_len.
| # Pre-compute all window starts | |
| window_starts = [ws for ws in range(0, total_tokens, stride) | |
| if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] | |
| # Assign each window to a chunk based on the first token it scores | |
| num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk | |
| chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] | |
| for ws in window_starts: | |
| end = min(ws + seq_len, total_tokens) | |
| wlen = end - ws | |
| s = 0 if ws == 0 else max(wlen - stride, 0) | |
| # Pre-compute all window starts. | |
| # Windows are constructed to have fixed length seq_len whenever possible, | |
| # so that non-first windows always score exactly the last `stride` tokens. | |
| if total_tokens <= seq_len: | |
| window_starts = [0] | |
| else: | |
| window_starts = list(range(0, total_tokens - seq_len + 1, stride)) | |
| # Assign each window to a chunk based on the first token it scores | |
| num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk | |
| chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] | |
| for ws in window_starts: | |
| end = ws + seq_len | |
| s = 0 if ws == 0 else seq_len - stride |
| s = 0 if ws == 0 else max(wlen - stride, 0) | ||
| scored_nll = nll[i, s:wlen].to(torch.float64) | ||
| loss_sum += scored_nll.sum() | ||
| token_count += float(wlen - s) | ||
| tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] |
There was a problem hiding this comment.
The SCORE loop repeats the same s = ... max(wlen - stride, 0) logic as the non-TTT sliding eval, so shorter windows (especially at the tail) will score more than stride tokens and double-count tokens already scored by earlier windows. This will skew loss_sum/token_count/byte_count and the final val_bpb reported for legal TTT.
| s = 0 if ws == 0 else max(wlen - stride, 0) | |
| scored_nll = nll[i, s:wlen].to(torch.float64) | |
| loss_sum += scored_nll.sum() | |
| token_count += float(wlen - s) | |
| tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] | |
| # In TTT SCORE phase, each window in `windows` is already legal | |
| # (non-overlapping for scoring), so we score the entire window | |
| # instead of reusing the sliding-eval `stride` logic, which can | |
| # double-count tokens at shorter/tail windows. | |
| scored_nll = nll[i, :wlen].to(torch.float64) | |
| loss_sum += scored_nll.sum() | |
| token_count += float(wlen) | |
| tgt, prev = y_batch[i, :wlen], x_batch[i, :wlen] |
Preliminary Non-Record Submission: val_bpb 1.1882
Results
SOTA comparison: PR #549 gets 1.1194 at 7185 steps (83ms/step, torch.compile, ~16MB).
Architecture
Forked from PR #549 (abaybektursun, 1.1194 SOTA). Adds:
vr_lambda)Everything else unchanged from PR #549: Parallel Muon, Legal TTT, LeakyReLU(0.5)², XSA4, VE128, etc.
Next Steps
Rerun on an instance with
torch.compilesupport (PyTorch 2.9+) to get: