From 74eab0aaab34da6f89e50e62adee6eeb1afad46a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Nov 2022 21:45:21 +0100 Subject: [PATCH] Fix issues with ForwardDiff 0.10.33 --- Project.toml | 5 +++-- src/logsumexp.jl | 57 ++++++++++++++++++++++++++++++----------------- test/basicfuns.jl | 13 +++++++++++ test/runtests.jl | 1 + 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/Project.toml b/Project.toml index 6a8a26f9..34ffe204 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.18" +version = "0.3.19" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -21,10 +21,11 @@ julia = "1" [extras] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesTestUtils", "FiniteDifferences", "OffsetArrays", "Random", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "ForwardDiff", "OffsetArrays", "Random", "Test"] diff --git a/src/logsumexp.jl b/src/logsumexp.jl index 6e57f2d8..0540ccae 100644 --- a/src/logsumexp.jl +++ b/src/logsumexp.jl @@ -84,17 +84,22 @@ _logsumexp_onepass_reduce(X, ::Base.EltypeUnknown) = reduce(_logsumexp_onepass_o # reduce two numbers function _logsumexp_onepass_op(x1::T, x2::T) where {T<:Number} - xmax, a = if x1 == x2 - # handle `x1 = x2 = ±Inf` correctly - x2, zero(x1 - x2) - elseif isnan(x1) || isnan(x2) + xmax, a = if isnan(x1) || isnan(x2) # ensure that `NaN` is propagated correctly for complex numbers z = oftype(x1, NaN) z, exp(z) - elseif real(x1) > real(x2) - x1, x2 - x1 else - x2, x1 - x2 + real_x1 = real(x1) + real_x2 = real(x2) + if real_x1 > real_x2 + x1, x2 - x1 + elseif real_x1 < real_x2 + x2, x1 - x2 + else + # handle `x1 = x2 = ±Inf` correctly + # checking inequalities above instead of equality fixes issue #59 + x2, zero(x1 - x2) + end end r = exp(a) return xmax, r @@ -109,17 +114,22 @@ _logsumexp_onepass_op((xmax, r)::Tuple{<:Number,<:Number}, x::Number) = _logsumexp_onepass_op(x::Number, xmax::Number, r::Number) = _logsumexp_onepass_op(promote(x, xmax)..., r) function _logsumexp_onepass_op(x::T, xmax::T, r::Number) where {T<:Number} - _xmax, _r = if x == xmax - # handle `x = xmax = ±Inf` correctly - xmax, r + exp(zero(x - xmax)) - elseif isnan(x) || isnan(xmax) + _xmax, _r = if isnan(x) || isnan(xmax) # ensure that `NaN` is propagated correctly for complex numbers z = oftype(x, NaN) z, r + exp(z) - elseif real(x) > real(xmax) - x, (r + one(r)) * exp(xmax - x) else - xmax, r + exp(x - xmax) + real_x = real(x) + real_xmax = real(xmax) + if real_x > real_xmax + x, (r + one(r)) * exp(xmax - x) + elseif real_x < real_xmax + xmax, r + exp(x - xmax) + else + # handle `x = xmax = ±Inf` correctly + # checking inequalities above instead of equality fixes issue #59 + xmax, r + exp(zero(x - xmax)) + end end return _xmax, _r end @@ -134,17 +144,22 @@ function _logsumexp_onepass_op(xmax1::Number, xmax2::Number, r1::Number, r2::Num return _logsumexp_onepass_op(promote(xmax1, xmax2)..., promote(r1, r2)...) end function _logsumexp_onepass_op(xmax1::T, xmax2::T, r1::R, r2::R) where {T<:Number,R<:Number} - xmax, r = if xmax1 == xmax2 - # handle `xmax1 = xmax2 = ±Inf` correctly - xmax2, r2 + (r1 + one(r1)) * exp(zero(xmax1 - xmax2)) - elseif isnan(xmax1) || isnan(xmax2) + xmax, r = if isnan(xmax1) || isnan(xmax2) # ensure that `NaN` is propagated correctly for complex numbers z = oftype(xmax1, NaN) z, r1 + exp(z) - elseif real(xmax1) > real(xmax2) - xmax1, r1 + (r2 + one(r2)) * exp(xmax2 - xmax1) else - xmax2, r2 + (r1 + one(r1)) * exp(xmax1 - xmax2) + real_xmax1 = real(xmax1) + real_xmax2 = real(xmax2) + if real_xmax1 > real_xmax2 + xmax1, r1 + (r2 + one(r2)) * exp(xmax2 - xmax1) + elseif real_xmax1 < real_xmax2 + xmax2, r2 + (r1 + one(r1)) * exp(xmax1 - xmax2) + else + # handle `xmax1 = xmax2 = ±Inf` correctly + # checking inequalities above instead of equality fixes issue #59 + xmax2, r2 + (r1 + one(r1)) * exp(zero(xmax1 - xmax2)) + end end return xmax, r end diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 3d993dd2..b133b1e9 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -342,6 +342,19 @@ end expected = logsumexp(xs; dims=2) @test logsumexp!(out, xs) ≈ expected @test out ≈ expected + + @testset "ForwardDiff" begin + # vector with finite numbers + x = randn(10) + ∇x = unthunk(rrule(logsumexp, x)[2](1)[2]) + @test ForwardDiff.gradient(logsumexp, x) ≈ ∇x + + # issue #59 + x = vcat(-Inf, randn(9)) + ∇x = unthunk(rrule(logsumexp, x)[2](1)[2]) + @assert all(isfinite, ∇x) + @test ForwardDiff.gradient(logsumexp, x) ≈ ∇x + end end @testset "softmax" begin diff --git a/test/runtests.jl b/test/runtests.jl index e780556e..b9665e71 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using ChainRulesTestUtils using ChainRulesCore using ChangesOfVariables using FiniteDifferences +using ForwardDiff using InverseFunctions using OffsetArrays