Skip to content

Commit

Permalink
Merge in main
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Jan 6, 2025
2 parents a729318 + 5eab9a6 commit 1af280e
Show file tree
Hide file tree
Showing 26 changed files with 637 additions and 434 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.63"
version = "0.4.70"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
22 changes: 13 additions & 9 deletions docs/src/understanding_mooncake/algorithmic_differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,20 @@ Using the usual definition of the inner product between matrices,
```
we can rearrange the inner product as follows:
```math
\begin{align}
\langle \bar{Y}, D f [X] (\dot{X}) \rangle &= \langle \bar{Y}, \dot{X}^\top X + X^\top \dot{X} \rangle \nonumber \\
&= \textrm{tr} (\bar{Y}^\top \dot{X}^\top X) + \textrm{tr}(\bar{Y}^\top X^\top \dot{X}) \nonumber \\
&= \textrm{tr} ( [\bar{Y} X^\top]^\top \dot{X}) + \textrm{tr}( [X \bar{Y}]^\top \dot{X}) \nonumber \\
&= \langle \bar{Y} X^\top + X \bar{Y}, \dot{X} \rangle. \nonumber
\end{align}
```
We can read off the adjoint operator from the first argument to the inner product:
\begin{align*}
\langle\bar{Y},Df[X](\dot{X})\rangle & =\langle\bar{Y},\dot{X}^{\top}X+X^{\top}\dot{X}\rangle\\
& =\textrm{tr}(\bar{Y}^{\top}\left(\dot{X}^{\top}X+X^{\top}\dot{X}\right))\\
& =\textrm{tr}(\dot{X}^{\top}X\bar{Y}^{\top})+\textrm{tr}(\bar{Y}^{\top}X^{\top}\dot{X})\\
& =\langle\dot{X},X\bar{Y}^{\top}\rangle+\langle X\bar{Y},\dot{X}\rangle\\
& =\langle X\bar{Y}^{\top}+X\bar{Y},\dot{X}\rangle.
\end{align*}
```
The linearity of inner products and trace, and the [cyclic property of trace](https://en.wikipedia.org/wiki/Trace_(linear_algebra)#Cyclic_property) was used in the above. We can read off the adjoint operator from the first argument to the inner product:
```math
D f [X]^\ast (\bar{Y}) = \bar{Y} X^\top + X \bar{Y}.
\begin{align*}
Df\left[X\right]^{*}\left(\bar{Y}\right) & =X\bar{Y}^{\top}+X\bar{Y}\\
& =X\left(\bar{Y}^{\top}+\bar{Y}\right).
\end{align*}
```

#### AD of a Julia function: a trivial example
Expand Down
2 changes: 1 addition & 1 deletion ext/MooncakeCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import Mooncake.TestUtils: populate_address_map!, AddressMap, __increment_should

# Tell Mooncake.jl how to handle CuArrays.

tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P
Mooncake.@tt_effects tangent_type(::Type{P}) where {P<:CuArray{<:IEEEFloat}} = P
zero_tangent(x::CuArray{<:IEEEFloat}) = zero(x)
function randn_tangent(rng::AbstractRNG, x::CuArray{Float32})
return cu(randn(rng, Float32, size(x)...))
Expand Down
162 changes: 89 additions & 73 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,13 @@ end
T == NoTangent && return NoFData

# This method can only handle struct types. Tell user to implement their own method.
isprimitivetype(T) &&
throw(error("$T is a primitive type. Implement a method of `fdata_type` for it."))
if isprimitivetype(T)
msg = "$T is a primitive type. Implement a method of `fdata_type` for it."
return :(error($msg))
end

# If the type is a Union, then take the union type of its arguments.
T isa Union && return Union{fdata_type(T.a),fdata_type(T.b)}
T isa Union && return :(Union{fdata_type($(T.a)),fdata_type($(T.b))})

# If `P` is a mutable type, then its forwards data is its tangent.
ismutabletype(T) && return T
Expand All @@ -179,33 +181,37 @@ end
# The same goes for if the type has any undetermined type parameters.
(isabstracttype(T) || !isconcretetype(T)) && return Any

# We should now have a `Tangent`. If not, we do not know what to do, so error.
T <: Tangent || return :(error("Unhandled type $T"))

# If `P` is an immutable type, then some of its fields may not need to be propagated
# on the forwards-pass.
if T <: Tangent
Tfields = fields_type(T)
fwds_data_field_types = map(1:fieldcount(Tfields)) do n
return fdata_type(fieldtype(Tfields, n))
end
all(==(NoFData), fwds_data_field_types) && return NoFData
return FData{NamedTuple{fieldnames(Tfields),Tuple{fwds_data_field_types...}}}
field_names = fieldnames(fields_type(T))
Tfields = fieldtypes(fields_type(T))
fdata_type_exprs = map(n -> :(fdata_type($(Tfields[n]))), 1:length(Tfields))
return quote
fwds_data_field_types = $(Expr(:call, :tuple, fdata_type_exprs...))
stable_all(tuple_map(==(NoFData), fwds_data_field_types)) && return NoFData
return FData{NamedTuple{$field_names,Tuple{fwds_data_field_types...}}}
end

return :(error("Unhandled type $T"))
end

fdata_type(::Type{T}) where {T<:Ptr} = T

@generated function fdata_type(::Type{P}) where {P<:Tuple}
isa(P, Union) && return Union{fdata_type(P.a),fdata_type(P.b)}
isa(P, Union) && return :(Union{fdata_type($(P.a)),fdata_type($(P.b))})
isempty(P.parameters) && return NoFData
isa(last(P.parameters), Core.TypeofVararg) && return Any
nofdata_tt = Tuple{Vararg{NoFData,length(P.parameters)}}
fdata_tt = Tuple{map(fdata_type, fieldtypes(P))...}
fdata_tt <: nofdata_tt && return NoFData
return nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt
fdata_type_exprs = map(_P -> Expr(:call, :fdata_type, _P), P.parameters)
return quote
fdata_tt = $(Expr(:curly, Tuple, fdata_type_exprs...))
fdata_tt <: $nofdata_tt && return NoFData
return $nofdata_tt <: fdata_tt ? Union{NoFData,fdata_tt} : fdata_tt
end
end

@generated function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple}
function fdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple}
if fdata_type(T) == NoFData
return NoFData
elseif isconcretetype(fdata_type(T))
Expand All @@ -224,28 +230,28 @@ Returns the type of to the nth field of the fdata type associated to `P`. Will b
function fdata_field_type(::Type{P}, n::Int) where {P}
Tf = tangent_type(fieldtype(P, n))
f = ismutabletype(P) ? Tf : fdata_type(Tf)
return is_always_initialised(P, n) ? f : _wrap_type(f)
return is_always_initialised(P, n) ? f : PossiblyUninitTangent{f}
end

