-
Notifications
You must be signed in to change notification settings - Fork 137
Description
Description
Analyzing graphs with reshape operations is rather complex because Reshape represents what we want, but not "what it means"".
Except for esoteric cases where Reshape
shapes may come from a complex computation / shapes of other variables, it is usually a case of multiplying some dimensions (merging) and diving others (splitting). We could represent these cases with some sort of symbolic mapping:
x = tensor(shape=(4, 3, 2))
x.reshape(4, 6) # JoinDims(0, (1, 2))
It almost begs for an extension of DimShuffle
, which was brought up before: Theano/Theano#4640
Splitting dims is trickier, because there are many choices, we can split in different orders and sizes
x = tensor(shape=(12,))
x.reshape(2, 2, 3)
x.reshape(2, 3, 2)
x.reshape(4, 3)
...
Still an Op that achieves the same as splitting via reshape but knows which dims are going where (and in what quantities), would be more readable
An example where Reshape is currently hard to work with is during vectorization. If we have a common graph like reshape(x, x.shape[0] * x.shape[1], -1)
we cannot return the desired output reshape(new_x, x.shape[0], x.shape[1] * x.shape[2], -1)
eagerly because there is a chain of complex operations we must vectorize before we get to the Reshape
node (Shape
-> Subtensor
-> Mul
-> MakeVector
). So we need to put it in a costly Blockwise and try our best to remove it during rewrites. This came up in #702 when vectorizing tensordot
to get a batched_tensordot
Such a problem wouldn't exist with a symbolic reshape that is told what dims are being joined/split.
It also makes rewrites to remove/lift reshapes much simpler than they currently are:
pytensor/pytensor/tensor/rewriting/shape.py
Lines 798 to 895 in bf73f8a
def local_useless_reshape(fgraph, node): | |
"""Remove two kinds of useless `Reshape`. | |
- Remove `Reshape` when both the input and output have a single dimension. | |
- Remove `Reshape` when reshaping to the shape of the input. | |
""" | |
inp = node.inputs[0] | |
output = node.outputs[0] | |
output_shape = node.inputs[1] | |
if inp.type.ndim != output.type.ndim: | |
return False | |
# Simple case: both input and output have a single dimension. | |
# TODO FIXME XXX: This could hide errors if the user provides inconsistent | |
# shapes. | |
if ( | |
inp.type.ndim == 1 | |
and output.type.ndim == 1 | |
and all( | |
s1 == s2 | |
for s1, s2 in zip(inp.type.shape, output.type.shape) | |
if s1 == 1 or s2 == 1 | |
) | |
): | |
return [inp] | |
# Second case: all the shapes match the input shape | |
# Match Reshape(x, x.shape) | |
if output_shape.owner and isinstance(output_shape.owner.op, Shape): | |
shape_input = output_shape.owner.inputs[0] | |
if shape_input == inp: | |
return [inp] | |
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for | |
# broadcastable and constant dimensions | |
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): | |
output_shape_is = output_shape.owner.inputs | |
shape_feature = getattr(fgraph, "shape_feature", None) | |
nb_m1 = 0 | |
shape_match = [False] * inp.type.ndim | |
for dim in range(inp.type.ndim): | |
outshp_i = output_shape_is[dim] | |
# Match Shape_i{dim}(input) | |
if ( | |
outshp_i.owner | |
and isinstance(outshp_i.owner.op, Shape_i) | |
and outshp_i.owner.op.i == dim | |
and outshp_i.owner.inputs[0] == inp | |
): | |
shape_match[dim] = True | |
continue | |
# Match Shape(input)[dim] | |
if ( | |
outshp_i.owner | |
and isinstance(outshp_i.owner.op, Subtensor) | |
and len(outshp_i.owner.inputs) == 2 | |
and extract_constant(outshp_i.owner.inputs[1]) == dim | |
): | |
subtensor_inp = outshp_i.owner.inputs[0] | |
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): | |
shape_input_i = subtensor_inp.owner.inputs[0] | |
if shape_input_i == inp: | |
shape_match[dim] = True | |
continue | |
# Match 1 if input.type.shape[dim] == 1 | |
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) | |
if inp.type.shape[dim] == 1 and cst_outshp_i == 1: | |
shape_match[dim] = True | |
continue | |
# Match -1 | |
if cst_outshp_i == -1: | |
shape_match[dim] = True | |
nb_m1 += 1 | |
continue | |
# Match shape_of[input][dim] or its constant equivalent | |
if shape_feature: | |
inpshp_i = shape_feature.get_shape(inp, dim) | |
if inpshp_i == outshp_i or ( | |
extract_constant(inpshp_i, only_process_constants=1) | |
== extract_constant(outshp_i, only_process_constants=1) | |
): | |
shape_match[dim] = True | |
continue | |
if all(shape_match) and nb_m1 <= 1: | |
return [inp] | |
# TODO later: if all the shapes except one match, we may want to | |
# consider it useless as well, like we do in the 1-dim case. | |
return False |
This is somewhat related to why we have Second
and Alloc
. The first one is easier to reason about because it tells us more immediately that we are broadcasting with the shape of a variable, whereas Alloc specifies the desired output without its meaning (specially after some rewrites, where the shape may become dissociated from the original variable)
pytensor/pytensor/tensor/rewriting/basic.py
Lines 3 to 23 in d62f4b1
Notes | |
----- | |
There are two ways of broadcasting arrays: | |
second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape)) | |
The second can be more efficient because x doesn't usually need to be computed when we only want its shape. | |
It may also allow other rewrites that don't try to modify x when it has multiple clients (for fear of duplicating computation). | |
However, the first one is easier to reason about. | |
Knowing we have such a graph allows to do certain rewrites such as "sinking" broadcasting operations below Elemwise. | |
The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one. | |
As an example contrast rewriting the following two equivalent graphs | |
alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y | |
second(y, x) + second(x, y) -> x + y | |
Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later, | |
via rewrites like `local_fill_to_alloc`, and using the `alloc_like` helper inside rewrites. | |
Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important. | |
""" |