|
1 | 1 | """
|
2 |
| -$(SIGNATURES) |
| 2 | + sumlog(X::AbstractArray{T}; dims) |
3 | 3 |
|
4 |
| -Compute `sum(log.(X))` with a single `log` evaluation. |
| 4 | +Compute `sum(log.(X))` with a single `log` evaluation, |
| 5 | +provided `float(T) <: AbstractFloat`. |
5 | 6 |
|
6 |
| -This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in |
7 |
| -particular as the size of `X` increases. |
8 |
| -
|
9 |
| -This works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, |
| 7 | +This is faster than computing `sum(log, X)`, especially for large `X`. |
| 8 | +It works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, |
10 | 9 | allowing us to write
|
11 | 10 | ```math
|
12 | 11 | \\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j
|
13 | 12 | ```
|
14 |
| -Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` |
15 |
| -evaluation. |
16 | 13 | """
|
17 |
| -function sumlog(x) |
18 |
| - T = float(eltype(x)) |
19 |
| - _sumlog(T, values(x)) |
| 14 | +sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x) |
| 15 | + |
| 16 | +function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat} |
| 17 | + sig, ex = mapreduce(_sumlog_op, x; init=(one(T), 0)) do xj |
| 18 | + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) |
| 19 | + float_xj = float(xj) |
| 20 | + significand(float_xj), _exponent(float_xj) |
| 21 | + end |
| 22 | + return log(sig) + IrrationalConstants.logtwo * T(ex) |
20 | 23 | end
|
21 | 24 |
|
22 |
| -@inline function _sumlog(::Type{T}, x) where {T<:AbstractFloat} |
23 |
| - sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj |
| 25 | +function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat} |
| 26 | + sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), 0)) do xj |
| 27 | + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) |
24 | 28 | float_xj = float(xj)
|
25 |
| - significand(float_xj), exponent(float_xj) |
| 29 | + significand(float_xj), _exponent(float_xj) |
| 30 | + end |
| 31 | + map(sig_ex) do (sig, ex) |
| 32 | + log(sig) + IrrationalConstants.logtwo * T(ex) |
26 | 33 | end
|
27 |
| - return log(sig) + IrrationalConstants.logtwo * ex |
28 | 34 | end
|
29 | 35 |
|
| 36 | +# Fallback: `float(T)` is not always `<: AbstractFloat`, e.g. complex, dual numbers or symbolics |
| 37 | +_sumlog(::Type, dims, x) = sum(log, x; dims) |
| 38 | + |
30 | 39 | @inline function _sumlog_op((sig1, ex1), (sig2, ex2))
|
31 | 40 | sig = sig1 * sig2
|
| 41 | + # sig = ifelse(sig2<0, sig2, sig1 * sig2) |
32 | 42 | ex = ex1 + ex2
|
33 | 43 | # Significands are in the range [1,2), so multiplication will eventually overflow
|
34 | 44 | if sig > floatmax(typeof(sig)) / 2
|
35 |
| - ex += exponent(sig) |
| 45 | + ex += _exponent(sig) |
36 | 46 | sig = significand(sig)
|
37 | 47 | end
|
38 | 48 | return sig, ex
|
39 | 49 | end
|
40 | 50 |
|
41 |
| -# `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics |
42 |
| -@inline _sumlog(::Type{T}, x) where {T} = sum(log, x) |
| 51 | +# The exported `exponent(x)` checks for `NaN` etc, this function doesn't, which is fine as `sig` keeps track. |
| 52 | +_exponent(x::Base.IEEEFloat) = Base.Math._exponent_finite_nonzero(x) |
| 53 | +Base.@assume_effects :nothrow _exponent(x::AbstractFloat) = Int(exponent(x)) # e.g. for BigFloat |
| 54 | + |
| 55 | +""" |
| 56 | + sumlog(x) |
| 57 | + sumlog(f, x, ys...) |
| 58 | +
|
| 59 | +For any iterator which produces `AbstractFloat` elements, |
| 60 | +this can use `sumlog`'s fast reduction strategy. |
| 61 | +
|
| 62 | +Signature with `f` is equivalent to `sum(log, map(f, x, ys...))` |
| 63 | +or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations. |
| 64 | +
|
| 65 | +Does not accept a `dims` keyword. |
| 66 | +""" |
| 67 | +sumlog(f, x) = sumlog(Iterators.map(f, x)) |
| 68 | +sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...)) |
| 69 | + |
| 70 | +# Iterator version, uses the same `_sumlog_op`, should be the same speed. |
| 71 | +function sumlog(x) |
| 72 | + iter = iterate(x) |
| 73 | + if isnothing(iter) |
| 74 | + T = Base._return_type(first, Tuple{typeof(x)}) |
| 75 | + return T <: Number ? zero(float(T)) : 0.0 |
| 76 | + end |
| 77 | + x1 = float(iter[1]) |
| 78 | + x1 isa AbstractFloat || return sum(log, x) |
| 79 | + x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1) |
| 80 | + sig, ex = significand(x1), _exponent(x1) |
| 81 | + nonfloat = zero(x1) |
| 82 | + iter = iterate(x, iter[2]) |
| 83 | + while iter !== nothing |
| 84 | + xj = float(iter[1]) |
| 85 | + if xj isa AbstractFloat |
| 86 | + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) |
| 87 | + sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj))) |
| 88 | + else |
| 89 | + nonfloat += log(xj) |
| 90 | + end |
| 91 | + iter = iterate(x, iter[2]) |
| 92 | + end |
| 93 | + return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat |
| 94 | +end |
0 commit comments