|
5 | 5 | from scipy.optimize import root as scipy_root
|
6 | 6 |
|
7 | 7 | from pytensor import Variable, function, graph_replace
|
8 |
| -from pytensor.gradient import grad, jacobian |
| 8 | +from pytensor.gradient import DisconnectedType, grad, jacobian |
9 | 9 | from pytensor.graph import Apply, Constant, FunctionGraph
|
10 | 10 | from pytensor.graph.basic import truncated_graph_inputs
|
11 | 11 | from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
|
12 | 12 | from pytensor.scalar import bool as scalar_bool
|
13 |
| -from pytensor.tensor.basic import atleast_2d |
| 13 | +from pytensor.tensor.basic import atleast_2d, concatenate |
14 | 14 | from pytensor.tensor.slinalg import solve
|
15 | 15 | from pytensor.tensor.variable import TensorVariable
|
16 | 16 |
|
@@ -146,18 +146,40 @@ def L_op(self, inputs, outputs, output_grads):
|
146 | 146 | inner_x, *inner_args = self.fgraph.inputs
|
147 | 147 | inner_fx = self.fgraph.outputs[0]
|
148 | 148 |
|
149 |
| - inner_grads = grad(inner_fx, [inner_x, *inner_args]) |
| 149 | + implicit_f = grad(inner_fx, inner_x) |
150 | 150 |
|
151 |
| - replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True)) |
| 151 | + df_dx = atleast_2d(concatenate(jacobian(implicit_f, [inner_x]), axis=-1)) |
152 | 152 |
|
153 |
| - grad_f_wrt_x_star, *grad_f_wrt_args = graph_replace( |
154 |
| - inner_grads, replace=replace |
| 153 | + df_dtheta = concatenate( |
| 154 | + [ |
| 155 | + atleast_2d(x, left=False) |
| 156 | + for x in jacobian(implicit_f, inner_args, disconnected_inputs="ignore") |
| 157 | + ], |
| 158 | + axis=-1, |
155 | 159 | )
|
156 | 160 |
|
157 |
| - grad_wrt_args = [ |
158 |
| - -grad_f_wrt_arg / grad_f_wrt_x_star * output_grad |
159 |
| - for grad_f_wrt_arg in grad_f_wrt_args |
160 |
| - ] |
| 161 | + replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True)) |
| 162 | + |
| 163 | + df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], replace=replace) |
| 164 | + |
| 165 | + grad_wrt_args_vector = solve(-df_dtheta_star, df_dx_star) |
| 166 | + |
| 167 | + cursor = 0 |
| 168 | + grad_wrt_args = [] |
| 169 | + |
| 170 | + for output_grad, arg in zip(output_grads, args, strict=True): |
| 171 | + arg_shape = arg.shape |
| 172 | + arg_size = arg_shape.prod() |
| 173 | + arg_grad = grad_wrt_args_vector[cursor : cursor + arg_size].reshape( |
| 174 | + arg_shape |
| 175 | + ) |
| 176 | + |
| 177 | + grad_wrt_args.append( |
| 178 | + arg_grad * output_grad |
| 179 | + if not isinstance(output_grad.type, DisconnectedType) |
| 180 | + else DisconnectedType() |
| 181 | + ) |
| 182 | + cursor += arg_size |
161 | 183 |
|
162 | 184 | return [x.zeros_like(), *grad_wrt_args]
|
163 | 185 |
|
|
0 commit comments