Skip to content

Commit eb1b524

Browse files
authored
Merge pull request #1 from mcabbott/iterate
Fiddling with `sumlog`
2 parents 2a0004d + e5809d1 commit eb1b524

File tree

2 files changed

+120
-30
lines changed

2 files changed

+120
-30
lines changed

src/sumlog.jl

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,94 @@
11
"""
2-
$(SIGNATURES)
2+
sumlog(X::AbstractArray{T}; dims)
33
4-
Compute `sum(log.(X))` with a single `log` evaluation.
4+
Compute `sum(log.(X))` with a single `log` evaluation,
5+
provided `float(T) <: AbstractFloat`.
56
6-
This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in
7-
particular as the size of `X` increases.
8-
9-
This works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``,
7+
This is faster than computing `sum(log, X)`, especially for large `X`.
8+
It works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``,
109
allowing us to write
1110
```math
1211
\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j
1312
```
14-
Since ``\\log{2}`` is constant, `sumlog` only requires a single `log`
15-
evaluation.
1613
"""
17-
function sumlog(x)
18-
T = float(eltype(x))
19-
_sumlog(T, values(x))
14+
sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x)
15+
16+
function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat}
17+
sig, ex = mapreduce(_sumlog_op, x; init=(one(T), 0)) do xj
18+
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
19+
float_xj = float(xj)
20+
significand(float_xj), _exponent(float_xj)
21+
end
22+
return log(sig) + IrrationalConstants.logtwo * T(ex)
2023
end
2124

22-
@inline function _sumlog(::Type{T}, x) where {T<:AbstractFloat}
23-
sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj
25+
function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat}
26+
sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), 0)) do xj
27+
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
2428
float_xj = float(xj)
25-
significand(float_xj), exponent(float_xj)
29+
significand(float_xj), _exponent(float_xj)
30+
end
31+
map(sig_ex) do (sig, ex)
32+
log(sig) + IrrationalConstants.logtwo * T(ex)
2633
end
27-
return log(sig) + IrrationalConstants.logtwo * ex
2834
end
2935

36+
# Fallback: `float(T)` is not always `<: AbstractFloat`, e.g. complex, dual numbers or symbolics
37+
_sumlog(::Type, dims, x) = sum(log, x; dims)
38+
3039
@inline function _sumlog_op((sig1, ex1), (sig2, ex2))
3140
sig = sig1 * sig2
41+
# sig = ifelse(sig2<0, sig2, sig1 * sig2)
3242
ex = ex1 + ex2
3343
# Significands are in the range [1,2), so multiplication will eventually overflow
3444
if sig > floatmax(typeof(sig)) / 2
35-
ex += exponent(sig)
45+
ex += _exponent(sig)
3646
sig = significand(sig)
3747
end
3848
return sig, ex
3949
end
4050

41-
# `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics
42-
@inline _sumlog(::Type{T}, x) where {T} = sum(log, x)
51+
# The exported `exponent(x)` checks for `NaN` etc, this function doesn't, which is fine as `sig` keeps track.
52+
_exponent(x::Base.IEEEFloat) = Base.Math._exponent_finite_nonzero(x)
53+
Base.@assume_effects :nothrow _exponent(x::AbstractFloat) = Int(exponent(x)) # e.g. for BigFloat
54+
55+
"""
56+
sumlog(x)
57+
sumlog(f, x, ys...)
58+
59+
For any iterator which produces `AbstractFloat` elements,
60+
this can use `sumlog`'s fast reduction strategy.
61+
62+
Signature with `f` is equivalent to `sum(log, map(f, x, ys...))`
63+
or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations.
64+
65+
Does not accept a `dims` keyword.
66+
"""
67+
sumlog(f, x) = sumlog(Iterators.map(f, x))
68+
sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...))
69+
70+
# Iterator version, uses the same `_sumlog_op`, should be the same speed.
71+
function sumlog(x)
72+
iter = iterate(x)
73+
if isnothing(iter)
74+
T = Base._return_type(first, Tuple{typeof(x)})
75+
return T <: Number ? zero(float(T)) : 0.0
76+
end
77+
x1 = float(iter[1])
78+
x1 isa AbstractFloat || return sum(log, x)
79+
x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1)
80+
sig, ex = significand(x1), _exponent(x1)
81+
nonfloat = zero(x1)
82+
iter = iterate(x, iter[2])
83+
while iter !== nothing
84+
xj = float(iter[1])
85+
if xj isa AbstractFloat
86+
xj < 0 && Base.Math.throw_complex_domainerror(:log, xj)
87+
sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj)))
88+
else
89+
nonfloat += log(xj)
90+
end
91+
iter = iterate(x, iter[2])
92+
end
93+
return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat
94+
end

