Skip to content

Commit

Permalink
Regression Tests Cleanup (#110)
Browse files Browse the repository at this point in the history
* use exact equality for tests

all of the values involved are integers or simple fractions
so we might as well use exact equality checks to rule out issues

* use higher-order combinators for comparison

* remove unnecessary eta expansions

* supply diff' to tests

* add HasCallStack constraint to expect for better test failure location

* add forward mode tests

* add notes about test limitations

* adjust type synonym name to be consistent with usage
  • Loading branch information
julmb authored Mar 11, 2024
1 parent 84d4f8f commit 2e37c7d
Showing 1 changed file with 60 additions and 65 deletions.
125 changes: 60 additions & 65 deletions tests/Regression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,111 +3,106 @@

module Main (main) where

import qualified Numeric.AD.Mode.Forward as F
import qualified Numeric.AD.Mode.Forward.Double as FD
import qualified Numeric.AD.Mode.Reverse as R
import qualified Numeric.AD.Mode.Reverse.Double as RD

import Text.Printf
import Test.Tasty
import Test.Tasty.HUnit

type Diff = (forall a. Floating a => a -> a) -> Double -> Double
type Diff' = (forall a. Floating a => a -> a) -> Double -> (Double, Double)
type Grad = (forall a. Floating a => [a] -> a) -> [Double] -> [Double]
type Jacobian = (forall a. Floating a => [a] -> [a]) -> [Double] -> [[Double]]
type Hessian = (forall a. Floating a => [a] -> a) -> [Double] -> [[Double]]

main :: IO ()
main = defaultMain tests

-- TODO: the forward-double tests are currently failing due to discrepancies between the modes
-- see also https://github.com/ekmett/ad/issues/109 and https://github.com/ekmett/ad/pull/110
tests :: TestTree
tests = testGroup "tests" [
mode "reverse" (\ f -> R.diff f) (\ f -> R.grad f) (\ f -> R.jacobian f) (\ f -> R.hessian f),
mode "reverse-double" (\ f -> RD.diff f) (\ f -> RD.grad f) (\ f -> RD.jacobian f) (\ f -> RD.hessian f)]
mode "forward" (\ f -> F.diff' f) (\ f -> F.grad f) (\ f -> F.jacobian f) (\ f -> F.jacobian $ F.grad f),
--mode "forward-double" (\ f -> FD.diff' f) (\ f -> FD.grad f) (\ f -> FD.jacobian f) (\ f -> FD.jacobian $ F.grad f),
mode "reverse" (\ f -> R.diff' f) (\ f -> R.grad f) (\ f -> R.jacobian f) (\ f -> R.hessian f),
mode "reverse-double" (\ f -> RD.diff' f) (\ f -> RD.grad f) (\ f -> RD.jacobian f) (\ f -> RD.hessian f)]

mode :: String -> Diff -> Grad -> Jacobian -> Hessian -> TestTree
mode :: String -> Diff' -> Grad -> Jacobian -> Hessian -> TestTree
mode name diff grad jacobian hessian = testGroup name [basic diff grad jacobian hessian, issue97 diff, issue104 diff grad]

basic :: Diff -> Grad -> Jacobian -> Hessian -> TestTree
basic :: Diff' -> Grad -> Jacobian -> Hessian -> TestTree
basic diff grad jacobian hessian = testGroup "basic" [tdiff, tgrad, tjacobian, thessian] where
tdiff = testCase "diff" $ do
assertNearList [11, 5.5, 3, 3.5, 7, 13.5, 23, 35.5, 51] $ diff p <$> [-2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2]
assertNearList [nan, inf, 1, 0.5, 0.25] $ diff sqrt <$> [-1, 0, 0.25, 1, 4]
assertNearList [1, 0, 1] $ [diff sin, diff cos, diff tan] <*> [0]
assertNearList [-1, 0, 1] $ diff abs <$> [-1, 0, 1]
assertNearList [1, exp 1, inf, 1] $ [diff exp, diff log] <*> [0, 1]
expect (list eq) [11, 5.5, 3, 3.5, 7, 13.5, 23, 35.5, 51] $ snd . diff p <$> [-2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2]
expect (list eq) [nan, inf, 1, 0.5, 0.25] $ snd . diff sqrt <$> [-1, 0, 0.25, 1, 4]
expect (list eq) [1, 0, 1] $ [snd . diff sin, snd . diff cos, snd . diff tan] <*> [0]
expect (list eq) [-1, 0, 1] $ snd . diff abs <$> [-1, 0, 1]
expect (list eq) [1, exp 1, inf, 1] $ [snd . diff exp, snd . diff log] <*> [0, 1]
tgrad = testCase "grad" $ do
assertNearList [2, 1, 1] $ grad f [1, 2, 3]
assertNearList [1, 0.25] $ grad h [2, 8]
assertNearList [0, nan] $ grad power [0, 2]
expect (list eq) [2, 1, 1] $ grad f [1, 2, 3]
expect (list eq) [1, 0.25] $ grad h [2, 8]
expect (list eq) [0, nan] $ grad power [0, 2]
tjacobian = testCase "jacobian" $ do
assertNearMatrix [[0, 1], [1, 0], [1, 2]] $ jacobian g [2, 1]
expect (list $ list eq) [[0, 1], [1, 0], [1, 2]] $ jacobian g [2, 1]
thessian = testCase "hessian" $ do
assertNearMatrix [[0, 1, 0], [1, 0, 0], [0, 0, 0]] $ hessian f [1, 2, 3]
assertNearMatrix [[0, 0], [0, 0]] $ hessian sum [1, 2]
assertNearMatrix [[0, 1], [1, 0]] $ hessian product [1, 2]
assertNearMatrix [[2, 1], [1, 0]] $ hessian power [1, 2]
sum = \ [x, y] -> x + y
product = \ [x, y] -> x * y
power = \ [x, y] -> x ** y
f = \ [x, y, z] -> x * y + z
g = \ [x, y] -> [y, x, x * y]
h = \ [x, y] -> sqrt $ x * y
p = \ x -> 12 + 7 * x + 5 * x ^ 2 + 2 * x ^ 3
expect (list $ list eq) [[0, 1, 0], [1, 0, 0], [0, 0, 0]] $ hessian f [1, 2, 3]
expect (list $ list eq) [[0, 0], [0, 0]] $ hessian sum [1, 2]
expect (list $ list eq) [[0, 1], [1, 0]] $ hessian product [1, 2]
expect (list $ list eq) [[2, 1], [1, 0]] $ hessian power [1, 2]
sum [x, y] = x + y
product [x, y] = x * y
power [x, y] = x ** y
f [x, y, z] = x * y + z
g [x, y] = [y, x, x * y]
h [x, y] = sqrt $ x * y
p x = 12 + 7 * x + 5 * x ^ 2 + 2 * x ^ 3

-- Reverse.Double +ffi initializes the tape with a block of size 4096
-- The large term in this function forces the allocation of an additional block
issue97 :: Diff -> TestTree
issue97 diff = testCase "issue-97" $ assertNear 5000 $ diff f 0 where f = sum . replicate 5000
issue97 :: Diff' -> TestTree
issue97 diff = testCase "issue-97" $ expect eq 5000 $ snd $ diff f 0 where f = sum . replicate 5000

issue104 :: Diff -> Grad -> TestTree
issue104 :: Diff' -> Grad -> TestTree
issue104 diff grad = testGroup "issue-104" [inside, outside] where
inside = testGroup "inside" [tdiff, tgrad] where
tdiff = testCase "diff" $ do
assertNearList [nan, nan] $ diff (0 `f`) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (1 `f`) <$> [0, 1]
assertNearList [nan, nan] $ diff (`f` 0) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (`f` 1) <$> [0, 1]
expect (list eq) [nan, nan] $ snd . diff (0 `f`) <$> [0, 1]
expect (list eq) [inf, 0.5] $ snd . diff (1 `f`) <$> [0, 1]
expect (list eq) [nan, nan] $ snd . diff (`f` 0) <$> [0, 1]
expect (list eq) [inf, 0.5] $ snd . diff (`f` 1) <$> [0, 1]
tgrad = testCase "grad" $ do
assertNearList [nan, nan] $ grad (binary f) [0, 0]
assertNearList [nan, inf] $ grad (binary f) [1, 0]
assertNearList [inf, nan] $ grad (binary f) [0, 1]
assertNearList [0.5, 0.5] $ grad (binary f) [1, 1]
expect (list eq) [nan, nan] $ grad (binary f) [0, 0]
expect (list eq) [nan, inf] $ grad (binary f) [1, 0]
expect (list eq) [inf, nan] $ grad (binary f) [0, 1]
expect (list eq) [0.5, 0.5] $ grad (binary f) [1, 1]
f x y = sqrt $ x * y -- grad f [x, y] = [y / (2 * f x y), x / (2 * f x y)]
outside = testGroup "outside" [tdiff, tgrad] where
tdiff = testCase "diff" $ do
assertNearList [nan, 0.0] $ diff (0 `f`) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (1 `f`) <$> [0, 1]
assertNearList [nan, 0.0] $ diff (`f` 0) <$> [0, 1]
assertNearList [inf, 0.5] $ diff (`f` 1) <$> [0, 1]
expect (list eq) [nan, 0.0] $ snd . diff (0 `f`) <$> [0, 1]
expect (list eq) [inf, 0.5] $ snd . diff (1 `f`) <$> [0, 1]
expect (list eq) [nan, 0.0] $ snd . diff (`f` 0) <$> [0, 1]
expect (list eq) [inf, 0.5] $ snd . diff (`f` 1) <$> [0, 1]
tgrad = testCase "grad" $ do
assertNearList [nan, nan] $ grad (binary f) [0, 0]
assertNearList [0.0, inf] $ grad (binary f) [1, 0]
assertNearList [inf, 0.0] $ grad (binary f) [0, 1]
assertNearList [0.5, 0.5] $ grad (binary f) [1, 1]
expect (list eq) [nan, nan] $ grad (binary f) [0, 0]
expect (list eq) [0.0, inf] $ grad (binary f) [1, 0]
expect (list eq) [inf, 0.0] $ grad (binary f) [0, 1]
expect (list eq) [0.5, 0.5] $ grad (binary f) [1, 1]
f x y = sqrt x * sqrt y -- grad f [x, y] = [sqrt y / 2 sqrt x, sqrt x / 2 sqrt y]
binary f = \ [x, y] -> f x y
binary f [x, y] = f x y

near :: Double -> Double -> Bool
near a b = bothNaN || bothInfinite || abs (a - b) <= 1e-12 where
bothNaN = isNaN a && isNaN b
bothInfinite = signum a == signum b && isInfinite a && isInfinite b
-- TODO: ideally, we would consider `0` and `-0` to be different
-- however, zero signedness is currently not reliably propagated through some modes
-- see also https://github.com/ekmett/ad/issues/109 and https://github.com/ekmett/ad/pull/110
eq :: Double -> Double -> Bool
eq a b = isNaN a && isNaN b || a == b

nearList :: [Double] -> [Double] -> Bool
nearList as bs = length as == length bs && and (zipWith near as bs)
list :: (a -> a -> Bool) -> [a] -> [a] -> Bool
list eq as bs = length as == length bs && and (zipWith eq as bs)

nearMatrix :: [[Double]] -> [[Double]] -> Bool
nearMatrix as bs = length as == length bs && and (zipWith nearList as bs)

assertNear :: Double -> Double -> Assertion
assertNear a b = near a b @? expect a b

assertNearList :: [Double] -> [Double] -> Assertion
assertNearList a b = nearList a b @? expect a b

assertNearMatrix :: [[Double]] -> [[Double]] -> Assertion
assertNearMatrix a b = nearMatrix a b @? expect a b

expect :: Show a => a -> a -> String
expect a b = printf "expected %s but got %s" (show a) (show b)
expect :: HasCallStack => Show a => (a -> a -> Bool) -> a -> a -> Assertion
expect eq a b = eq a b @? printf "expected %s but got %s" (show a) (show b)

nan :: Double
nan = 0 / 0
Expand Down

0 comments on commit 2e37c7d

Please sign in to comment.