ICLR 2025
🎉 This is a PyTorch/GPU implementation of the paper Kolmogorov–Arnold Transformer (KAT), which replace the MLP layers in transformer with KAN layers.
For more technical details, please refer to our ICLR'25 paper.
Kolmogorov–Arnold Transformer
📝[Paper] </>[code] </>[Trition/CUDA kernel]
Xingyi Yang, Xinchao Wang
National University of Singapore
International Conference on Learning Representations (ICLR'25)
Vanilla ViT + KAN struggle to scale effectively. We introduce the KAT model, which integrates GR-KANs into transformers for large-scale training scenarios like ImageNet, achieving significant performance improvements.
- Base Function: Replace B-spline to CUDA-implemented Rational.
- Group KAN: Share weights among groups of edges for efficiency.
- Initialization: Maintain activation magnitudes across layers.
- Release the KAT paper, CUDA implementation and IN-1k training code.
- 🎉🎉🎉🎉 Triton Implementation, on 1D and 2D tasks. This is much easier to install than the CUDA version. Please See https://github.com/Adamdad/rational_kat_cu.
- KAT Detection and segmentation code.
- KAT on NLP tasks.
Please find our CUDA implementation in https://github.com/Adamdad/rational_kat_cu.git.
# install torch and other things pip install timm==1.0.3 pip install wandb # I personally use wandb for results visualizations git clone https://github.com/Adamdad/rational_kat_cu.git cd rational_kat_cu pip install -e .📦 Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this script
│imagenet/ ├──train/ │ ├── n01440764 │ │ ├── n01440764_10026.JPEG │ │ ├── n01440764_10027.JPEG │ │ ├── ...... │ ├── ...... ├──val/ │ ├── n01440764 │ │ ├── ILSVRC2012_val_00000293.JPEG │ │ ├── ILSVRC2012_val_00002138.JPEG │ │ ├── ...... │ ├── ...... Refer to example.py for a detailed use case demonstrating how to use KAT with timm to classify an image.
Download pre-trained models or access training checkpoints:
| 🏷️ Model | ⚙️ Setup | 📦 Param | 📈 Top1 | 🔗 Link |
|---|---|---|---|---|
| KAT-T | From Scratch | 5.7M | 74.6 | link/huggingface |
| KAT-T | From ViT | 5.7M | 75.7 | link/huggingface |
| KAT-S | From Scratch | 22.1M | 81.2 | link/huggingface |
| KAT-S | From ViT | 22.1M | 82.0 | link/huggingface |
| KAT-B | From Scratch | 86.6M | 82.3 | link/huggingface |
| KAT-B | From ViT | 86.6M | 82.8 | link/huggingface |
All training scripts are under scripts/
bash scripts/train_kat_tiny_8x128.shIf you want to change the hyper-parameters, can edit
#!/bin/bash DATA_PATH=/local_home/dataset/imagenet/ bash ./dist_train.sh 8 $DATA_PATH \ --model kat_tiny_swish_patch16_224 \ # Rationals are initialized to be swish functions -b 128 \ --opt adamw \ --lr 1e-3 \ --weight-decay 0.05 \ --epochs 300 \ --mixup 0.8 \ --cutmix 1.0 \ --sched cosine \ --smoothing 0.1 \ --drop-path 0.1 \ --aa rand-m9-mstd0.5 \ --remode pixel --reprob 0.25 \ --amp \ --crop-pct 0.875 \ --mean 0.485 0.456 0.406 \ --std 0.229 0.224 0.225 \ --model-ema \ --model-ema-decay 0.9999 \ --output output/kat_tiny_swish_patch16_224 \ --log-wandbTo evaluate our kat_tiny_patch16_224 models, run:
DATA_PATH=/local_home/dataset/imagenet/ CHECKPOINT_PATH=kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth python validate.py $DATA_PATH --model kat_tiny_patch16_224 \ --checkpoint $CHECKPOINT_PATH -b 512 ################### Validating in float32. AMP not enabled. Loaded state_dict from checkpoint 'kat_tiny_patch16_224_1f3ad3b2e69821f3d412f2924cf159a0e266f142d739cb68f68f796f5a0fe289.pth' Model kat_tiny_patch16_224 created, param count: 5718328 Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.875 crop_mode: center Test: [ 0/98] Time: 3.453s (3.453s, 148.28/s) Loss: 0.6989 (0.6989) Acc@1: 84.375 ( 84.375) Acc@5: 96.875 ( 96.875) ....... Test: [ 90/98] Time: 0.212s (0.592s, 864.23/s) Loss: 1.1640 (1.1143) Acc@1: 71.875 ( 74.270) Acc@5: 93.750 ( 92.220) * Acc@1 74.558 (25.442) Acc@5 92.390 (7.610) --result { "model": "kat_tiny_patch16_224", "top1": 74.558, "top1_err": 25.442, "top5": 92.39, "top5_err": 7.61, "param_count": 5.72, "img_size": 224, "crop_pct": 0.875, "interpolation": "bicubic" }We extend our gratitude to the authors of rational_activations for their contributions to CUDA rational function implementations that inspired parts of this work. We thank @yuweihao, @florinshen, @Huage001 and @yu-rp for valuable discussions.
If you use this repository, please cite:
@inproceedings{ yang2025kolmogorovarnold, title={Kolmogorov-Arnold Transformer}, author={Xingyi Yang, Xinchao Wang}, booktitle={The Thirteenth International Conference on Learning Representations}, year={2025}, url={https://openreview.net/forum?id=BCeock53nt} }

