Skip to content

Commit b4cb1ac

Browse files
Make seeds and extract_jacobian gpu-friendly
Use broadcast/macro consistently Fix jac
1 parent 4c7495d commit b4cb1ac

File tree

3 files changed

+19
-24
lines changed

3 files changed

+19
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ForwardDiff"
22
uuid = "f6369f11-7733-5829-9624-2563aa707210"
3-
version = "0.10.12"
3+
version = "0.10.13"
44

55
[deps]
66
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"

src/apiutils.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,30 @@ end
5555

5656
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
5757
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
58-
for i in eachindex(duals)
59-
duals[i] = Dual{T,V,N}(x[i], seed)
60-
end
58+
duals .= Dual{T,V,N}.(x, Ref(seed))
6159
return duals
6260
end
6361

6462
function seed!(duals::AbstractArray{Dual{T,V,N}}, x,
6563
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
66-
for i in 1:N
67-
duals[i] = Dual{T,V,N}(x[i], seeds[i])
68-
end
64+
dual_inds = 1:N
65+
duals[dual_inds] .= Dual{T,V,N}.(view(x,dual_inds), seeds)
6966
return duals
7067
end
7168

7269
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
7370
seed::Partials{N,V} = zero(Partials{N,V})) where {T,V,N}
7471
offset = index - 1
75-
for i in 1:N
76-
j = i + offset
77-
duals[j] = Dual{T,V,N}(x[j], seed)
78-
end
72+
dual_inds = (1:N) .+ offset
73+
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), Ref(seed))
7974
return duals
8075
end
8176

8277
function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index,
8378
seeds::NTuple{N,Partials{N,V}}, chunksize = N) where {T,V,N}
8479
offset = index - 1
85-
for i in 1:chunksize
86-
j = i + offset
87-
duals[j] = Dual{T,V,N}(x[j], seeds[i])
88-
end
80+
seed_inds = 1:chunksize
81+
dual_inds = seed_inds .+ offset
82+
duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), seeds[seed_inds])
8983
return duals
9084
end

src/jacobian.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ end
111111

112112
function extract_jacobian!(::Type{T}, result::AbstractArray, ydual::AbstractArray, n) where {T}
113113
out_reshaped = reshape(result, length(ydual), n)
114-
for col in 1:size(out_reshaped, 2), row in 1:size(out_reshaped, 1)
115-
out_reshaped[row, col] = partials(T, ydual[row], col)
116-
end
114+
ydual_reshaped = vec(ydual)
115+
# Use closure to avoid GPU broadcasting with Type
116+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
117+
out_reshaped .= partials_wrap.(ydual_reshaped, transpose(1:n))
117118
return result
118119
end
119120

@@ -123,13 +124,13 @@ function extract_jacobian!(::Type{T}, result::MutableDiffResult, ydual::Abstract
123124
end
124125

125126
function extract_jacobian_chunk!(::Type{T}, result, ydual, index, chunksize) where {T}
127+
ydual_reshaped = vec(ydual)
126128
offset = index - 1
127-
for i in 1:chunksize
128-
col = i + offset
129-
for row in eachindex(ydual)
130-
result[row, col] = partials(T, ydual[row], i)
131-
end
132-
end
129+
irange = 1:chunksize
130+
col = irange .+ offset
131+
# Use closure to avoid GPU broadcasting with Type
132+
partials_wrap(ydual, nrange) = partials(T, ydual, nrange)
133+
result[:, col] .= partials_wrap.(ydual_reshaped, transpose(irange))
133134
return result
134135
end
135136

0 commit comments

Comments
 (0)