2222from typing import Callable , Optional , Sequence , Tuple , Union
2323
2424import numpy as np
25+ import opcode
2526
2627from aesara import tensor as at
2728from aesara .compile .builders import OpFromGraph
@@ -164,6 +165,45 @@ def fn(*args, **kwargs):
164165 return fn
165166
166167
168+ # Helper function from pyprob
169+ def _extract_target_of_assignment (depth ):
170+ frame = sys ._getframe (depth )
171+ code = frame .f_code
172+ next_instruction = code .co_code [frame .f_lasti + 2 ]
173+ instruction_arg = code .co_code [frame .f_lasti + 3 ]
174+ instruction_name = opcode .opname [next_instruction ]
175+ if instruction_name == "STORE_FAST" :
176+ return code .co_varnames [instruction_arg ]
177+ elif instruction_name in ["STORE_NAME" , "STORE_GLOBAL" ]:
178+ return code .co_names [instruction_arg ]
179+ elif (
180+ instruction_name in ["LOAD_FAST" , "LOAD_NAME" , "LOAD_GLOBAL" ]
181+ and opcode .opname [code .co_code [frame .f_lasti + 4 ]] in ["LOAD_CONST" , "LOAD_FAST" ]
182+ and opcode .opname [code .co_code [frame .f_lasti + 6 ]] == "STORE_SUBSCR"
183+ ):
184+ if instruction_name == "LOAD_FAST" :
185+ base_name = code .co_varnames [instruction_arg ]
186+ else :
187+ base_name = code .co_names [instruction_arg ]
188+
189+ second_instruction = opcode .opname [code .co_code [frame .f_lasti + 4 ]]
190+ second_arg = code .co_code [frame .f_lasti + 5 ]
191+ if second_instruction == "LOAD_CONST" :
192+ value = code .co_consts [second_arg ]
193+ elif second_instruction == "LOAD_FAST" :
194+ var_name = code .co_varnames [second_arg ]
195+ value = frame .f_locals [var_name ]
196+ else :
197+ value = None
198+ if value is not None :
199+ index_name = repr (value )
200+ return base_name + "[" + index_name + "]"
201+ else :
202+ return None
203+ else :
204+ return None
205+
206+
167207class SymbolicRandomVariable (OpFromGraph ):
168208 """Symbolic Random Variable
169209
@@ -216,7 +256,6 @@ class Distribution(metaclass=DistributionMeta):
216256
217257 def __new__ (
218258 cls ,
219- name : str ,
220259 * args ,
221260 rng = None ,
222261 dims : Optional [Dims ] = None ,
@@ -234,8 +273,6 @@ def __new__(
234273 ----------
235274 cls : type
236275 A PyMC distribution.
237- name : str
238- Name for the new model variable.
239276 rng : optional
240277 Random number generator to use with the RandomVariable.
241278 dims : tuple, optional
@@ -277,6 +314,19 @@ def __new__(
277314 "for a standalone distribution."
278315 )
279316
317+ if "name" in kwargs :
318+ name = kwargs .pop ("name" )
319+ elif len (args ) > 0 and isinstance (args [0 ], string_types ):
320+ name = args [0 ]
321+ args = args [1 :]
322+ else :
323+ name = _extract_target_of_assignment (2 )
324+ if name is None :
325+ raise TypeError ("Name could not be inferred for variable" )
326+
327+ if not isinstance (name , string_types ):
328+ raise TypeError (f"Name needs to be a string but got: { name } " )
329+
280330 if "testval" in kwargs :
281331 initval = kwargs .pop ("testval" )
282332 warnings .warn (
@@ -285,9 +335,6 @@ def __new__(
285335 stacklevel = 2 ,
286336 )
287337
288- if not isinstance (name , string_types ):
289- raise TypeError (f"Name needs to be a string but got: { name } " )
290-
291338 dims = convert_dims (dims )
292339 if observed is not None :
293340 observed = convert_observed_data (observed )
0 commit comments