-
Notifications
You must be signed in to change notification settings - Fork 23
sumlog #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
sumlog #48
Changes from 13 commits
581b9df
5725aa9
77aa3d9
88d6fb1
9ecf589
9db732f
76533e1
5747205
0f5a927
977723d
4d488cd
afa5d94
e400483
cc1aaac
07809b7
16ee153
1af518b
0eaf8d2
1f478d0
0807f7a
2a0004d
e5809d1
eb1b524
6fe8bb1
0def97d
3ad95b2
a0a9348
fa667ec
989a111
a54a024
dc48433
39ca989
207fce2
55d125e
bef4728
9572e48
3848848
c4c3e89
e0f410e
23b5bf1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "LogExpFunctions" | ||
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" | ||
authors = ["StatsFun.jl contributors, Tamas K. Papp <[email protected]>"] | ||
version = "0.3.14" | ||
version = "0.3.15" | ||
|
||
[deps] | ||
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
using IrrationalConstants: logtwo | ||
|
||
cscherrer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
$(SIGNATURES) | ||
|
||
Compute `sum(log.(X))` with a single `log` evaluation. | ||
|
||
This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in | ||
particular as the size of `X` increases. | ||
|
||
This works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, | ||
allowing us to write | ||
```math | ||
\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j | ||
``` | ||
Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` | ||
evaluation. | ||
""" | ||
function sumlog(x::AbstractArray{<:Real}) | ||
T = float(eltype(x)) | ||
_sumlog(T, x) | ||
end | ||
|
||
@inline function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} | ||
sig = one(T) | ||
ex = zero(exponent(sig)) | ||
bound = floatmax(T) / 2 | ||
for xj in x | ||
float_xj = float(xj) | ||
sig *= significand(float_xj) | ||
ex += exponent(float_xj) | ||
|
||
# Significands are in the range [1,2), so multiplication will eventually overflow | ||
if sig > bound | ||
(a, b) = (significand(sig), exponent(sig)) | ||
sig = a | ||
ex += b | ||
end | ||
end | ||
cscherrer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log(sig) + logtwo * ex | ||
end | ||
|
||
# `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` | ||
cscherrer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@inline _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T} = sum(log, x) | ||
|
||
sumlog(x) = sum(log, x) | ||
cscherrer marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
@testset "sumlog" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some surprises: julia> sumlog([1,2,-0.1,-0.2])
-3.2188758248682
julia> sumlog([1,2,NaN])
ERROR: DomainError with NaN:
Cannot be NaN or Inf.
Stacktrace:
[1] (::Base.Math.var"#throw1#5")(x::Float64)
@ Base.Math ./math.jl:845
[2] exponent
@ ./math.jl:848 [inlined]
julia> sumlog([-0.0])
ERROR: DomainError with -0.0:
Cannot be ±0.0.
julia> sum(log, -0.0)
-Inf There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling Adding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, note that tests right now only test Float64 |
||
for T in [Int, Float16, Float32, Float64, BigFloat] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed that you removed the type restriction. Thus we should extend the tests and eg. check more general iterables (also with different types, abstract eltype etc since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
julia> Core.Compiler.return_type(gen.f, Tuple{eltype(gen.iter)})
Float64 We could instead have it fall back on the default, but I'd guess that will sacrifice performance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I bet you could write an equally fast version which explicitly calls One reason to keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should check more carefully, but this appears to work & is as fast as current version: function sumlog(x)
iter = iterate(x)
if isnothing(iter)
return eltype(x) <: Number ? zero(float(eltype(x))) : 0.0
end
x1 = float(iter[1])
x1 isa AbstractFloat || return sum(log, x)
sig, ex = significand(x1), exponent(x1)
iter = iterate(x, iter[2])
while iter !== nothing
xj = float(iter[1])
x1 isa AbstractFloat || return sum(log, x) # maybe not ideal, re-starts iterator
sig, ex = _sumlog_op((sig, ex), (significand(xj), exponent(xj)))
iter = iterate(x, iter[2])
end
return log(sig) + IrrationalConstants.logtwo * ex
end
sumlog(f, x) = sumlog(Iterators.map(f, x))
sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...)) And for dims: sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x)
function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat}
sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj
float_xj = float(xj)
significand(float_xj), exponent(float_xj)
end
return log(sig) + IrrationalConstants.logtwo * ex
end
function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat}
sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), zero(exponent(one(T))))) do xj
float_xj = float(xj)
significand(float_xj), exponent(float_xj)
end
map(sig_ex) do (sig, ex)
log(sig) + IrrationalConstants.logtwo * ex
end
end Should I make a PR to the PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cscherrer#1 is a tidier version of the above. |
||
for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)] | ||
@test (@inferred sumlog(x)) ≈ sum(log, x) | ||
end | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.