I have the following code that defines an abstract class and its final subclasse. The two classes are both subclasses of the equinox.Module class, which registers class attributes as the leaves of a PyTree container.
# === IMPORTS === from abc import ABC, abstractmethod import jax from jax.typing import ArrayLike import jax.numpy as jnp import equinox as eqx from quadax import quadgk jax.config.update("jax_enable_x64", True) class MyClass(eqx.Module): # Works if I toggle to MyClass(ABC) rtol = 1e-12 atol = 1e-12 param: ArrayLike def __init__(self): self.param = self._integral_moment(3) # Fails, but works if I toggle to something like "self.param = self.func(1.)" @abstractmethod def func(self, tau): pass def func_abs(self, tau): return jnp.abs(self.func(tau)) def _integral_moment(self, order): return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0] def _integrand_moment(self, tau, order): return self.func_abs(tau) * jnp.abs(tau)**order class MySubClass(MyClass): gamma: ArrayLike kappa: ArrayLike w0: ArrayLike def __init__(self, gamma, kappa, w0): self.gamma = jnp.asarray(gamma) self.kappa = jnp.asarray(kappa) self.w0 = jnp.asarray(w0) super().__init__() def func(self, tau): return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2) # Test test = MySubClass(gamma=1., kappa=1., w0=1.) test.param This code produces the AttributeError message:
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[21], line 52 48 return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2) 51 # Test ---> 52 test = MySubClass(gamma=1., kappa=1., w0=1.) 53 test.param [... skipping hidden 2 frame] Cell In[21], line 45 43 self.kappa = jnp.asarray(kappa) 44 self.w0 = jnp.asarray(w0) ---> 45 super().__init__() Cell In[21], line 19 18 def __init__(self): ---> 19 self.param = self._integral_moment(3) [... skipping hidden 1 frame] Cell In[21], line 29 28 def _integral_moment(self, order): ---> 29 return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0] ... 659 and isinstance(out, types.MethodType) 660 and out.__self__ is self 661 ): AttributeError: 'MySubClass' object has no attribute 'param' This error clearly comes from a restriction of the equinox.Module, since if I change the parent class to ABC, the code runs fine.
First, I thought that maybe equinox did not allow me to use methods to initialize attributes. But if I use the func() method instead of the _integral_moment() method to initialize param, the code works fine.
So I just don't understand what is going on here. I thought it would be better to ask here before asking the developers at equinox.
This uses equinox version 0.13.1 with jax version 0.7.2.