From 0376c004b470e25aac7ba1649c537a0d32a97c6f Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sun, 18 Feb 2024 20:07:17 +0000 Subject: [PATCH] Cache Rules in AbstractInterpreter (#80) * Cache rrules * Add failing test * Fix self-referencing function problem --- src/interpreter/reverse_mode_ad.jl | 20 +++++++++++++++----- src/test_utils.jl | 6 ++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/interpreter/reverse_mode_ad.jl b/src/interpreter/reverse_mode_ad.jl index 8f17d94b2..213885563 100644 --- a/src/interpreter/reverse_mode_ad.jl +++ b/src/interpreter/reverse_mode_ad.jl @@ -444,6 +444,9 @@ end function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig} + # If we've already constructed this interpreted function, just return it. + sig in keys(in_f.interp.in_f_rrule_cache) && return in_f.interp.in_f_rrule_cache[sig] + return_slot = SlotRef{codual_type(eltype(in_f.return_slot))}() return_tangent_slot = SlotRef{tangent_type(eltype(in_f.return_slot))}() arg_info = make_codual_arginfo(in_f.arg_info) @@ -472,16 +475,19 @@ function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig} # Set PhiNodes. make_phi_instructions!(in_f, __rrule!!) + in_f.interp.in_f_rrule_cache[sig] = __rrule!! + return __rrule!! end -struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V} +struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V, Q} j::Int bwds_instructions::Tbwds_f ret_tangent::Tret_tangent n_stack::Stack{Int} arg_info::Targ_info arg_tangent_stacks::V + arg_tangent_stack_refs::Q end function (in_f_rrule!!::InterpretedFunctionRRule{sig})( @@ -503,15 +509,18 @@ function (in_f_rrule!!::InterpretedFunctionRRule{sig})( n = 1 j = length(n_stack) + # Get references to top of tangent stacks for use on reverse-pass. + arg_tangent_stack_refs = map(top_ref, arg_tangent_stacks) + # Run instructions until done. while next_block != -1 - push!(n_stack, n) if !isassigned(in_f_rrule!!.fwds_instructions, n) fwds, bwds = generate_coinstructions(in_f, in_f_rrule!!, n) in_f_rrule!!.fwds_instructions[n] = fwds in_f_rrule!!.bwds_instructions[n] = bwds end next_block = in_f_rrule!!.fwds_instructions[n](prev_block) + push!(n_stack, n) if next_block == 0 n += 1 elseif next_block > 0 @@ -530,6 +539,7 @@ function (in_f_rrule!!::InterpretedFunctionRRule{sig})( n_stack, arg_info, arg_tangent_stacks, + arg_tangent_stack_refs, ) return return_val, interpreted_function_pb!! end @@ -538,8 +548,8 @@ function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any, # Update the output cotangent value to whatever is provided. if_pb!!.ret_tangent[] = dout - tangent_stacks = if_pb!!.arg_tangent_stacks - set_tangent_stacks!(tangent_stacks, dargs, if_pb!!.arg_info) + tangent_stack_refs = if_pb!!.arg_tangent_stack_refs # this can go when we refactor + set_tangent_stacks!(tangent_stack_refs, dargs, if_pb!!.arg_info) # Run the instructions in reverse. Present assumes linear instruction ordering. n_stack = if_pb!!.n_stack @@ -550,7 +560,7 @@ function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any, end # Return resulting tangents from slots. - return NoTangent(), assemble_dout(tangent_stacks, if_pb!!.arg_info)... + return NoTangent(), assemble_dout(if_pb!!.arg_tangent_stacks, if_pb!!.arg_info)... end function set_tangent_stacks!(tangent_stacks, dargs, ai::ArgInfo{<:Any, is_va}) where {is_va} diff --git a/src/test_utils.jl b/src/test_utils.jl index 8ea688726..689ac61d5 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1402,6 +1402,10 @@ end sr(n) = Xoshiro(n) +@noinline function test_self_reference(a, b) + return a < b ? a * b : test_self_reference(b, a) +end + function generate_test_functions() return Any[ (false, :allocs, nothing, const_tester), @@ -1477,6 +1481,8 @@ function generate_test_functions() test_union_of_types, Ref{Union{Type{Float64}, Type{Int}}}(Float64), ), + (false, :allocs, nothing, test_self_reference, 1.1, 1.5), + (false, :allocs, nothing, test_self_reference, 1.5, 1.1), ( false, :none,