Skip to content

Single H100 10 min 16mb< 1.24 bpb#1559

Open
adityasasidhar wants to merge 1 commit intoopenai:mainfrom
adityasasidhar:main
Open

Single H100 10 min 16mb< 1.24 bpb#1559
adityasasidhar wants to merge 1 commit intoopenai:mainfrom
adityasasidhar:main

Conversation

@adityasasidhar
Copy link
Copy Markdown

@adityasasidhar adityasasidhar commented Apr 12, 2026

Hey,

Single-H100 run -> final sliding-window val_bpb: 1.2498

This is my second PR and I have moved closer to my actual goal of reaching the naive baseline bpb on a single H100

Here's what I have found:

  1. Size doesn't matter as we are under training our model, so its better to shrink it, which I did. Initially everyone jumped to increasing the depth and quantization to fit more parameters ( even I did ), my earlier PR tried 11L with a mix of int 8 attention and int 6 mlp mult, which didnt scale well under my goals so, I'm back to 8L., it gives me an extra 300 steps
  2. Used the sliding eval method with a stride of 128 ( could have used 96 but I'm self funded and broke )
  3. I'm using flash attention 2 which is easier to install and as no wheels are available in the service I'm using, FA3 will be a game changer which is expected to give a boost of upto 3% to 8% in my use case.
  4. Reduced the XSA layers from 4 to 2, it also gave us come extra computational head of around 3%
  5. Tuned the partial rope to 32 and saw great val_bnb improvement of 0.15 bpb
  6. Pushed the QAT to the last 10 percent which allowed higher precision training for the most duration of the session and lower quality training only when spikes started around 900 steps help by 0.01bpb

Also Im well under the with a submission size of 15.2 mb and in future I'll use this space for extra improvements without computational overhead.

What Changed vs Base train_gpt.py

Model and training setup

  • shrinks depth from 9 layers to 8
  • increases sequence length from 1024 to 2048
  • increases MLP expansion from 2x to 3x
  • keeps GQA with 8 query heads and 4 KV heads
  • adds partial RoPE via rope_dims=32
  • raises qk_gain_init from the base script's 1.5
  • increases planned iterations from 20000 to 35000
  • adds xsa_last_n, with XSA enabled on the final 2 layers in this run
  • switches attention layout to [batch, seq, heads, dim] internally and uses the flash_attn 2 interface when available future work needs flash attention 3.
  • applies RoPE only to the first rope_dims channels and leaves the rest of each head unrotated
  • changes block init by orthogonally initializing large linear layers and scaling projection weights by depth
  • tightens the attention RMSNorm to eps=1e-6

Optimization and schedule

  • lowers tied-embedding LR from 0.05 to 0.04
  • lowers matrix/scalar LR from 0.04 to 0.032
  • changes token/scalar/head optimizers from Adam to AdamW
  • adds decoupled weight decay: adam_wd=0.04, muon_wd=0.02
  • extends Muon itself to apply decoupled weight decay
  • adds warmdown_last_frac so warmdown can be driven by wallclock fraction instead of only warmdown_iters
  • adds qat_last_frac so fake quantization is turned on only near the end of training

Evaluation

  • adds eval_stride and eval_batch_seqs
  • adds eval_val_sliding(...) for sliding-window validation
  • refactors the model to expose forward_logits(...) so sliding evaluation can reuse logits directly

Quantization and export

  • uses a mixed-precision export format: mixed_int6_int8_per_row_v1
  • adds packed int6 storage for selected tensors using pack_lowbit_tensor(...) / unpack_lowbit_tensor(...)
  • adds INT6_NAME_PATTERNS and INT8_QAT_NAME_PATTERNS to control export and QAT targeting by parameter name
  • adds STE fake quantization during training for selected CastedLinear weights
  • removes the baseline small-tensor fp16 passthrough heuristic and instead quantizes float tensors into int6 or int8 unless they are non-float passthrough tensors
  • adds compress_quant_payload(...) / decompress_quant_payload(...), currently using zlib
  • renames final log lines from final_int8_zlib_roundtrip... to final_mixed_quant_zlib_roundtrip...

Logged artifact sizes and roundtrip metrics

  • raw serialized model: 76650381 bytes
  • code size: 59176 bytes
  • raw total submission size: 76709557 bytes
  • mixed-quant artifact: 15142148 bytes
  • quantized payload bytes before torch serialization overhead: 19476680
  • quantized raw torch object bytes: 19529207
  • payload compression ratio vs raw tensor bytes: 3.93x
  • final submission size with mixed quant + zlib: 15201324 bytes
  • sliding-window exact: val_loss=2.11022315, val_bpb=1.24979460
  • sliding-window eval time: 657418ms

Apologies for no multi seed runs, I'm a university student and I'm funding myself, waiting for the grant, usually I experiment on cheap services which don't give a lot of options with A5000 or such gpus, I'll update the multi seed log as soon as I find some compute.

Signed-off-by: Aditya Sasidhar <telikicherlaadityasasidhar@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

1 participant