- Notifications
You must be signed in to change notification settings - Fork 19.7k
Description
Bug Report: keras.optimizers.Muon Fails with AttributeError on variable.path in Keras 3 / TF 2.16+
1. Bug Description
The keras.optimizers.Muon optimizer is currently incompatible with modern Keras 3 and TensorFlow (2.16+) environments. When a model is compiled with the Muon optimizer and training begins, it fails during the first training step with the error: AttributeError: 'ResourceVariable' object has no attribute 'path'.
The issue stems from the internal implementation of the optimizer's _adamw_update_step method, which attempts to access a .path attribute on a TensorFlow ResourceVariable. This attribute no longer exists on variables in recent versions of TensorFlow when used with the Keras 3 backend abstraction layer.
This bug makes the Muon optimizer unusable in its current state for anyone using an up-to-date TensorFlow backend.
2. Reproducible Example
The following minimal code consistently reproduces the error. It uses a simple Sequential model, dummy data, and the Muon optimizer. The error occurs immediately upon calling model.fit().
import tensorflow as tf import keras import numpy as np print(f"Keras version: {keras.__version__}") print(f"TensorFlow version: {tf.__version__}") # 1. Define a simple model # A named layer is included to demonstrate the `exclude_layers` argument. model = keras.Sequential([ keras.layers.Input(shape=(10,)), keras.layers.Dense(5), keras.layers.Dense(1, name="last") ]) # 2. Create dummy data x_train = np.random.rand(100, 10).astype(np.float32) y_train = np.random.rand(100, 1).astype(np.float32) # 3. Instantiate the Muon optimizer # The error occurs with or without `exclude_layers`. optimizer = keras.optimizers.Muon(learning_rate=1e-3, exclude_layers=["last"]) # 4. Compile the model and attempt to fit model.compile(optimizer=optimizer, loss='mse') print("\nAttempting to start model.fit()... The error is expected here.") try: model.fit(x_train, y_train, epochs=1, batch_size=10) except AttributeError as e: print(f"\nSuccessfully reproduced the error: {e}")