Skip to content

Commit

Permalink
Minimal test for expected values
Browse files Browse the repository at this point in the history
  • Loading branch information
milancurcic committed Apr 18, 2024
1 parent 6033ec1 commit 634ce92
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions test/test_loss.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,42 @@ program test_loss

logical :: ok = .true.

block

type(mse) :: loss
real :: true(2) = [1., 2.]
real :: pred(2) = [3., 4.]

if (.not. loss % eval(true, pred) == 4) then
write(stderr, '(a)') 'expected output of mse % eval().. failed'
ok = .false.
end if

if (.not. all(loss % derivative(true, pred) == [2, 2])) then
write(stderr, '(a)') 'expected output of mse % derivative().. failed'
ok = .false.
end if

end block

block

type(quadratic) :: loss
real :: true(4) = [1., 2., 3., 4.]
real :: pred(4) = [3., 4., 5., 6.]

if (.not. loss % eval(true, pred) == 8) then
write(stderr, '(a)') 'expected output of quadratic % eval().. failed'
ok = .false.
end if

if (.not. all(loss % derivative(true, pred) == [2, 2, 2, 2])) then
write(stderr, '(a)') 'expected output of quadratic % derivative().. failed'
ok = .false.
end if

end block

if (ok) then
print '(a)', 'test_loss: All tests passed.'
else
Expand Down

0 comments on commit 634ce92

Please sign in to comment.