Skip to content

Commit 8768c31

Browse files
Aaryan-549claude
andcommitted
Fix pedestal model tests to use ModelConfig API and add pedestal module to __init__.py
- Update register_model_test.py to test model instantiation through ToraxConfig.from_dict() instead of direct class instantiation - Add pedestal module import to torax/__init__.py for better API accessibility - Tests now properly verify that registered models work through the pydantic discriminator system Addresses maintainer feedback on PR. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent e353edd commit 8768c31

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

torax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pylint: disable=g-importing-member
2121
import jax
22+
from torax import pedestal
2223
from torax import transport
2324
from torax._src import version
2425
from torax._src.config.config_loader import build_torax_config_from_file
@@ -48,6 +49,7 @@
4849
__version_info__ = version.TORAX_VERSION_INFO
4950

5051
__all__ = [
52+
'pedestal',
5153
'transport',
5254
'build_torax_config_from_file',
5355
'import_module',

torax/_src/pedestal_model/tests/register_model_test.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torax._src.pedestal_model import pydantic_model
2424
from torax._src.pedestal_model import register_model
2525
from torax._src.pedestal_model import runtime_params as pedestal_runtime_params
26+
from torax._src.torax_pydantic import model_config
2627
from torax._src.torax_pydantic import torax_pydantic
2728
import jax.numpy as jnp
2829

@@ -85,16 +86,31 @@ def build_runtime_params(
8586
# Register the model
8687
register_model.register_pedestal_model(TestPedestal)
8788

88-
# Verify it can be instantiated
89-
config = TestPedestal(test_value=99.0)
90-
self.assertEqual(config.test_value, 99.0)
91-
self.assertEqual(config.model_name, 'test_pedestal')
89+
# Verify it can be instantiated through the ModelConfig API
90+
# Create a minimal ToraxConfig using from_dict to test the registration
91+
minimal_config_dict = {
92+
'profile_conditions': {},
93+
'numerics': {},
94+
'plasma_composition': {},
95+
'geometry': {'geometry_type': 'circular'},
96+
'sources': {},
97+
'pedestal': {
98+
'model_name': 'test_pedestal',
99+
'test_value': 99.0,
100+
},
101+
}
102+
torax_config = model_config.ToraxConfig.from_dict(minimal_config_dict)
103+
104+
# Verify the pedestal config is the correct type
105+
self.assertIsInstance(torax_config.pedestal, TestPedestal)
106+
self.assertEqual(torax_config.pedestal.test_value, 99.0)
107+
self.assertEqual(torax_config.pedestal.model_name, 'test_pedestal')
92108

93109
# Verify it can build the model and runtime params
94-
model = config.build_pedestal_model()
110+
model = torax_config.pedestal.build_pedestal_model()
95111
self.assertIsInstance(model, TestPedestalModel)
96112

97-
runtime_params = config.build_runtime_params(t=0.0)
113+
runtime_params = torax_config.pedestal.build_runtime_params(t=0.0)
98114
self.assertEqual(runtime_params.test_value, 99.0)
99115

100116

@@ -163,12 +179,34 @@ def build_runtime_params(self, t):
163179
register_model.register_pedestal_model(Config1)
164180
register_model.register_pedestal_model(Config2)
165181

166-
# Verify both can be instantiated
167-
config1 = Config1()
168-
config2 = Config2()
169-
170-
self.assertEqual(config1.model_name, 'model1')
171-
self.assertEqual(config2.model_name, 'model2')
182+
# Verify both can be instantiated through the ModelConfig API
183+
minimal_config_dict_1 = {
184+
'profile_conditions': {},
185+
'numerics': {},
186+
'plasma_composition': {},
187+
'geometry': {'geometry_type': 'circular'},
188+
'sources': {},
189+
'pedestal': {
190+
'model_name': 'model1',
191+
},
192+
}
193+
torax_config_1 = model_config.ToraxConfig.from_dict(minimal_config_dict_1)
194+
self.assertIsInstance(torax_config_1.pedestal, Config1)
195+
self.assertEqual(torax_config_1.pedestal.model_name, 'model1')
196+
197+
minimal_config_dict_2 = {
198+
'profile_conditions': {},
199+
'numerics': {},
200+
'plasma_composition': {},
201+
'geometry': {'geometry_type': 'circular'},
202+
'sources': {},
203+
'pedestal': {
204+
'model_name': 'model2',
205+
},
206+
}
207+
torax_config_2 = model_config.ToraxConfig.from_dict(minimal_config_dict_2)
208+
self.assertIsInstance(torax_config_2.pedestal, Config2)
209+
self.assertEqual(torax_config_2.pedestal.model_name, 'model2')
172210

173211

174212
if __name__ == '__main__':

0 commit comments

Comments
 (0)