Skip to content

Commit 7bd421f

Browse files
j-fudevmotion
authored andcommitted
Support specializing on functions (#615)
* specialize on function in jacobian * specilize on function parameters for derivative, gradient, hessian
1 parent 7dd4911 commit 7bd421f

File tree

5 files changed

+32
-33
lines changed

5 files changed

+32
-33
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticM
3737
end
3838

3939
# Gradient
40-
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
41-
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
42-
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
40+
@inline ForwardDiff.gradient(f::F, x::StaticArray) where F = vector_mode_gradient(f, x)
41+
@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig) where F = gradient(f, x)
42+
@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient(f, x)
4343

44-
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
45-
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
46-
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
44+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_gradient!(result, f, x)
45+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where F = gradient!(result, f, x)
46+
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where F = gradient!(result, f, x)
4747

4848
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
4949
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
@@ -65,13 +65,13 @@ end
6565
end
6666

6767
# Jacobian
68-
@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
69-
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
70-
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
68+
@inline ForwardDiff.jacobian(f::F, x::StaticArray) where F = vector_mode_jacobian(f, x)
69+
@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian(f, x)
70+
@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian(f, x)
7171

72-
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
73-
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
74-
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
72+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where F = vector_mode_jacobian!(result, f, x)
73+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where F = jacobian!(result, f, x)
74+
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where F = jacobian!(result, f, x)
7575

7676
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
7777
M, N = length(ydual), length(x)
@@ -110,18 +110,18 @@ end
110110
end
111111

112112
# Hessian
113-
ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
114-
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
115-
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
113+
ForwardDiff.hessian(f::F, x::StaticArray) where F = jacobian(y -> gradient(f, y), x)
114+
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig) where F = hessian(f, x)
115+
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian(f, x)
116116

117-
ForwardDiff.hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
117+
ForwardDiff.hessian!(result::AbstractArray, f::F, x::StaticArray) where F = jacobian!(result, y -> gradient(f, y), x)
118118

119-
ForwardDiff.hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
119+
ForwardDiff.hessian!(result::MutableDiffResult, f::F, x::StaticArray) where F = hessian!(result, f, x, HessianConfig(f, result, x))
120120

121-
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
122-
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
121+
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where F = hessian!(result, f, x)
122+
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where F = hessian!(result, f, x)
123123

124-
function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray)
124+
function ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where F
125125
T = typeof(Tag(f, eltype(x)))
126126
d1 = dualize(T, x)
127127
d2 = dualize(T, d1)

src/derivative.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ stored in `y`.
2222
2323
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
2424
"""
25-
@inline function derivative(f!, y::AbstractArray, x::Real,
26-
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
25+
@inline function derivative(f!::F, y::AbstractArray, x::Real,
26+
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
2727
require_one_based_indexing(y)
2828
CHK && checktag(T, f!, x)
2929
ydual = cfg.duals
@@ -60,8 +60,8 @@ called as `f!(y, x)` where the result is stored in `y`.
6060
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
6161
"""
6262
@inline function derivative!(result::Union{AbstractArray,DiffResult},
63-
f!, y::AbstractArray, x::Real,
64-
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
63+
f!::F, y::AbstractArray, x::Real,
64+
cfg::DerivativeConfig{T} = DerivativeConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
6565
result isa DiffResult ? require_one_based_indexing(y) : require_one_based_indexing(result, y)
6666
CHK && checktag(T, f!, x)
6767
ydual = cfg.duals

src/gradient.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ This method assumes that `isa(f(x), Real)`.
1313
1414
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
1515
"""
16-
function gradient(f, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
16+
function gradient(f::F, x::AbstractArray, cfg::GradientConfig{T} = GradientConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T, CHK}
1717
require_one_based_indexing(x)
1818
CHK && checktag(T, f, x)
1919
if chunksize(cfg) == length(x)

src/hessian.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ This method assumes that `isa(f(x), Real)`.
1111
1212
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
1313
"""
14-
function hessian(f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
14+
function hessian(f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {F, T,CHK}
1515
require_one_based_indexing(x)
1616
CHK && checktag(T, f, x)
1717
∇f = y -> gradient(f, y, cfg.gradient_config, Val{false}())
@@ -28,7 +28,7 @@ This method assumes that `isa(f(x), Real)`.
2828
2929
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
3030
"""
31-
function hessian!(result::AbstractArray, f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
31+
function hessian!(result::AbstractArray, f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
3232
require_one_based_indexing(result, x)
3333
CHK && checktag(T, f, x)
3434
∇f = y -> gradient(f, y, cfg.gradient_config, Val{false}())
@@ -63,8 +63,7 @@ because `isa(result, DiffResult)`, `cfg` is constructed as `HessianConfig(f, res
6363
6464
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
6565
"""
66-
function hessian!(result::DiffResult, f, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, result, x), ::Val{CHK}=Val{true}()) where {T,CHK}
67-
require_one_based_indexing(x)
66+
function hessian!(result::DiffResult, f::F, x::AbstractArray, cfg::HessianConfig{T} = HessianConfig(f, result, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
6867
CHK && checktag(T, f, x)
6968
∇f! = InnerGradientForHess(result, cfg, f)
7069
jacobian!(DiffResults.hessian(result), ∇f!, DiffResults.gradient(result), x, cfg.jacobian_config, Val{false}())

src/jacobian.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
1515
1616
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
1717
"""
18-
function jacobian(f, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {T,CHK}
18+
function jacobian(f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
1919
require_one_based_indexing(x)
2020
CHK && checktag(T, f, x)
2121
if chunksize(cfg) == length(x)
@@ -33,7 +33,7 @@ stored in `y`.
3333
3434
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
3535
"""
36-
function jacobian(f!, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T, CHK}
36+
function jacobian(f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
3737
require_one_based_indexing(y, x)
3838
CHK && checktag(T, f!, x)
3939
if chunksize(cfg) == length(x)
@@ -54,7 +54,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
5454
5555
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
5656
"""
57-
function jacobian!(result::Union{AbstractArray,DiffResult}, f, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {T, CHK}
57+
function jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f, x), ::Val{CHK}=Val{true}()) where {F,T, CHK}
5858
result isa DiffResult ? require_one_based_indexing(x) : require_one_based_indexing(result, x)
5959
CHK && checktag(T, f, x)
6060
if chunksize(cfg) == length(x)
@@ -75,7 +75,7 @@ This method assumes that `isa(f(x), AbstractArray)`.
7575
7676
Set `check` to `Val{false}()` to disable tag checking. This can lead to perturbation confusion, so should be used with care.
7777
"""
78-
function jacobian!(result::Union{AbstractArray,DiffResult}, f!, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {T,CHK}
78+
function jacobian!(result::Union{AbstractArray,DiffResult}, f!::F, y::AbstractArray, x::AbstractArray, cfg::JacobianConfig{T} = JacobianConfig(f!, y, x), ::Val{CHK}=Val{true}()) where {F,T,CHK}
7979
result isa DiffResult ? require_one_based_indexing(y, x) : require_one_based_indexing(result, y, x)
8080
CHK && checktag(T, f!, x)
8181
if chunksize(cfg) == length(x)

0 commit comments

Comments
 (0)