|
23 | 23 | from torax._src.pedestal_model import pydantic_model |
24 | 24 | from torax._src.pedestal_model import register_model |
25 | 25 | from torax._src.pedestal_model import runtime_params as pedestal_runtime_params |
| 26 | +from torax._src.torax_pydantic import model_config |
26 | 27 | from torax._src.torax_pydantic import torax_pydantic |
27 | 28 | import jax.numpy as jnp |
28 | 29 |
|
@@ -85,16 +86,31 @@ def build_runtime_params( |
85 | 86 | # Register the model |
86 | 87 | register_model.register_pedestal_model(TestPedestal) |
87 | 88 |
|
88 | | - # Verify it can be instantiated |
89 | | - config = TestPedestal(test_value=99.0) |
90 | | - self.assertEqual(config.test_value, 99.0) |
91 | | - self.assertEqual(config.model_name, 'test_pedestal') |
| 89 | + # Verify it can be instantiated through the ModelConfig API |
| 90 | + # Create a minimal ToraxConfig using from_dict to test the registration |
| 91 | + minimal_config_dict = { |
| 92 | + 'profile_conditions': {}, |
| 93 | + 'numerics': {}, |
| 94 | + 'plasma_composition': {}, |
| 95 | + 'geometry': {'geometry_type': 'circular'}, |
| 96 | + 'sources': {}, |
| 97 | + 'pedestal': { |
| 98 | + 'model_name': 'test_pedestal', |
| 99 | + 'test_value': 99.0, |
| 100 | + }, |
| 101 | + } |
| 102 | + torax_config = model_config.ToraxConfig.from_dict(minimal_config_dict) |
| 103 | + |
| 104 | + # Verify the pedestal config is the correct type |
| 105 | + self.assertIsInstance(torax_config.pedestal, TestPedestal) |
| 106 | + self.assertEqual(torax_config.pedestal.test_value, 99.0) |
| 107 | + self.assertEqual(torax_config.pedestal.model_name, 'test_pedestal') |
92 | 108 |
|
93 | 109 | # Verify it can build the model and runtime params |
94 | | - model = config.build_pedestal_model() |
| 110 | + model = torax_config.pedestal.build_pedestal_model() |
95 | 111 | self.assertIsInstance(model, TestPedestalModel) |
96 | 112 |
|
97 | | - runtime_params = config.build_runtime_params(t=0.0) |
| 113 | + runtime_params = torax_config.pedestal.build_runtime_params(t=0.0) |
98 | 114 | self.assertEqual(runtime_params.test_value, 99.0) |
99 | 115 |
|
100 | 116 |
|
@@ -163,12 +179,34 @@ def build_runtime_params(self, t): |
163 | 179 | register_model.register_pedestal_model(Config1) |
164 | 180 | register_model.register_pedestal_model(Config2) |
165 | 181 |
|
166 | | - # Verify both can be instantiated |
167 | | - config1 = Config1() |
168 | | - config2 = Config2() |
169 | | - |
170 | | - self.assertEqual(config1.model_name, 'model1') |
171 | | - self.assertEqual(config2.model_name, 'model2') |
| 182 | + # Verify both can be instantiated through the ModelConfig API |
| 183 | + minimal_config_dict_1 = { |
| 184 | + 'profile_conditions': {}, |
| 185 | + 'numerics': {}, |
| 186 | + 'plasma_composition': {}, |
| 187 | + 'geometry': {'geometry_type': 'circular'}, |
| 188 | + 'sources': {}, |
| 189 | + 'pedestal': { |
| 190 | + 'model_name': 'model1', |
| 191 | + }, |
| 192 | + } |
| 193 | + torax_config_1 = model_config.ToraxConfig.from_dict(minimal_config_dict_1) |
| 194 | + self.assertIsInstance(torax_config_1.pedestal, Config1) |
| 195 | + self.assertEqual(torax_config_1.pedestal.model_name, 'model1') |
| 196 | + |
| 197 | + minimal_config_dict_2 = { |
| 198 | + 'profile_conditions': {}, |
| 199 | + 'numerics': {}, |
| 200 | + 'plasma_composition': {}, |
| 201 | + 'geometry': {'geometry_type': 'circular'}, |
| 202 | + 'sources': {}, |
| 203 | + 'pedestal': { |
| 204 | + 'model_name': 'model2', |
| 205 | + }, |
| 206 | + } |
| 207 | + torax_config_2 = model_config.ToraxConfig.from_dict(minimal_config_dict_2) |
| 208 | + self.assertIsInstance(torax_config_2.pedestal, Config2) |
| 209 | + self.assertEqual(torax_config_2.pedestal.model_name, 'model2') |
172 | 210 |
|
173 | 211 |
|
174 | 212 | if __name__ == '__main__': |
|
0 commit comments