2

I have the following code that defines an abstract class and its final subclasse. The two classes are both subclasses of the equinox.Module class, which registers class attributes as the leaves of a PyTree container.

# === IMPORTS === from abc import ABC, abstractmethod import jax from jax.typing import ArrayLike import jax.numpy as jnp import equinox as eqx from quadax import quadgk jax.config.update("jax_enable_x64", True) class MyClass(eqx.Module): # Works if I toggle to MyClass(ABC) rtol = 1e-12 atol = 1e-12 param: ArrayLike def __init__(self): self.param = self._integral_moment(3) # Fails, but works if I toggle to something like "self.param = self.func(1.)" @abstractmethod def func(self, tau): pass def func_abs(self, tau): return jnp.abs(self.func(tau)) def _integral_moment(self, order): return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0] def _integrand_moment(self, tau, order): return self.func_abs(tau) * jnp.abs(tau)**order class MySubClass(MyClass): gamma: ArrayLike kappa: ArrayLike w0: ArrayLike def __init__(self, gamma, kappa, w0): self.gamma = jnp.asarray(gamma) self.kappa = jnp.asarray(kappa) self.w0 = jnp.asarray(w0) super().__init__() def func(self, tau): return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2) # Test test = MySubClass(gamma=1., kappa=1., w0=1.) test.param 

This code produces the AttributeError message:

--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[21], line 52 48 return self.gamma * jnp.exp(-1j * self.w0 * tau) * jnp.exp(-self.kappa*jnp.abs(tau)/2) 51 # Test ---> 52 test = MySubClass(gamma=1., kappa=1., w0=1.) 53 test.param [... skipping hidden 2 frame] Cell In[21], line 45 43 self.kappa = jnp.asarray(kappa) 44 self.w0 = jnp.asarray(w0) ---> 45 super().__init__() Cell In[21], line 19 18 def __init__(self): ---> 19 self.param = self._integral_moment(3) [... skipping hidden 1 frame] Cell In[21], line 29 28 def _integral_moment(self, order): ---> 29 return quadgk(self._integrand_moment, [0, jnp.inf], args=(order,), epsrel=self.rtol, epsabs=self.atol)[0] ... 659 and isinstance(out, types.MethodType) 660 and out.__self__ is self 661 ): AttributeError: 'MySubClass' object has no attribute 'param' 

This error clearly comes from a restriction of the equinox.Module, since if I change the parent class to ABC, the code runs fine.

First, I thought that maybe equinox did not allow me to use methods to initialize attributes. But if I use the func() method instead of the _integral_moment() method to initialize param, the code works fine.

So I just don't understand what is going on here. I thought it would be better to ask here before asking the developers at equinox.

This uses equinox version 0.13.1 with jax version 0.7.2.

2 Answers 2

1

The issue here is that when traced, eqx.Module attempts to access all the declared attributes of the Module, so the module cannot be traced before those attributes are created. Here's a simpler repro of the same problem:

import jax import equinox as eqx class MyClass(eqx.Module): param: ArrayLike def __init__(self): self.param = jax.jit(self.func)() def func(self): return 4 MyClass() # AttributeError: 'MyClass' object has no attribute 'param' 

The quadgk function traces its input, and since you call it before setting param, you get this error. With this issue in mind, you can fix your problem by setting the missing param to a placeholder value before you call a function that traces the object's methods:

class MyClass(eqx.Module): ... def __init__(self): self.param = 0 # set to a placeholder to allow tracing self.param = self._integral_moment(3) ... 
Sign up to request clarification or add additional context in comments.

Comments

0

To follow up on @jakevdp's answer, a completely equivalent but perhaps slightly more elegant way of systematically pre-empting this issue in equinox is to assign a value directly in the attribute definition:

class MyClass(eqx.Module): ... param: float = 0 # set to a placeholder to allow tracing def __init__(self): self.param = self._integral_moment(3) ... 

EDIT: Importantly, not that it is NOT allowed to initialize attributes as mutables or jax arrays at the class level in dataclasses like equinox modules, which raises a ValueError: Use default_factory. For the above code, all instances of the class will initially share the same instance object for the field, which is not desired behavior in a dataclass if the attribute can later be modified in some way. This is probably why the previous answer made that choice of initializing in init, which will always work.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.