Skip to content

Commit d6bbbac

Browse files
committed
embedding_layer: add absolute positional encoding
1 parent 11d5e94 commit d6bbbac

File tree

3 files changed

+74
-16
lines changed

3 files changed

+74
-16
lines changed

src/nf/nf_embedding_layer.f90

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ module nf_embedding_layer
1515
!! This layer converts them into a table of shape
1616
!! (`sequence_length`, `model_dimension`)
1717
integer :: sequence_length, vocab_size, model_dimension
18-
logical :: positional
18+
integer :: positional
1919

2020
real, allocatable :: weights(:, :)
2121
real, allocatable :: output(:, :)
@@ -25,7 +25,8 @@ module nf_embedding_layer
2525

2626
procedure :: backward
2727
procedure :: forward
28-
procedure :: positional_encoding
28+
procedure :: positional_trigonometric
29+
procedure :: positional_absolute
2930
procedure :: init
3031
procedure :: get_num_params
3132
procedure :: get_params
@@ -37,7 +38,7 @@ module nf_embedding_layer
3738
interface embedding_layer
3839
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
3940
integer, intent(in) :: vocab_size, model_dimension
40-
logical, optional :: positional
41+
integer, optional :: positional
4142
type(embedding_layer) :: res
4243
end function embedding_layer_cons
4344
end interface embedding_layer
@@ -57,11 +58,17 @@ pure module subroutine backward(self, input, gradient)
5758
real, intent(in) :: gradient(:, :)
5859
end subroutine backward
5960

60-
pure module subroutine positional_encoding(self, pos)
61+
pure module subroutine positional_trigonometric(self, pos)
6162
!! Sum embedding with positional info (trigonometric, not trianable)
6263
class(embedding_layer), intent(in out) :: self
6364
integer, intent(in) :: pos
64-
end subroutine positional_encoding
65+
end subroutine positional_trigonometric
66+
67+
pure module subroutine positional_absolute(self, pos)
68+
!! Sum embedding with absolute position
69+
class(embedding_layer), intent(in out) :: self
70+
integer, intent(in) :: pos
71+
end subroutine positional_absolute
6572

6673
module subroutine init(self, input_shape)
6774
class(embedding_layer), intent(in out) :: self

src/nf/nf_embedding_layer_submodule.f90

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
#define NONE 0
2+
#define TRIGONOMETRIC 1
3+
#define ABSOLUTE 2
4+
15
submodule(nf_embedding_layer) nf_embedding_layer_submodule
26
use nf_base_layer, only: base_layer
37
implicit none
48
contains
59
module function embedding_layer_cons(vocab_size, model_dimension, positional) result(res)
610
integer, intent(in) :: vocab_size, model_dimension
7-
logical, optional :: positional
11+
integer, optional :: positional
812
type(embedding_layer) :: res
913

1014
res % vocab_size = vocab_size
1115
res % model_dimension = model_dimension
1216
if (.not. present(positional)) then
13-
res % positional = .false.
17+
res % positional = NONE
1418
else
1519
res % positional = positional
1620
end if
@@ -46,8 +50,10 @@ pure module subroutine forward(self, input)
4650

4751
self % output(i, :) = self % weights(index, :)
4852

49-
if (self % positional) then
50-
call self % positional_encoding(i)
53+
if (self % positional == TRIGONOMETRIC) then
54+
call self % positional_trigonometric(i)
55+
elseif (self % positional == ABSOLUTE) then
56+
call self % positional_absolute(i)
5157
end if
5258
end do
5359
end subroutine forward
@@ -63,7 +69,7 @@ pure module subroutine backward(self, input, gradient)
6369
end do
6470
end subroutine backward
6571

