diff --git a/src/nf/nf_dense_layer_submodule.f90 b/src/nf/nf_dense_layer_submodule.f90 index a424cf9c..53ca1b1d 100644 --- a/src/nf/nf_dense_layer_submodule.f90 +++ b/src/nf/nf_dense_layer_submodule.f90 @@ -2,7 +2,7 @@ use nf_activation, only: activation_function use nf_base_layer, only: base_layer - use nf_random, only: random_normal + use nf_random, only: random_normal, random_xavier, random_he implicit none @@ -125,8 +125,18 @@ module subroutine init(self, input_shape) ! Weights are a 2-d array of shape previous layer size ! times this layer size. allocate(self % weights(self % input_size, self % output_size)) - call random_normal(self % weights) - self % weights = self % weights / self % input_size + if (& + self % activation_name == 'relu' & + .or. self % activation_name == 'leaky_relu' & + .or. self % activation_name == 'celu' & + ) then + call random_he(self % weights, self % input_size) + elseif (self % activation_name == 'sigmoid' .or. self % activation_name == 'tanhf') then + call random_xavier(self % weights, self % input_size) + else + call random_normal(self % weights) + self % weights = self % weights / self % input_size + end if ! Broadcast weights to all other images, if any. #ifdef PARALLEL diff --git a/src/nf/nf_random.f90 b/src/nf/nf_random.f90 index 57c5d11f..6ed08626 100644 --- a/src/nf/nf_random.f90 +++ b/src/nf/nf_random.f90 @@ -6,7 +6,7 @@ module nf_random implicit none private - public :: random_normal + public :: random_normal, random_he, random_xavier real, parameter :: pi = 4 * atan(1.d0) @@ -23,4 +23,22 @@ impure elemental subroutine random_normal(x) x = sqrt(- 2 * log(u(1))) * cos(2 * pi * u(2)) end subroutine random_normal + impure elemental subroutine random_he(x, n_prev) + !! Kaiming weight initialization + real, intent(in out) :: x + integer, intent(in) :: n_prev + call random_number(x) + x = x * sqrt(2. / n_prev) + end subroutine random_he + + impure elemental subroutine random_xavier(x, n_prev) + !! Kaiming weight initialization + real, intent(in out) :: x + integer, intent(in) :: n_prev + real :: lower, upper + lower = -(1. / sqrt(real(n_prev))) + upper = 1. / sqrt(real(n_prev)) + call random_number(x) + x = lower + x * (upper - lower) + end subroutine random_xavier end module nf_random