Skip to content

Commit

Permalink
Cache Rules in AbstractInterpreter (#80)
Browse files Browse the repository at this point in the history
* Cache rrules

* Add failing test

* Fix self-referencing function problem
  • Loading branch information
willtebbutt authored Feb 18, 2024
1 parent aea04bf commit 0376c00
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ end

function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig}

# If we've already constructed this interpreted function, just return it.
sig in keys(in_f.interp.in_f_rrule_cache) && return in_f.interp.in_f_rrule_cache[sig]

return_slot = SlotRef{codual_type(eltype(in_f.return_slot))}()
return_tangent_slot = SlotRef{tangent_type(eltype(in_f.return_slot))}()
arg_info = make_codual_arginfo(in_f.arg_info)
Expand Down Expand Up @@ -472,16 +475,19 @@ function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig}
# Set PhiNodes.
make_phi_instructions!(in_f, __rrule!!)

in_f.interp.in_f_rrule_cache[sig] = __rrule!!

return __rrule!!
end

struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V}
struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V, Q}
j::Int
bwds_instructions::Tbwds_f
ret_tangent::Tret_tangent
n_stack::Stack{Int}
arg_info::Targ_info
arg_tangent_stacks::V
arg_tangent_stack_refs::Q
end

function (in_f_rrule!!::InterpretedFunctionRRule{sig})(
Expand All @@ -503,15 +509,18 @@ function (in_f_rrule!!::InterpretedFunctionRRule{sig})(
n = 1
j = length(n_stack)

# Get references to top of tangent stacks for use on reverse-pass.
arg_tangent_stack_refs = map(top_ref, arg_tangent_stacks)

# Run instructions until done.
while next_block != -1
push!(n_stack, n)
if !isassigned(in_f_rrule!!.fwds_instructions, n)
fwds, bwds = generate_coinstructions(in_f, in_f_rrule!!, n)
in_f_rrule!!.fwds_instructions[n] = fwds
in_f_rrule!!.bwds_instructions[n] = bwds
end
next_block = in_f_rrule!!.fwds_instructions[n](prev_block)
push!(n_stack, n)
if next_block == 0
n += 1
elseif next_block > 0
Expand All @@ -530,6 +539,7 @@ function (in_f_rrule!!::InterpretedFunctionRRule{sig})(
n_stack,
arg_info,
arg_tangent_stacks,
arg_tangent_stack_refs,
)
return return_val, interpreted_function_pb!!
end
Expand All @@ -538,8 +548,8 @@ function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any,

# Update the output cotangent value to whatever is provided.
if_pb!!.ret_tangent[] = dout
tangent_stacks = if_pb!!.arg_tangent_stacks
set_tangent_stacks!(tangent_stacks, dargs, if_pb!!.arg_info)
tangent_stack_refs = if_pb!!.arg_tangent_stack_refs # this can go when we refactor
set_tangent_stacks!(tangent_stack_refs, dargs, if_pb!!.arg_info)

# Run the instructions in reverse. Present assumes linear instruction ordering.
n_stack = if_pb!!.n_stack
Expand All @@ -550,7 +560,7 @@ function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any,
end

# Return resulting tangents from slots.
return NoTangent(), assemble_dout(tangent_stacks, if_pb!!.arg_info)...
return NoTangent(), assemble_dout(if_pb!!.arg_tangent_stacks, if_pb!!.arg_info)...
end

function set_tangent_stacks!(tangent_stacks, dargs, ai::ArgInfo{<:Any, is_va}) where {is_va}
Expand Down
6 changes: 6 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,10 @@ end

sr(n) = Xoshiro(n)

@noinline function test_self_reference(a, b)
return a < b ? a * b : test_self_reference(b, a)
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down Expand Up @@ -1477,6 +1481,8 @@ function generate_test_functions()
test_union_of_types,
Ref{Union{Type{Float64}, Type{Int}}}(Float64),
),
(false, :allocs, nothing, test_self_reference, 1.1, 1.5),
(false, :allocs, nothing, test_self_reference, 1.5, 1.1),
(
false,
:none,
Expand Down

0 comments on commit 0376c00

Please sign in to comment.