"""
fdata(t)::fdata_type(typeof(t))
Extract the forwards data from tangent `t`.
"""
@generated function fdata(t::T) where {T}
function fdata(t::T) where {T}

# Ask for the forwards-data type. Useful catch-all error checking for unexpected types.
F = fdata_type(T)

# Catch-all for anything with no forwards-data.
F == NoFData && return :(NoFData())
F == NoFData && return NoFData()

# Catch-all for anything where we return the whole object (mutable structs, arrays...).
F == T && return :(t)
F == T && return t

# T must be a `Tangent` by now. If it's not, something has gone wrong.
!(T <: Tangent) && return :(error("Unhandled type $T"))
return :($F(fdata(t.fields)))
T <: Tangent || error("Unhandled type $T")
return F(fdata(t.fields))
end

function fdata(t::T) where {T<:PossiblyUninitTangent}
Expand Down Expand Up @@ -415,11 +421,13 @@ end
T == NoTangent && return NoRData

# This method can only handle struct types. Tell user to implement their own method.
isprimitivetype(T) &&
throw(error("$T is a primitive type. Implement a method of `rdata_type` for it."))
if isprimitivetype(T)
msg = "$T is a primitive type. Implement a method of `rdata_type` for it."
return :(error(msg))
end

# If the type is a Union, then take the union type of its arguments.
T isa Union && return Union{rdata_type(T.a),rdata_type(T.b)}
T isa Union && return :(Union{rdata_type($(T.a)),rdata_type($(T.b))})

# If `P` is a mutable type, then all tangent info is propagated on the forwards-pass.
ismutabletype(T) && return NoRData
Expand All @@ -428,26 +436,31 @@ end
# The same goes for if the type has any undetermined type parameters.
(isabstracttype(T) || !isconcretetype(T)) && return Any

