Skip to content

Commit

Permalink
Proposition of API for the method network % evaluate (#182)
Browse files Browse the repository at this point in the history
* proposition of API for the method evaluate

* nf_metric -> nf_metrics for consistency with Python frameworks

* Add nf_metrics.f90 to the CMake build

* Make corr metric public

* Formatting

* Bump minor version

* Make metrics accessible via nf

* Evaluate metrics in MNIST example

* Add simple tests for metrics

* addition of maxabs

* Update example

* Remove multri-metrics variant of net % evaluate

* Mention metrics in README

---------

Co-authored-by: Vandenplas, Jeremie <[email protected]>
Co-authored-by: milancurcic <[email protected]>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent 6dfaed0 commit e82d565
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 12 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ add_library(neural-fortran
src/nf/nf_loss_submodule.f90
src/nf/nf_maxpool2d_layer.f90
src/nf/nf_maxpool2d_layer_submodule.f90
src/nf/nf_metrics.f90
src/nf/nf_network.f90
src/nf/nf_network_submodule.f90
src/nf/nf_optimizers.f90
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Read the paper [here](https://arxiv.org/abs/1902.06714).
* Stochastic gradient descent optimizers: Classic, momentum, Nesterov momentum,
RMSProp, Adagrad, Adam, AdamW
* More than a dozen activation functions and their derivatives
* Loss functions and metrics: Quadratic, Mean Squared Error, Pearson Correlation etc.
* Loading dense and convolutional models from Keras HDF5 (.h5) files
* Data-based parallelism

Expand Down
16 changes: 12 additions & 4 deletions example/dense_mnist.f90
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
program dense_mnist

use nf, only: dense, input, network, sgd, label_digits, load_mnist
use nf, only: dense, input, network, sgd, label_digits, load_mnist, corr

implicit none

Expand Down Expand Up @@ -38,9 +38,17 @@ program dense_mnist
optimizer=sgd(learning_rate=3.) &
)

if (this_image() == 1) &
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
net, validation_images, label_digits(validation_labels)) * 100, ' %'
block
real, allocatable :: output_metrics(:,:)
real, allocatable :: mean_metrics(:)
! 2 metrics; 1st is default loss function (quadratic), other is Pearson corr.
output_metrics = net % evaluate(validation_images, label_digits(validation_labels), metric=corr())
mean_metrics = sum(output_metrics, 1) / size(output_metrics, 1)
if (this_image() == 1) &
print '(a,i2,3(a,f6.3))', 'Epoch ', n, ' done, Accuracy: ', &
accuracy(net, validation_images, label_digits(validation_labels)) * 100, &
'%, Loss: ', mean_metrics(1), ', Pearson correlation: ', mean_metrics(2)
end block

end do epochs

Expand Down
2 changes: 1 addition & 1 deletion fpm.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name = "neural-fortran"
version = "0.16.1"
version = "0.17.0"
license = "MIT"
author = "Milan Curcic"
maintainer = "[email protected]"
Expand Down
1 change: 1 addition & 0 deletions src/nf.f90
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module nf
use nf_layer_constructors, only: &
conv2d, dense, flatten, input, maxpool2d, reshape
use nf_loss, only: mse, quadratic
use nf_metrics, only: corr, maxabs
use nf_network, only: network
use nf_optimizers, only: sgd, rmsprop, adam, adagrad
use nf_activation, only: activation_function, elu, exponential, &
Expand Down
9 changes: 2 additions & 7 deletions src/nf/nf_loss.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,20 @@ module nf_loss
!! loss type that extends the abstract loss derived type, and that
!! implements concrete eval and derivative methods that accept vectors.

use nf_metrics, only: metric_type
implicit none

private
public :: loss_type
public :: mse
public :: quadratic

type, abstract :: loss_type
type, extends(metric_type), abstract :: loss_type
contains
procedure(loss_interface), nopass, deferred :: eval
procedure(loss_derivative_interface), nopass, deferred :: derivative
end type loss_type

abstract interface
pure function loss_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
end function loss_interface
pure function loss_derivative_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
Expand Down
72 changes: 72 additions & 0 deletions src/nf/nf_metrics.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
module nf_metrics

!! This module provides a collection of metric functions.

implicit none

private
public :: metric_type
public :: corr
public :: maxabs

type, abstract :: metric_type
contains
procedure(metric_interface), nopass, deferred :: eval
end type metric_type

abstract interface
pure function metric_interface(true, predicted) result(res)
real, intent(in) :: true(:)
real, intent(in) :: predicted(:)
real :: res
end function metric_interface
end interface

type, extends(metric_type) :: corr
!! Pearson correlation
contains
procedure, nopass :: eval => corr_eval
end type corr

type, extends(metric_type) :: maxabs
!! Maximum absolute difference
contains
procedure, nopass :: eval => maxabs_eval
end type maxabs

contains

pure module function corr_eval(true, predicted) result(res)
!! Pearson correlation function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting correlation value
real :: m_true, m_pred

m_true = sum(true) / size(true)
m_pred = sum(predicted) / size(predicted)

res = dot_product(true - m_true, predicted - m_pred) / &
sqrt(sum((true - m_true)**2)*sum((predicted - m_pred)**2))

end function corr_eval

pure function maxabs_eval(true, predicted) result(res)
!! Maximum absolute difference function:
!!
real, intent(in) :: true(:)
!! True values, i.e. labels from training datasets
real, intent(in) :: predicted(:)
!! Values predicted by the network
real :: res
!! Resulting maximum absolute difference value

res = maxval(abs(true - predicted))

end function maxabs_eval

end module nf_metrics
13 changes: 13 additions & 0 deletions src/nf/nf_network.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module nf_network
!! This module provides the network type to create new models.

use nf_layer, only: layer
use nf_metrics, only: metric_type
use nf_loss, only: loss_type
use nf_optimizers, only: optimizer_base_type

Expand All @@ -28,13 +29,15 @@ module nf_network
procedure :: train
procedure :: update

procedure, private :: evaluate_batch_1d
procedure, private :: forward_1d
procedure, private :: forward_3d
procedure, private :: predict_1d
procedure, private :: predict_3d
procedure, private :: predict_batch_1d
procedure, private :: predict_batch_3d

generic :: evaluate => evaluate_batch_1d
generic :: forward => forward_1d, forward_3d
generic :: predict => predict_1d, predict_3d, predict_batch_1d, predict_batch_3d

Expand Down Expand Up @@ -62,6 +65,16 @@ end function network_from_keras

end interface network

interface evaluate
module function evaluate_batch_1d(self, input_data, output_data, metric) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
class(metric_type), intent(in), optional :: metric
real, allocatable :: res(:,:)
end function evaluate_batch_1d
end interface evaluate

interface forward

pure module subroutine forward_1d(self, input)
Expand Down
30 changes: 30 additions & 0 deletions src/nf/nf_network_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,36 @@ pure module subroutine backward(self, output, loss)
end subroutine backward


module function evaluate_batch_1d(self, input_data, output_data, metric) result(res)
class(network), intent(in out) :: self
real, intent(in) :: input_data(:,:)
real, intent(in) :: output_data(:,:)
class(metric_type), intent(in), optional :: metric
real, allocatable :: res(:,:)

integer :: i, n
real, allocatable :: output(:,:)

output = self % predict(input_data)

n = 1
if (present(metric)) n = n + 1

allocate(res(size(output, dim=1), n))

do concurrent (i = 1:size(output, dim=1))
res(i,1) = self % loss % eval(output_data(i,:), output(i,:))
end do

if (.not. present(metric)) return

do concurrent (i = 1:size(output, dim=1))
res(i,2) = metric % eval(output_data(i,:), output(i,:))
end do

end function evaluate_batch_1d


pure module subroutine forward_1d(self, input)
class(network), intent(in out) :: self
real, intent(in) :: input(:)
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ foreach(execid
conv2d_network
optimizers
loss
metrics
)
add_executable(test_${execid} test_${execid}.f90)
target_link_libraries(test_${execid} PRIVATE neural-fortran h5fortran::h5fortran jsonfortran::jsonfortran ${LIBS})
Expand Down
70 changes: 70 additions & 0 deletions test/test_metrics.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
program test_metrics
use iso_fortran_env, only: stderr => error_unit
use nf, only: dense, input, network, sgd, mse
implicit none
type(network) :: net
logical :: ok = .true.

! Minimal 2-layer network
net = network([ &
input(1), &
dense(1) &
])

training: block
real :: x(1), y(1)
real :: tolerance = 1e-3
integer :: n
integer, parameter :: num_iterations = 1000
real :: quadratic_loss, mse_metric
real, allocatable :: metrics(:,:)

x = [0.1234567]
y = [0.7654321]

do n = 1, num_iterations
call net % forward(x)
call net % backward(y)
call net % update(sgd(learning_rate=1.))
if (all(abs(net % predict(x) - y) < tolerance)) exit
end do

! Returns only one metric, based on the default loss function (quadratic).
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]))
quadratic_loss = metrics(1,1)

if (.not. all(shape(metrics) == [1, 1])) then
write(stderr, '(a)') 'metrics array is the correct shape (1, 1).. failed'
ok = .false.
end if

! Returns two metrics, one from the loss function and another specified by the user.
metrics = net % evaluate(reshape(x, [1, 1]), reshape(y, [1, 1]), metric=mse())

if (.not. all(shape(metrics) == [1, 2])) then
write(stderr, '(a)') 'metrics array is the correct shape (1, 2).. failed'
ok = .false.
end if

mse_metric = metrics(1,2)

if (.not. all(metrics < 1e-5)) then
write(stderr, '(a)') 'value for all metrics is expected.. failed'
ok = .false.
end if

if (.not. metrics(1,1) == quadratic_loss) then
write(stderr, '(a)') 'first metric should be the same as that of the loss function.. failed'
ok = .false.
end if

end block training

if (ok) then
print '(a)', 'test_metrics: All tests passed.'
else
write(stderr, '(a)') 'test_metrics: One or more tests failed.'
stop 1
end if

end program test_metrics

0 comments on commit e82d565

Please sign in to comment.