Skip to content

Commit 23ed88e

Browse files
Make seeds and extract_jacobian gpu-friendly
Use broadcast/macro consistently Fix jac Add AllocationsTest.jl
1 parent 4c7495d commit 23ed88e

File tree

5 files changed

+54
-25
lines changed

5 files changed

+54
-25
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "0.10.12"
3+
version = "0.10.13"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
@@ -28,7 +28,8 @@ DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
2828
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
2929
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3030
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
31+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233

3334
[targets]
34-
test = ["Calculus", "DiffTests", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils"]
35+
test = ["Calculus", "DiffTests", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils", "BenchmarkTools"]

src/apiutils.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,30 @@ end
5555

5656
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
5757
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
58-
for i in eachindex(duals)
59-
duals[i] = Dual{T,V,N}(x[i], seed)
60-
end
58+
duals .= Dual{T,V,N}.(x, Ref(seed))
6159
return duals
6260
end
6361

6462
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
6563
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
66-
for i in 1:N
67-
duals[i] = Dual{T,V,N}(x[i], seeds[i])
68-
end
64+
dual_inds = 1:N
65+
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
6966
return duals
7067
end
7168

7269
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
7370
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
7471
offset = index - 1
75-
for i in 1:N
76-
j = i + offset
77-
duals[j] = Dual{T,V,N}(x[j], seed)
78-
end
72+
dual_inds = (1:N) .+ offset
73+
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
7974
return duals
8075
end
8176

8277
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
8378
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
8479
offset = index - 1
85-
for i in 1:chunksize
86-
j = i + offset
87-
duals[j] = Dual{T,V,N}(x[j], seeds[i])
88-
end
80+
seed_inds = 1:chunksize
81+
dual_inds = seed_inds .+ offset
82+
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds))
8983
return duals
9084
end

src/jacobian.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ end
111111

112112
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
113113
out_reshaped = reshape(result, length(ydual), n)
114-
for col in 1:size(out_reshaped, 2), row in 1:size(out_reshaped, 1)
115-
out_reshaped[row, col] = partials(T, ydual[row], col)
116-
end
114+
ydual_reshaped = vec(ydual)
115+
# Use closure to avoid GPU broadcasting with Type
116+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
117+
out_reshaped .= partials_wrap.(ydual_reshaped, transpose(1:n))
117118
return result
118119
end
119120

@@ -123,13 +124,13 @@ function extract_jacobian!(::Type{T}, result::MutableDiffResult, ydual::Abstract
123124
end
124125

125126
function extract_jacobian_chunk!(::Type{T}, result, ydual, index, chunksize) where {T}
127+
ydual_reshaped = vec(ydual)
126128
offset = index - 1
127-
for i in 1:chunksize
128-
col = i + offset
129-
for row in eachindex(ydual)
130-
result[row, col] = partials(T, ydual[row], i)
131-
end
132-
end
129+
irange = 1:chunksize
130+
col = irange .+ offset
131+
# Use closure to avoid GPU broadcasting with Type
132+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
133+
result[:, col] .= partials_wrap.(ydual_reshaped, transpose(irange))
133134
return result
134135
end
135136

test/AllocationsTest.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module AllocationsTest
2+
3+
using ForwardDiff
4+
using BenchmarkTools
5+
6+
include(joinpath(dirname(@__FILE__), "utils.jl"))
7+
8+
@testset "Test seed! allocations" begin
9+
x = rand(1000)
10+
cfg = ForwardDiff.GradientConfig(nothing, x)
11+
12+
balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $(cfg.seeds))
13+
@test balloc == 0
14+
15+
balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $(cfg.seeds[1]))
16+
@test balloc == 0
17+
18+
index = 1
19+
balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $index, $(cfg.seeds))
20+
@test balloc == 0
21+
22+
index = 1
23+
balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $index, $(cfg.seeds[1]))
24+
@test balloc == 0
25+
end
26+
27+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ println("done (took $t seconds).")
3131
println("Testing miscellaneous functionality...")
3232
t = @elapsed include("MiscTest.jl")
3333
println("done (took $t seconds).")
34+
35+
if VERSION >= v"1.5-"
36+
println("Testing allocations...")
37+
t = @elapsed include("AllocationsTest.jl")
38+
println("done (took $t seconds).")
39+
end

0 commit comments

Comments
 (0)