# If `T` is an immutable type, then some of its fields may not have been propagated on
# the forwards-pass.
if T <: Tangent
Tfs = fields_type(T)
rvs_types = map(n -> rdata_type(fieldtype(Tfs, n)), 1:fieldcount(Tfs))
all(==(NoRData), rvs_types) && return NoRData
return RData{NamedTuple{fieldnames(Tfs),Tuple{rvs_types...}}}
# If `T` is an immutable type, then some of its fields may not need to be propagated
# on the forwards-pass.
field_names = fieldnames(fields_type(T))
Tfields = fieldtypes(fields_type(T))
rdata_type_exprs = map(n -> :(rdata_type($(Tfields[n]))), 1:length(Tfields))
return quote
rvs_data_field_types = $(Expr(:call, :tuple, rdata_type_exprs...))
stable_all(tuple_map(==(NoRData), rvs_data_field_types)) && return NoRData
return RData{NamedTuple{$field_names,Tuple{rvs_data_field_types...}}}
end
end

rdata_type(::Type{<:Ptr}) = NoRData

@generated function rdata_type(::Type{P}) where {P<:Tuple}
isa(P, Union) && return Union{rdata_type(P.a),rdata_type(P.b)}
isa(P, Union) && return :(Union{rdata_type($(P.a)),rdata_type($(P.b))})
isempty(P.parameters) && return NoRData
isa(last(P.parameters), Core.TypeofVararg) && return Any
nordata_tt = Tuple{Vararg{NoRData,length(P.parameters)}}
rdata_tt = Tuple{map(rdata_type, fieldtypes(P))...}
rdata_tt <: nordata_tt && return NoRData
return nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt
rdata_type_exprs = map(_P -> Expr(:call, :rdata_type, _P), P.parameters)
return quote
rdata_tt = $(Expr(:curly, Tuple, rdata_type_exprs...))
rdata_tt <: $nordata_tt && return NoRData
return $nordata_tt <: rdata_tt ? Union{NoRData,rdata_tt} : rdata_tt
end
end

function rdata_type(::Type{NamedTuple{names,T}}) where {names,T<:Tuple}
Expand All @@ -468,7 +481,7 @@ Returns the type of to the nth field of the rdata type associated to `P`. Will b
"""
function rdata_field_type(::Type{P}, n::Int) where {P}
r = rdata_type(tangent_type(fieldtype(P, n)))
return is_always_initialised(P, n) ? r : _wrap_type(r)
return is_always_initialised(P, n) ? r : PossiblyUninitTangent{r}
end

"""
Expand All @@ -480,20 +493,20 @@ Extract the reverse data from tangent `t`.
See extended help section of [fdata_type](@ref).
"""
@generated function rdata(t::T) where {T}
function rdata(t::T) where {T}

# Ask for the reverse-data type. Useful catch-all error checking for unexpected types.
R = rdata_type(T)

# Catch-all for anything with no reverse-data.
R == NoRData && return :(NoRData())
R == NoRData && return NoRData()

# Catch-all for anything where we return the whole object (Float64, isbits structs, ...)
R == T && return :(t)
R == T && return t

# T must be a `Tangent` by now. If it's not, something has gone wrong.
!(T <: Tangent) && return :(error("Unhandled type $T"))
return :($(rdata_type(T))(rdata(t.fields)))
T <: Tangent || error("Unhandled type $T")
return R(rdata(t.fields))
end

function rdata(t::T) where {T<:PossiblyUninitTangent}
Expand Down Expand Up @@ -604,41 +617,48 @@ constitute a correctness problem, but can be detrimental to performance, so shou
with.
"""
@generated function zero_rdata_from_type(::Type{P}) where {P}
R = rdata_type(tangent_type(P))

# If we know we can't produce a tangent, say so.
can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType()

# Simple case.
R == NoRData && return NoRData()

