diff --git a/src/interpreter/reverse_mode_ad.jl b/src/interpreter/reverse_mode_ad.jl index 06d629129..8f17d94b2 100644 --- a/src/interpreter/reverse_mode_ad.jl +++ b/src/interpreter/reverse_mode_ad.jl @@ -209,7 +209,7 @@ function build_coinsts(ir_inst::Expr, P, in_f, _rrule!!, n::Int, b::Int, is_blk_ arg_slots = map(arg -> _get_slot(arg, _rrule!!), (__args..., )) # Construct signature, and determine how the rrule is to be computed. - primal_sig = _typeof(map(primal ∘ get_codual, arg_slots)) + primal_sig = Tuple{map(arg -> eltype(_get_slot(arg, in_f)), (__args..., ))...} evaluator = get_evaluator(in_f.ctx, primal_sig, in_f.interp, is_invoke) __rrule!! = get_rrule!!_evaluator(evaluator) @@ -320,7 +320,7 @@ function rrule!!(_f::CoDual{<:DelayedInterpretedFunction{C, F}}, args::CoDual... f = primal(_f) s = _typeof(map(primal, args)) if is_primitive(C, s) - return rrule!!(zero_codual(f.f), args...) + return rrule!!(zero_codual(_eval), args...) else in_f = InterpretedFunction(f.ctx, s, f.interp) return build_rrule!!(in_f)(zero_codual(in_f), args...) diff --git a/src/test_utils.jl b/src/test_utils.jl index ffd17065e..8ea688726 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1376,6 +1376,30 @@ function test_union_of_types(x::Ref{Union{Type{Float64}, Type{Int}}}) return x[] end +# Only one of these is a primitive. Lots of methods to prevent the compiler from +# over-specialising. +@noinline edge_case_tester(x::Float64) = 5x +@noinline edge_case_tester(x::Any) = 5.0 +@noinline edge_case_tester(x::Float32) = 6.0 +@noinline edge_case_tester(x::Int) = 10 +@noinline edge_case_tester(x::String) = "hi" +@is_primitive MinimalCtx Tuple{typeof(edge_case_tester), Float64} +function Taped.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64}) + edge_case_tester_pb!!(dy, df, dx) = df, dx + 5 * dy + return CoDual(5 * primal(x), 0.0), edge_case_tester_pb!! +end + +# To test the edge case properly, call this with x = Any[5.0, false] +function test_primitive_dynamic_dispatch(x::Vector{Any}) + i = 0 + y = 0.0 + while i < 2 + i += 1 + y += edge_case_tester(x[i]) + end + return y +end + sr(n) = Xoshiro(n) function generate_test_functions() @@ -1402,6 +1426,7 @@ function generate_test_functions() (false, :none, nothing, type_unstable_tester, Ref{Any}(5.0)), (false, :none, nothing, type_unstable_tester_2, Ref{Real}(5.0)), (false, :none, (lb=1, ub=1000), type_unstable_tester_3, Ref{Any}(5.0)), + (false, :none, (lb=1, ub=1000), test_primitive_dynamic_dispatch, Any[5.0, false]), (false, :none, nothing, type_unstable_function_eval, Ref{Any}(sin), 5.0), (false, :allocs, nothing, phi_const_bool_tester, 5.0), (false, :allocs, nothing, phi_const_bool_tester, -5.0),