1+ import warnings
12from collections .abc import Sequence
23from copy import copy
34from textwrap import dedent
1920from pytensor .misc .frozendict import frozendict
2021from pytensor .printing import Printer , pprint
2122from pytensor .scalar import get_scalar_type
23+ from pytensor .scalar .basic import Composite , transfer_type , upcast
2224from pytensor .scalar .basic import bool as scalar_bool
2325from pytensor .scalar .basic import identity as scalar_identity
24- from pytensor .scalar .basic import transfer_type , upcast
2526from pytensor .tensor import elemwise_cgen as cgen
2627from pytensor .tensor import get_vector_length
2728from pytensor .tensor .basic import _get_vector_length , as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364365 self .name = name
365366 self .scalar_op = scalar_op
366367 self .inplace_pattern = inplace_pattern
368+ self .ufunc = None
367369 self .destroy_map = {o : [i ] for o , i in self .inplace_pattern .items ()}
368370
369371 if nfunc_spec is None :
@@ -375,14 +377,13 @@ def __init__(
375377 def __getstate__ (self ):
376378 d = copy (self .__dict__ )
377379 d .pop ("ufunc" )
378- d .pop ("nfunc" )
379380 d .pop ("__epydoc_asRoutine" , None )
380381 return d
381382
382383 def __setstate__ (self , d ):
384+ d .pop ("nfunc" , None ) # This used to be stored in the Op, not anymore
383385 super ().__setstate__ (d )
384386 self .ufunc = None
385- self .nfunc = None
386387 self .inplace_pattern = frozendict (self .inplace_pattern )
387388
388389 def get_output_info (self , * inputs ):
@@ -623,31 +624,47 @@ def transform(r):
623624
624625 return ret
625626
626- def prepare_node (self , node , storage_map , compute_map , impl ):
627- # Postpone the ufunc building to the last minutes due to:
628- # - NumPy ufunc support only up to 32 operands (inputs and outputs)
629- # But our c code support more.
630- # - nfunc is reused for scipy and scipy is optional
631- if (len (node .inputs ) + len (node .outputs )) > 32 and impl == "py" :
632- impl = "c"
633-
634- if getattr (self , "nfunc_spec" , None ) and impl != "c" :
635- self .nfunc = import_func_from_string (self .nfunc_spec [0 ])
636-
627+ def _create_node_ufunc (self , node ) -> None :
637628 if (
638- ( len ( node . inputs ) + len ( node . outputs )) <= 32
639- and ( self . nfunc is None or self . scalar_op . nin != len ( node . inputs ))
640- and self . ufunc is None
641- and impl == "py"
629+ self . nfunc_spec is not None
630+ # Some scalar Ops like `Add` allow for a variable number of inputs,
631+ # whereas the numpy counterpart does not.
632+ and len ( node . inputs ) == self . nfunc_spec [ 1 ]
642633 ):
634+ ufunc = import_func_from_string (self .nfunc_spec [0 ])
635+ if ufunc is None :
636+ raise ValueError (
637+ f"Could not import ufunc { self .nfunc_spec [0 ]} for { self } "
638+ )
639+
640+ elif self .ufunc is not None :
641+ # Cached before
642+ ufunc = self .ufunc
643+
644+ else :
645+ if (len (node .inputs ) + len (node .outputs )) > 32 :
646+ if isinstance (self .scalar_op , Composite ):
647+ warnings .warn (
648+ "Trying to create a Python Composite Elemwise function with more than 32 operands.\n "
649+ "This operation should not have been introduced if the C-backend is not properly setup. "
650+ 'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n '
651+ "Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
652+ '`pytensor.config.mode = "NUMBA" (or "JAX").'
653+ )
654+ else :
655+ warnings .warn (
656+ f"Trying to create a Python Elemwise function for the scalar Op { self .scalar_op } "
657+ f"with more than 32 operands. This will likely fail."
658+ )
659+
643660 ufunc = np .frompyfunc (
644661 self .scalar_op .impl , len (node .inputs ), self .scalar_op .nout
645662 )
646- if self .scalar_op .nin > 0 :
647- # We can reuse it for many nodes
663+ if self .scalar_op .nin > 0 : # Default in base class is -1
664+ # Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648665 self .ufunc = ufunc
649- else :
650- node .tag .ufunc = ufunc
666+
667+ node .tag .ufunc = ufunc
651668
652669 # Numpy ufuncs will sometimes perform operations in
653670 # float16, in particular when the input is int8.
@@ -660,15 +677,23 @@ def prepare_node(self, node, storage_map, compute_map, impl):
660677
661678 # NumPy 1.10.1 raise an error when giving the signature
662679 # when the input is complex. So add it only when inputs is int.
663- out_dtype = node . outputs [ 0 ]. dtype
680+ ufunc_kwargs = {}
664681 if (
665- out_dtype in float_dtypes
666- and isinstance ( self . nfunc , np . ufunc )
682+ isinstance ( ufunc , np . ufunc )
683+ # TODO: Why check for the dtype of the first input only?
667684 and node .inputs [0 ].dtype in discrete_dtypes
685+ and len (node .outputs ) == 1
686+ and node .outputs [0 ].dtype in float_dtypes
668687 ):
669- char = np .sctype2char (out_dtype )
670- sig = char * node .nin + "->" + char * node .nout
671- node .tag .sig = sig
688+ char = np .sctype2char (node .outputs [0 ].dtype )
689+ ufunc_kwargs ["sig" ] = char * node .nin + "->" + char * node .nout
690+
691+ node .tag .ufunc_kwargs = ufunc_kwargs
692+
693+ def prepare_node (self , node , storage_map , compute_map , impl ):
694+ if impl == "py" :
695+ self ._create_node_ufunc (node )
696+
672697 node .tag .fake_node = Apply (
673698 self .scalar_op ,
674699 [
@@ -684,71 +709,32 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684709 self .scalar_op .prepare_node (node .tag .fake_node , None , None , impl )
685710
686711 def perform (self , node , inputs , output_storage ):
687- if (len (node .inputs ) + len (node .outputs )) > 32 :
688- # Some versions of NumPy will segfault, other will raise a
689- # ValueError, if the number of operands in an ufunc is more than 32.
690- # In that case, the C version should be used, or Elemwise fusion
691- # should be disabled.
692- # FIXME: This no longer calls the C implementation!
693- super ().perform (node , inputs , output_storage )
712+ ufunc = getattr (node .tag , "ufunc" , None )
713+ if ufunc is None :
714+ self ._create_node_ufunc (node )
715+ ufunc = node .tag .ufunc
694716
695717 self ._check_runtime_broadcast (node , inputs )
696718
697- ufunc_args = inputs
698- ufunc_kwargs = {}
699- # We supported in the past calling manually op.perform.
700- # To keep that support we need to sometimes call self.prepare_node
701- if self .nfunc is None and self .ufunc is None :
702- self .prepare_node (node , None , None , "py" )
703- if self .nfunc and len (inputs ) == self .nfunc_spec [1 ]:
704- ufunc = self .nfunc
705- nout = self .nfunc_spec [2 ]
706- if hasattr (node .tag , "sig" ):
707- ufunc_kwargs ["sig" ] = node .tag .sig
708- # Unfortunately, the else case does not allow us to
709- # directly feed the destination arguments to the nfunc
710- # since it sometimes requires resizing. Doing this
711- # optimization is probably not worth the effort, since we
712- # should normally run the C version of the Op.
713- else :
714- # the second calling form is used because in certain versions of
715- # numpy the first (faster) version leads to segfaults
716- if self .ufunc :
717- ufunc = self .ufunc
718- elif not hasattr (node .tag , "ufunc" ):
719- # It happen that make_thunk isn't called, like in
720- # get_underlying_scalar_constant_value
721- self .prepare_node (node , None , None , "py" )
722- # prepare_node will add ufunc to self or the tag
723- # depending if we can reuse it or not. So we need to
724- # test both again.
725- if self .ufunc :
726- ufunc = self .ufunc
727- else :
728- ufunc = node .tag .ufunc
729- else :
730- ufunc = node .tag .ufunc
731-
732- nout = ufunc .nout
733-
734- variables = ufunc (* ufunc_args , ** ufunc_kwargs )
719+ outputs = ufunc (* inputs , ** node .tag .ufunc_kwargs )
735720
736- if nout == 1 :
737- variables = [ variables ]
721+ if not isinstance ( outputs , tuple ) :
722+ outputs = ( outputs ,)
738723
739- for i , (variable , storage , nout ) in enumerate (
740- zip (variables , output_storage , node .outputs )
724+ for i , (out , out_storage , node_out ) in enumerate (
725+ zip (outputs , output_storage , node .outputs )
741726 ):
742- storage [0 ] = variable = np .asarray (variable , dtype = nout .dtype )
727+ # Numpy frompyfunc always returns object arrays
728+ out_storage [0 ] = out = np .asarray (out , dtype = node_out .dtype )
743729
744730 if i in self .inplace_pattern :
745- odat = inputs [self .inplace_pattern [i ]]
746- odat [...] = variable
747- storage [0 ] = odat
731+ inp = inputs [self .inplace_pattern [i ]]
732+ inp [...] = out
733+ out_storage [0 ] = inp
748734
749735 # numpy.real return a view!
750- if not variable .flags .owndata :
751- storage [0 ] = variable .copy ()
736+ if not out .flags .owndata :
737+ out_storage [0 ] = out .copy ()
752738
753739 @staticmethod
754740 def _check_runtime_broadcast (node , inputs ):
0 commit comments