# If `P` is a struct type, attempt to derive the zero rdata for it. We cannot derive
# the zero rdata if it is not possible to derive the zero rdata for any of its fields.
if isstructtype(P)
# Prepare expressions for manually-unrolled loop to construct zero rdata elements.
if P isa DataType
names = fieldnames(P)
types = fieldtypes(P)
wrapped_field_zeros = tuple_map(ntuple(identity, length(names))) do n
wrapped_field_zeros = map(enumerate(tangent_field_types(P))) do (n, tt)
fzero = :(zero_rdata_from_type($(types[n])))
if tangent_field_type(P, n) <: PossiblyUninitTangent
Q = rdata_type(tangent_type(fieldtype(P, n)))
return :(_wrap_field($Q, $fzero))
if tt <: PossiblyUninitTangent
Q = :(rdata_type(tangent_type($(fieldtype(P, n)))))
return :(PossiblyUninitTangent{$Q}($fzero))
else
return fzero
end
end
wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...)
return :($R(NamedTuple{$names}($wrapped_field_zeros_tuple)))
wrapped_expr = :(R(NamedTuple{$names}($wrapped_field_zeros_tuple)))
else
wrapped_expr = nothing
end

# Fallback -- we've not been able to figure out how to produce an instance of zero rdata
# so report that it cannot be done.
return throw(error("Unhandled type $P"))
return quote

# If we know we can't produce a tangent, say so.
can_produce_zero_rdata_from_type($P) || return CannotProduceZeroRDataFromType()

# Simple case.
R = rdata_type(tangent_type($P))
R == NoRData && return NoRData()

$(isstructtype(P)) || error("Unhandled type $P")
return $wrapped_expr
end
end

@generated function zero_rdata_from_type(::Type{P}) where {P<:Tuple}
can_produce_zero_rdata_from_type(P) || return CannotProduceZeroRDataFromType()
rdata_type(tangent_type(P)) == NoRData && return NoRData()
return tuple_map(zero_rdata_from_type, fieldtypes(P))
has_fields = P isa DataType && Base.datatype_fieldcount(P) !== nothing
zero_exprs = has_fields ? map(_P -> :(zero_rdata_from_type($_P)), fieldtypes(P)) : []
return quote
can_produce_zero_rdata_from_type($P) || return CannotProduceZeroRDataFromType()
rdata_type(tangent_type($P)) == NoRData && return NoRData()
return $(Expr(:call, :tuple, zero_exprs...))
end
end

function zero_rdata_from_type(::Type{P}) where {P<:NamedTuple}
Expand Down Expand Up @@ -785,15 +805,14 @@ tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F

# Tuples
@generated function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple}
return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...}
tt_exprs = map((f, r) -> :(tangent_type($f, $r)), fieldtypes(F), fieldtypes(R))
return Expr(:curly, :Tuple, tt_exprs...)
end
function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple}
F_tuple = Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...}
return tangent_type(F_tuple, R)
return tangent_type(Tuple{tuple_fill(NoFData, Val(length(R.parameters)))...}, R)
end
function tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Tuple}
R_tuple = Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...}
return tangent_type(F, R_tuple)
return tangent_type(F, Tuple{tuple_fill(NoRData, Val(length(F.parameters)))...})
end

# NamedTuples
Expand Down Expand Up @@ -904,10 +923,7 @@ Equivalent to `tangent(fdata, rdata(zero_tangent(primal)))`.
zero_tangent(p, ::NoFData) = zero_tangent(p)

function zero_tangent(p::P, f::F) where {P,F}
T = tangent_type(P)
T == F && return f
r = rdata(zero_tangent(p))
return tangent(f, r)
return tangent_type(P) == F ? f : tangent(f, rdata(zero_tangent(p)))
end

zero_tangent(p::Tuple, f::Union{Tuple,NamedTuple}) = tuple_map(zero_tangent, p, f)
6 changes: 2 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ function __value_and_pullback!!(
return v, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(rdata(ȳ)))
end

function __verify_sig(
rule::DerivedRule{<:Any,<:MistyClosure{<:OpaqueClosure{sig}}}, fx::Tfx
) where {sig,Tfx}
Pfx = typeof(__unflatten_codual_varargs(rule.isva, fx, rule.nargs))
function __verify_sig(rule::DerivedRule{<:Any,sig}, fx::Tfx) where {sig,Tfx}
Pfx = typeof(__unflatten_codual_varargs(_isva(rule), fx, rule.nargs))
if sig != Pfx
msg = "signature of arguments, $Pfx, not equal to signature required by rule, $sig."
throw(ArgumentError(msg))
Expand Down
Loading

0 comments on commit 1af280e

Please sign in to comment.