Skip to content

Commit

Permalink
Issue 313 (#475)
Browse files Browse the repository at this point in the history
* WIP

* add chainrules frule, rrule

* correct merge
  • Loading branch information
jverzani authored Mar 7, 2023
1 parent ea0323d commit a6ba8fa
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 2 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@ author = "JuliaMath"
version = "3.2.5"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[compat]
RecipesBase = "0.7, 0.8, 1"
ChainRulesCore = "1"
MakieCore = "0.6"
RecipesBase = "0.7, 0.8, 1"
julia = "1.6"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down
2 changes: 1 addition & 1 deletion src/Polynomials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ include("rational-functions/plot-recipes.jl")

# compat; opt-in with `using Polynomials.PolyCompat`
include("polynomials/Poly.jl")

include("chain_rules.jl")
include("precompiles.jl")

end # module
16 changes: 16 additions & 0 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import ChainRulesCore

function ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
(_, Δx),
p::AbstractPolynomial,
x
)
p(x), derivative(p)(x)*Δx
end


function ChainRulesCore.rrule(p::AbstractPolynomial, x)
_pullback(ΔΩ) = (ChainRulesCore.NoTangent(), derivative(p)(x))
return (p(x), _pullback)
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down
11 changes: 11 additions & 0 deletions test/StandardBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1603,3 +1603,14 @@ end
@test Polynomials.minimumexponent(LaurentPolynomial{Float64}) == typemin(Int)
@test Polynomials.minimumexponent(LaurentPolynomial{Float64, :y}) == typemin(Int)
end


# Chain rules
using ChainRulesTestUtils

@testset "Test frule and rrule" begin
p = Polynomial([1,2,3,4])
dp = derivative(p)

test_scalar(p, 1.0; check_inferred=true)
end

0 comments on commit a6ba8fa

Please sign in to comment.