From 0e67692a56dd5f302d4091e375a9a162fc77ef68 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 28 Jul 2020 14:25:00 -0700 Subject: [PATCH] Add explicit `wrt:` clause to `@derivative(of:)` attribute. Add explicit `wrt:` clause to `@derivative(of:)` attribute for `Dense.init(weight:bias:activation:)`. This is necessary now that `Optional` conforms to `Differentiable` so that `bias` is not inferred as a differentiability parameter. --- Sources/TensorFlow/Layers/Dense.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/TensorFlow/Layers/Dense.swift b/Sources/TensorFlow/Layers/Dense.swift index 1583c19f4..3a86c13cf 100644 --- a/Sources/TensorFlow/Layers/Dense.swift +++ b/Sources/TensorFlow/Layers/Dense.swift @@ -62,7 +62,7 @@ public struct Dense: Layer { } // TODO(TF-433): Remove custom derivative after `try_apply` differentiation is supported. - @derivative(of: init) + @derivative(of: init, wrt: weight) @usableFromInline static func vjpInit( weight: Tensor,