Skip to content

Commit 4af61b5

Browse files
Correct gradients for minimize
1 parent 92db737 commit 4af61b5

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

pytensor/tensor/optimize.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from scipy.optimize import root as scipy_root
66

77
from pytensor import Variable, function, graph_replace
8-
from pytensor.gradient import grad, jacobian
8+
from pytensor.gradient import DisconnectedType, grad, jacobian
99
from pytensor.graph import Apply, Constant, FunctionGraph
1010
from pytensor.graph.basic import truncated_graph_inputs
1111
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1212
from pytensor.scalar import bool as scalar_bool
13-
from pytensor.tensor.basic import atleast_2d
13+
from pytensor.tensor.basic import atleast_2d, concatenate
1414
from pytensor.tensor.slinalg import solve
1515
from pytensor.tensor.variable import TensorVariable
1616

@@ -146,18 +146,40 @@ def L_op(self, inputs, outputs, output_grads):
146146
inner_x, *inner_args = self.fgraph.inputs
147147
inner_fx = self.fgraph.outputs[0]
148148

149-
inner_grads = grad(inner_fx, [inner_x, *inner_args])
149+
implicit_f = grad(inner_fx, inner_x)
150150

151-
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
151+
df_dx = atleast_2d(concatenate(jacobian(implicit_f, [inner_x]), axis=-1))
152152

153-
grad_f_wrt_x_star, *grad_f_wrt_args = graph_replace(
154-
inner_grads, replace=replace
153+
df_dtheta = concatenate(
154+
[
155+
atleast_2d(x, left=False)
156+
for x in jacobian(implicit_f, inner_args, disconnected_inputs="ignore")
157+
],
158+
axis=-1,
155159
)
156160

157-
grad_wrt_args = [
158-
-grad_f_wrt_arg / grad_f_wrt_x_star * output_grad
159-
for grad_f_wrt_arg in grad_f_wrt_args
160-
]
161+
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
162+
163+
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace)
164+
165+
grad_wrt_args_vector = solve(-df_dtheta_star, df_dx_star)
166+
167+
cursor = 0
168+
grad_wrt_args = []
169+
170+
for output_grad, arg in zip(output_grads, args, strict=True):
171+
arg_shape = arg.shape
172+
arg_size = arg_shape.prod()
173+
arg_grad = grad_wrt_args_vector[cursor : cursor + arg_size].reshape(
174+
arg_shape
175+
)
176+
177+
grad_wrt_args.append(
178+
arg_grad * output_grad
179+
if not isinstance(output_grad.type, DisconnectedType)
180+
else DisconnectedType()
181+
)
182+
cursor += arg_size
161183

162184
return [x.zeros_like(), *grad_wrt_args]
163185

tests/tensor/test_optimize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def test_simple_minimize():
2020
out = (x - b * c) ** 2
2121

2222
minimized_x, success = minimize(out, x)
23-
minimized_x.dprint()
2423

2524
a_val = 2.0
2625
c_val = 3.0

0 commit comments

Comments
 (0)