test/sumlog.jl

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,63 @@
11
@testset "sumlog" begin
2-
for T in [Int, Float16, Float32, Float64, BigFloat]
3-
for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)]
2+
@testset for T in [Float16, Float32, Float64, BigFloat]
3+
for x in (
4+
T[1,2,3],
5+
10 .* rand(T, 1000),
6+
fill(nextfloat(T(1.0)), 1000),
7+
fill(prevfloat(T(2.0)), 1000),
8+
)
9+
@test sumlog(x) isa T
10+
411
@test (@inferred sumlog(x)) sum(log, x)
512

6-
y = view(x, 1:100)
7-
@test (@inferred sumlog(y)) sum(log, y)
13+
y = @view x[1:min(end, 100)]
14+
@test (@inferred sumlog(y')) sum(log, y)
815

916
tup = tuple(y...)
1017
@test (@inferred sumlog(tup)) sum(log, tup)
18+
#
19+
# gen = (sqrt(a) for a in y)
20+
# # `eltype` of a `Base.Generator` returns `Any`
21+
# @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen)
1122

12-
gen = (sqrt(a) for a in y)
13-
# `eltype` of a `Base.Generator` returns `Any`
14-
@test_broken (@inferred sumlog(gen)) sum(log, gen)
23+
# nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup)
24+
# @test (@inferred sumlog(y)) ≈ sum(log, y)
1525

16-
nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup)
17-
@test (@inferred sumlog(y)) sum(log, y)
18-
19-
i = Random.shuffle(x)
20-
z = x .+ i * im
26+
z = x .+ im .* Random.shuffle(x)
2127
@test (@inferred sumlog(z)) sum(log, z)
2228
end
2329

30+
# With dims
31+
m = 1 .+ rand(T, 10, 10)
32+
sumlog(m; dims=1) sum(log, m; dims=1)
33+
sumlog(m; dims=2) sum(log, m; dims=2)
34+
35+
# Iterator
36+
@test sumlog(x^2 for x in m) sumlog(abs2, m) sumlog(*, m, m) sum(log.(m.^2))
37+
@test sumlog(x for x in Any[1, 2, 3+im, 4]) sum(log, Any[1, 2, 3+im, 4])
38+
39+
# NaN, Inf
40+
if T != BigFloat # exponent fails here
41+
@test isnan(sumlog(T[1, 2, NaN]))
42+
@test isinf(sumlog(T[1, 2, Inf]))
43+
@test sumlog(T[1, 2, 0.0]) == -Inf
44+
@test sumlog(T[1, 2, -0.0]) == -Inf
45+
end
46+
47+
# Empty
48+
@test sumlog(T[]) isa T
49+
@test eltype(sumlog(T[]; dims=1)) == T
50+
@test sumlog(x for x in T[]) isa T
51+
52+
# Negative
53+
@test_throws DomainError sumlog(T[1, -2, 3]) # easy
54+
@test_throws DomainError sumlog(T[1, -2, -3]) # harder
55+
56+
end
57+
@testset "Int" begin
58+
@test sumlog([1,2,3]) isa Float64
59+
@test sumlog([1,2,3]) sum(log, [1,2,3])
60+
@test sumlog([1 2; 3 4]; dims=1) sum(log, [1 2; 3 4]; dims=1)
61+
@test sumlog(Int(x) for x in Float64[1,2,3]) sum(log, [1,2,3])
2462
end
2563
end

0 commit comments

Comments
 (0)