Open
Description
Description
Trying to compile a graph vectorized with vectorize_graph
that contains an OpFromGraph
into numba mode results in an error:
X = pt.dmatrix("X", shape=(None, None))
X_batched = pt.tensor("X", shape=(None, None, None))
z = X + 1
results = OpFromGraph(
inputs=[X],
outputs=[z],
)(X)
z_vec = vectorize_graph(results, {X: X_batched})
fn = pytensor.function(
[X_batched],
[z_vec],
mode='NUMBA',
)
Full traceback
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[36], line 12
6 results = OpFromGraph(
7 inputs=[X],
8 outputs=[z],
9 )(X)
11 z_vec = vectorize_graph(results, {X: X_batched})
---> 12 fn = pytensor.function(
13 [X_batched],
14 [z_vec],
15 mode='NUMBA',
16 )
File ~/Documents/Python/pytensor/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
321 fn = orig_function(
322 inputs,
323 outputs,
(...) 327 trust_input=trust_input,
328 )
329 else:
330 # note: pfunc will also call orig_function -- orig_function is
331 # a choke point that all compilation must pass through
--> 332 fn = pfunc(
333 params=inputs,
334 outputs=outputs,
335 mode=mode,
336 updates=updates,
337 givens=givens,
338 no_default_updates=no_default_updates,
339 accept_inplace=accept_inplace,
340 name=name,
341 rebuild_strict=rebuild_strict,
342 allow_input_downcast=allow_input_downcast,
343 on_unused_input=on_unused_input,
344 profile=profile,
345 output_keys=output_keys,
346 trust_input=trust_input,
347 )
348 return fn
File ~/Documents/Python/pytensor/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
452 profile = ProfileStats(message=profile)
454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
455 params,
456 outputs,
(...) 463 fgraph=fgraph,
464 )
--> 466 return orig_function(
467 inputs,
468 cloned_outputs,
469 mode,
470 accept_inplace=accept_inplace,
471 name=name,
472 profile=profile,
473 on_unused_input=on_unused_input,
474 output_keys=output_keys,
475 fgraph=fgraph,
476 trust_input=trust_input,
477 )
File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1835, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
1822 m = Maker(
1823 inputs,
1824 outputs,
(...) 1832 trust_input=trust_input,
1833 )
1834 with config.change_flags(compute_test_value="off"):
-> 1835 fn = m.create(defaults)
1836 finally:
1837 if profile and fn:
File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1719, in FunctionMaker.create(self, input_storage, storage_map)
1716 start_import_time = pytensor.link.c.cmodule.import_time
1718 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1719 _fn, _i, _o = self.linker.make_thunk(
1720 input_storage=input_storage_lists, storage_map=storage_map
1721 )
1723 end_linker = time.perf_counter()
1725 linker_time = end_linker - start_linker
File ~/Documents/Python/pytensor/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
238 def make_thunk(
239 self,
240 input_storage: Optional["InputStorageType"] = None,
(...) 243 **kwargs,
244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245 return self.make_all(
246 input_storage=input_storage,
247 output_storage=output_storage,
248 storage_map=storage_map,
249 )[:3]
File ~/Documents/Python/pytensor/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
692 for k in storage_map:
693 compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
696 compute_map, nodes, input_storage, output_storage, storage_map
697 )
699 [fn] = thunks
700 fn.jit_fn = jit_fn
File ~/Documents/Python/pytensor/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
File ~/Documents/Python/pytensor/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
7 def fgraph_convert(self, fgraph, **kwargs):
8 from pytensor.link.numba.dispatch import numba_funcify
---> 10 return numba_funcify(fgraph, **kwargs)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/basic.py:379, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
372 @numba_funcify.register(FunctionGraph)
373 def numba_funcify_FunctionGraph(
374 fgraph,
(...) 377 **kwargs,
378 ):
--> 379 return fgraph_to_python(
380 fgraph,
381 numba_funcify,
382 type_conversion_fn=numba_typify,
383 fgraph_name=fgraph_name,
384 **kwargs,
385 )
File ~/Documents/Python/pytensor/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
734 body_assigns = []
735 for node in order:
--> 736 compiled_func = op_conversion_fn(
737 node.op, node=node, storage_map=storage_map, **kwargs
738 )
740 # Create a local alias with a unique name
741 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/blockwise.py:31, in numba_funcify_Blockwise(op, node, **kwargs)
26 core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
28 core_node = blockwise_op._create_dummy_core_node(
29 cast(tuple[TensorVariable], blockwise_node.inputs)
30 )
---> 31 core_op_fn = numba_funcify(
32 core_op,
33 node=core_node,
34 parent_node=node,
35 **kwargs,
36 )
37 core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
39 batch_ndim = blockwise_op.batch_ndim(node)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/basic.py:355, in numba_funcify_OpFromGraph(op, node, **kwargs)
349 add_supervisor_to_fgraph(
350 fgraph=fgraph,
351 input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
352 accept_inplace=True,
353 )
354 NUMBA.optimizer(fgraph)
--> 355 fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
357 if len(op.fgraph.outputs) == 1:
359 @numba_njit
360 def opfromgraph(*inputs):
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/basic.py:379, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
372 @numba_funcify.register(FunctionGraph)
373 def numba_funcify_FunctionGraph(
374 fgraph,
(...) 377 **kwargs,
378 ):
--> 379 return fgraph_to_python(
380 fgraph,
381 numba_funcify,
382 type_conversion_fn=numba_typify,
383 fgraph_name=fgraph_name,
384 **kwargs,
385 )
File ~/Documents/Python/pytensor/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
734 body_assigns = []
735 for node in order:
--> 736 compiled_func = op_conversion_fn(
737 node.op, node=node, storage_map=storage_map, **kwargs
738 )
740 # Create a local alias with a unique name
741 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/elemwise.py:270, in numba_funcify_Elemwise(op, node, **kwargs)
267 scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
268 scalar_node = op.scalar_op.make_node(*scalar_inputs)
--> 270 scalar_op_fn = numba_funcify(
271 op.scalar_op,
272 node=scalar_node,
273 parent_node=node,
274 **kwargs,
275 )
277 nin = len(node.inputs)
278 nout = len(node.outputs)
TypeError: pytensor.link.numba.dispatch.basic.numba_funcify() got multiple values for keyword argument 'parent_node'