66-
pure module subroutine positional_encoding(self, pos)
72+
pure module subroutine positional_trigonometric(self, pos)
6773
class(embedding_layer), intent(in out) :: self
6874
integer, intent(in) :: pos
6975
integer :: i
@@ -74,7 +80,17 @@ pure module subroutine positional_encoding(self, pos)
7480
self % output(pos, 2 * i - 1) = self % output(pos, 2 * i - 1) + sin(theta)
7581
self % output(pos, 2 * i) = self % output(pos, 2 * i) + cos(theta)
7682
end do
77-
end subroutine positional_encoding
83+
end subroutine positional_trigonometric
84+
85+
pure module subroutine positional_absolute(self, pos)
86+
class(embedding_layer), intent(in out) :: self
87+
integer, intent(in) :: pos
88+
integer :: i
89+
90+
do concurrent(i = 1: self % model_dimension)
91+
self % output(pos, i) = self % output(pos, i) + pos - 1
92+
end do
93+
end subroutine positional_absolute
7894

7995
pure module function get_num_params(self) result(num_params)
8096
class(embedding_layer), intent(in) :: self

test/test_embedding_layer.f90

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ program test_embedding_layer
66
logical :: ok = .true.
77

88
call test_simple(ok)
9-
call test_positional(ok)
9+
call test_positional_trigonometric(ok)
10+
call test_positional_absolute(ok)
1011

1112
if (ok) then
1213
print '(a)', 'test_embedding_layer: All tests passed.'
@@ -47,7 +48,7 @@ subroutine test_simple(ok)
4748
end if
4849
end subroutine test_simple
4950

50-
subroutine test_positional(ok)
51+
subroutine test_positional_trigonometric(ok)
5152
logical, intent(in out) :: ok
5253

5354
integer :: sample_input(3) = [2, 1, 3]
@@ -63,7 +64,7 @@ subroutine test_positional(ok)
6364
real :: theta
6465
integer :: i, pos
6566

66-
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=.true.)
67+
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=1)
6768
call embedding % init([3])
6869
embedding % weights = reshape([&
6970
0.1, 0.3, 0.5, 0.7, 0.2,&
@@ -77,7 +78,41 @@ subroutine test_positional(ok)
7778
output_flat = reshape(embedding % output, [12])
7879
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
7980
ok = .false.
80-
write(stderr, '(a)') 'positional encoding returned incorrect values.. failed'
81+
write(stderr, '(a)') 'trigonometric positional encoding returned incorrect values.. failed'
8182
end if
82-
end subroutine test_positional
83+
end subroutine test_positional_trigonometric
84+
85+
subroutine test_positional_absolute(ok)
86+
logical, intent(in out) :: ok
87+
88+
integer :: sample_input(3) = [2, 1, 3]
89+
real :: output_flat(12)
90+
real :: expected_output_flat(12) = reshape([&
91+
0.3, 1.1, 2.5,&
92+
0.3, 1.1, 2.5,&
93+
0.3, 1.1, 2.5,&
94+
0.3, 1.1, 2.5&
95+
], [12])
96+
type(embedding_layer) :: embedding
97+
98+
real :: theta
99+
integer :: i, pos
100+
101+
embedding = embedding_layer(vocab_size=5, model_dimension=4, positional=2)
102+
call embedding % init([3])
103+
embedding % weights = reshape([&
104+
0.1, 0.3, 0.5, 0.7, 0.2,&
105+
0.1, 0.3, 0.5, 0.7, 0.2,&
106+
0.1, 0.3, 0.5, 0.7, 0.2,&
107+
0.1, 0.3, 0.5, 0.7, 0.2&
108+
], [5, 4])
109+
110+
call embedding % forward(sample_input)
111+
112+
output_flat = reshape(embedding % output, [12])
113+
if (.not. all(abs(output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs(expected_output_flat)))) then
114+
ok = .false.
115+
write(stderr, '(a)') 'absolute positional encoding returned incorrect values.. failed'
116+
end if
117+
end subroutine test_positional_absolute
83118
end program test_embedding_layer

0 commit comments

Comments
 (0)