diff --git a/src/nf/nf_multihead_attention.f90 b/src/nf/nf_multihead_attention.f90 index 80a59dfb..6ff51dd5 100644 --- a/src/nf/nf_multihead_attention.f90 +++ b/src/nf/nf_multihead_attention.f90 @@ -39,10 +39,25 @@ module nf_multihead_attention_layer real, allocatable :: k_input(:, :) real, allocatable :: v_input(:, :) real, allocatable :: o_input(:, :) + + ! temporary storages for forward and backward passes + real, allocatable :: normalized_attention(:, :, :) + real, allocatable :: q_or_dq(:, :, :) + real, allocatable :: k_or_dk(:, :, :) + real, allocatable :: v_or_dv(:, :, :) + real, allocatable :: d_output(:, :, :) + real, allocatable :: v_heads(:, :, :) + real, allocatable :: k_heads(:, :, :) + real, allocatable :: q_heads(:, :, :) + real, allocatable :: d_sdpa(:, :) + real, allocatable :: jacobian(:, :) + real, allocatable :: d_normalize(:, :, :) contains procedure :: common_backward procedure :: common_forward + procedure :: sdpa_forward + procedure :: sdpa_backward procedure :: get_num_params procedure :: get_params procedure :: get_gradients @@ -68,7 +83,7 @@ end function multihead_attention_layer_cons interface - pure module subroutine common_backward(self, input, gradient) + pure module subroutine common_backward(self, input, gradient, attention_mask) !! General backprop for MultiHead Attention mechanism !! Might be used for both Self and Cross Attention !! Self Attention: sum output gradients @@ -76,17 +91,30 @@ pure module subroutine common_backward(self, input, gradient) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) + real, optional, intent(in) :: attention_mask(:, :) end subroutine common_backward - pure module subroutine common_forward(self, query, key, value) + pure module subroutine common_forward(self, query, key, value, attention_mask) !! General forward propagation for MultiHead Attention Mechanism !! Might be used for both Self and Cross Attention !! Self Attention: pass the same value thrice !! Cross Attention: pass three values for your query, key and value class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: query(:, :), key(:, :), value(:, :) + real, optional, intent(in) :: attention_mask(:, :) end subroutine common_forward + pure module subroutine sdpa_forward(self, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), optional :: attention_mask(:, :) + end subroutine sdpa_forward + + pure module subroutine sdpa_backward(self, gradient, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) + end subroutine sdpa_backward + pure module subroutine init(self, input_shape) !! Initialize the layer data structures. !! @@ -119,7 +147,7 @@ pure module subroutine normalize_attention_matrix(self, attention_mask) !! Output dims: sequence_length, sequence_length, n_heads class(multihead_attention_layer), intent(in out) :: self !! (sequence_length, sequence_length, n_heads) - real, optional, intent(in) :: attention_mask(:, :, :) + real, optional, intent(in) :: attention_mask(:, :) !! (sequence_length, sequence_length, n_heads) end subroutine normalize_attention_matrix @@ -143,18 +171,18 @@ elemental module function get_num_params(self) result(num_params) end function get_num_params module function get_params(self) result(params) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: params(:) end function get_params module function get_gradients(self) result(gradients) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: gradients(:) end function get_gradients module subroutine set_params(self, params) class(multihead_attention_layer), intent(in out) :: self - real, intent(in), target :: params(:) + real, intent(in) :: params(:) end subroutine set_params module subroutine init_base(self, input_shape) diff --git a/src/nf/nf_multihead_attention_submodule.f90 b/src/nf/nf_multihead_attention_submodule.f90 index d0e43a2e..f78abafd 100644 --- a/src/nf/nf_multihead_attention_submodule.f90 +++ b/src/nf/nf_multihead_attention_submodule.f90 @@ -1,5 +1,4 @@ submodule(nf_multihead_attention_layer) nf_multihead_attention_layer_submodule -! use iso_fortran_env, only: stderr => error_unit use nf_activation, only: softmax use nf_base_layer, only: base_layer use nf_linear2d_layer, only: linear2d_layer @@ -14,51 +13,91 @@ module function multihead_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function multihead_attention_layer_cons - pure module subroutine common_backward(self, input, gradient) + pure module subroutine common_backward(self, input, gradient, attention_mask) class(multihead_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) - real, allocatable :: d_output(:, :, :) - real, allocatable :: v_heads(:, :, :) - real, allocatable :: k_heads(:, :, :) - real, allocatable :: q_heads(:, :, :) - real, allocatable :: dv(:, :, :) - real, allocatable :: d_sdpa(:, :) - real, allocatable :: jacobian(:, :) - real, allocatable :: d_normalize(:, :, :) - real, allocatable :: dq(:, :, :) - real, allocatable :: dk(:, :, :) integer :: head, seq, i, j - ! allocate temporary storages for backward computation - allocate(d_output(self % sequence_length, self % head_size, self % n_heads)) - allocate(v_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(k_heads(self % sequence_length, self % head_size, self % n_heads)) - allocate(q_heads(self % sequence_length, self % head_size, self % n_heads)) + self % v_heads = self % split_heads(self % value_layer % output) + self % k_heads = self % split_heads(self % key_layer % output) + self % q_heads = self % split_heads(self % query_layer % output) - allocate(dv(self % sequence_length, self % head_size, self % n_heads)) - allocate(d_sdpa(self % sequence_length, self % sequence_length)) - allocate(jacobian(self % sequence_length, self % sequence_length)) - allocate(d_normalize(self % sequence_length, self % sequence_length, self % n_heads)) - allocate(dq(self % sequence_length, self % head_size, self % n_heads)) - allocate(dk(self % sequence_length, self % head_size, self % n_heads)) + ! bakward through attention mechanism + call self % sdpa_backward(gradient, attention_mask) + + ! calculate deltas for input layers + call self % value_layer % backward(self % v_input, self % combine_heads(self % v_or_dv)) + call self % key_layer % backward(self % k_input, self % combine_heads(self % k_or_dk)) + call self % query_layer % backward(self % q_input, self % combine_heads(self % q_or_dq)) + end subroutine common_backward + + pure module subroutine common_forward(self, query, key, value, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: query(:, :), key(:, :), value(:, :) + real, intent(in), optional :: attention_mask(:, :) + + self % q_input = query + self % k_input = key + self % v_input = value + + ! run inputs through linear layers (trainable params) + call self % query_layer % forward(query) + call self % key_layer % forward(key) + call self % value_layer % forward(value) + + ! split attention heads for more efficient computation + self % q_or_dq = self % split_heads(self % query_layer % output) + self % k_or_dk = self % split_heads(self % key_layer % output) + self % v_or_dv = self % split_heads(self % value_layer % output) + + call self % sdpa_forward(attention_mask) + end subroutine common_forward + + pure module subroutine sdpa_forward(self, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in), optional :: attention_mask(:, :) + + ! create key by value matrix + call self % create_attention_matrix(self % q_or_dq, self % k_or_dk) + ! apply softmax and scaling + call self % normalize_attention_matrix(attention_mask) + ! multiply attention matrix by value + call self % scaled_dot_product_attention(self % v_or_dv) + + self % o_input = self % combine_heads(self % sdpa) + call self % output_layer % forward(self % o_input) + self % output = self % output_layer % output + end subroutine sdpa_forward + + pure module subroutine sdpa_backward(self, gradient, attention_mask) + class(multihead_attention_layer), intent(in out) :: self + real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) + + integer :: head, seq, i, j ! calculate output layer delta call self % output_layer % backward(self % o_input, gradient) ! split heads from output gradient - d_output = self % split_heads(self % output_layer % gradient) - v_heads = self % split_heads(self % value_layer % output) - k_heads = self % split_heads(self % key_layer % output) - q_heads = self % split_heads(self % query_layer % output) + self % d_output = self % split_heads(self % output_layer % gradient) ! iterate over heads to calculate deltas for each of them do concurrent(head = 1: self % n_heads) - dv(:, :, head) = matmul(transpose(self % attention_matrix(:, :, head)), d_output(:, :, head)) + self % v_or_dv(:, :, head) = matmul(& + transpose(self % attention_matrix(:, :, head)),& + self % d_output(:, :, head)& + ) ! calculate delta for attention matrix - d_sdpa = matmul(d_output(:, :, head), transpose(v_heads(:, :, head))) + self % d_sdpa = matmul(self % d_output(:, :, head), transpose(self % v_heads(:, :, head))) + + if (present(attention_mask)) then + self % d_sdpa = self % d_sdpa + attention_mask + end if ! this monstrosity below is scaled derivative of softmax do concurrent(seq = 1: self % sequence_length) @@ -69,11 +108,11 @@ pure module subroutine common_backward(self, input, gradient) ! should be: `softmax(x_i) * (1 - softmax(x_i))` ! for off-diagonal: `-softmax(x_i) * softmax(x_j)` if (i == j) then - jacobian(i, j) = & + self % jacobian(i, j) = & self % attention_matrix(seq, i, head) & * (1 - self % attention_matrix(seq, i, head)) else - jacobian(i, j) = & + self % jacobian(i, j) = & - self % attention_matrix(seq, i, head) & * self % attention_matrix(seq, j, head) end if @@ -82,79 +121,19 @@ pure module subroutine common_backward(self, input, gradient) ! multiply output of softmax by temp jacobian matrix ! For computational efficiency (avoid more temp storages), scaling is also done here ! reshapes: [3] -> [1, 3] @ [3, 3] = [1, 3] -> [3] - d_normalize(seq, :, head) = reshape(matmul(& - reshape(d_sdpa(seq, :), [1, self % sequence_length]),& - jacobian * self % scaling_factor& + self % d_normalize(seq, :, head) = reshape(matmul(& + reshape(self % d_sdpa(seq, :), [1, self % sequence_length]),& + self % jacobian * self % scaling_factor& ), [self % sequence_length]) end do ! calculate delta for query - dq(:, :, head) = matmul(d_normalize(:, :, head), k_heads(:, :, head)) + self % q_or_dq(:, :, head) = matmul(self % d_normalize(:, :, head), self % k_heads(:, :, head)) ! calculate delta for key, attention matrix should be transposed unlike for query - dk(:, :, head) = matmul(transpose(d_normalize(:, :, head)), q_heads(:, :, head)) + self % k_or_dk(:, :, head) = matmul(transpose(self % d_normalize(:, :, head)), self % q_heads(:, :, head)) end do - - ! calculate deltas for input layers - call self % value_layer % backward(self % v_input, self % combine_heads(dv)) - call self % key_layer % backward(self % k_input, self % combine_heads(dk)) - call self % query_layer % backward(self % q_input, self % combine_heads(dq)) - - ! free temporary storages - deallocate(d_output) - deallocate(v_heads) - deallocate(k_heads) - deallocate(q_heads) - deallocate(d_sdpa) - deallocate(jacobian) - deallocate(d_normalize) - deallocate(dq) - deallocate(dk) - end subroutine common_backward - - pure module subroutine common_forward(self, query, key, value) - class(multihead_attention_layer), intent(in out) :: self - real, intent(in) :: query(:, :), key(:, :), value(:, :) - - real, allocatable :: q(:, :, :) - real, allocatable :: k(:, :, :) - real, allocatable :: v(:, :, :) - - ! allocate storage for intermidiate stages - allocate(q(self % sequence_length, self % head_size, self % n_heads)) - allocate(k(self % sequence_length, self % head_size, self % n_heads)) - allocate(v(self % sequence_length, self % head_size, self % n_heads)) - - self % q_input = query - self % k_input = key - self % v_input = value - - ! run inputs through linear layers (trainable params) - call self % query_layer % forward(query) - call self % key_layer % forward(key) - call self % value_layer % forward(value) - - ! split attention heads for more efficient computation - q = self % split_heads(self % query_layer % output) - k = self % split_heads(self % key_layer % output) - v = self % split_heads(self % value_layer % output) - - ! create key by value matrix - call self % create_attention_matrix(q, k) - ! apply softmax and scaling - call self % normalize_attention_matrix() - ! multiply attention matrix by value - call self % scaled_dot_product_attention(v) - - self % o_input = self % combine_heads(self % sdpa) - call self % output_layer % forward(self % o_input) - self % output = self % output_layer % output - - ! free temp vars from memory - deallocate(q) - deallocate(k) - deallocate(v) - end subroutine common_forward + end subroutine sdpa_backward pure module function split_heads(self, input) result(output) class(multihead_attention_layer), intent(in) :: self @@ -176,26 +155,22 @@ end subroutine create_attention_matrix pure module subroutine normalize_attention_matrix(self, attention_mask) class(multihead_attention_layer), intent(in out) :: self - real, optional, intent(in) :: attention_mask(:, :, :) - real, allocatable :: output(:, :, :) + real, optional, intent(in) :: attention_mask(:, :) integer :: head, seq - ! temporary storage - allocate(output(self % sequence_length, self % sequence_length, self % n_heads)) - ! scale dowm by square root of each head's size self % attention_matrix = self % attention_matrix * self % scaling_factor ! attention mask is used to mask out some of the tokens if necessary if (present(attention_mask)) then - self % attention_matrix = self % attention_matrix + attention_mask + do concurrent(head = 1: self % n_heads) + self % attention_matrix(:, :, head) = self % attention_matrix(:, :, head) + attention_mask + end do end if ! softmax by last sequnce_length do concurrent(head = 1: self % n_heads, seq = 1: self % sequence_length) - output(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) + self % normalized_attention(seq, :, head) = self % softmax_func % eval_1d(self % attention_matrix(seq, :, head)) end do - self % attention_matrix = output - - deallocate(output) + self % attention_matrix = self % normalized_attention end subroutine normalize_attention_matrix pure module subroutine scaled_dot_product_attention(self, value) @@ -231,7 +206,7 @@ elemental module function get_num_params(self) result(num_params) end function get_num_params module function get_params(self) result(params) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: params(:) params = [& @@ -247,7 +222,7 @@ module function get_params(self) result(params) end function get_params module function get_gradients(self) result(gradients) - class(multihead_attention_layer), intent(in), target :: self + class(multihead_attention_layer), intent(in) :: self real, allocatable :: gradients(:) gradients = [ & @@ -264,8 +239,7 @@ end function get_gradients module subroutine set_params(self, params) class(multihead_attention_layer), intent(in out) :: self - real, intent(in), target :: params(:) - real, pointer :: p_(:,:) => null() + real, intent(in) :: params(:) integer :: i, j, window ! check if the number of parameters is correct @@ -335,9 +309,30 @@ module subroutine init_base(self, input_shape) self % scaling_factor = sqrt(1 / real(self % head_size)) - allocate(self % q_input(self % sequence_length, self % model_dimension)) - allocate(self % k_input(self % sequence_length, self % model_dimension)) - allocate(self % v_input(self % sequence_length, self % model_dimension)) - allocate(self % o_input(self % sequence_length, self % model_dimension)) + allocate(self % q_input, mold=self % output) + allocate(self % k_input, mold=self % output) + allocate(self % v_input, mold=self % output) + allocate(self % o_input, mold=self % output) + + ! allocate temporary storages + + ! this one is for forward pass + allocate(self % normalized_attention, mold=self % attention_matrix) + + ! the following three are used twice: + ! Forward pass: As inputs after the corresponding linear layer and head reshape + ! Backward pass: As deltas for each input array + allocate(self % q_or_dq, mold=self % sdpa) + allocate(self % k_or_dk, mold=self % sdpa) + allocate(self % v_or_dv, mold=self % sdpa) + + ! the other seven below are for backward pass + allocate(self % d_output, mold=self % sdpa) + allocate(self % v_heads, mold=self % sdpa) + allocate(self % k_heads, mold=self % sdpa) + allocate(self % q_heads, mold=self % sdpa) + allocate(self % d_sdpa(self % sequence_length, self % sequence_length)) + allocate(self % jacobian, mold=self % d_sdpa) + allocate(self % d_normalize, mold=self % attention_matrix) end subroutine init_base end submodule nf_multihead_attention_layer_submodule \ No newline at end of file diff --git a/src/nf/nf_self_attention_layer.f90 b/src/nf/nf_self_attention_layer.f90 index 15e8f40c..0b5f217d 100644 --- a/src/nf/nf_self_attention_layer.f90 +++ b/src/nf/nf_self_attention_layer.f90 @@ -35,14 +35,15 @@ module function self_attention_layer_cons(n_heads) result(res) res % n_heads = n_heads end function self_attention_layer_cons - pure module subroutine backward(self, input, gradient) + pure module subroutine backward(self, input, gradient, attention_mask) !! Self Attention back propagation !! Returns sum of Query, Key and Value gradients class(self_attention_layer), intent(in out) :: self real, intent(in) :: input(:, :) real, intent(in) :: gradient(:, :) + real, intent(in), optional :: attention_mask(:, :) - call self % common_backward(input, gradient) + call self % common_backward(input, gradient, attention_mask) self % gradient = & self % query_layer % gradient & + self % key_layer % gradient & diff --git a/test/test_multihead_attention_layer.f90 b/test/test_multihead_attention_layer.f90 index fdc6862d..e9845bba 100644 --- a/test/test_multihead_attention_layer.f90 +++ b/test/test_multihead_attention_layer.f90 @@ -27,6 +27,7 @@ program test_multihead_attention_layer call test_multihead_attention_backward(attention, ok) call test_multihead_attention_update_gradients(attention, ok) call test_multihead_attention_forward_reallife_shape(ok) + call test_multihead_attention_mask(ok) call test_self_attention(ok) call test_cross_attention(ok) @@ -255,11 +256,10 @@ subroutine test_multihead_attention_backward(attention, ok) call attention % common_backward(input, gradient) ! sample for Self Attention: sum of output gradients - ! FIXME: remove reshapes when linear2d situation is resolved output = & - reshape(attention % query_layer % gradient, [attention % sequence_length, attention % model_dimension]) & - + reshape(attention % key_layer % gradient, [attention % sequence_length, attention % model_dimension]) & - + reshape(attention % value_layer % gradient, [attention % sequence_length, attention % model_dimension]) + attention % query_layer % gradient & + + attention % key_layer % gradient & + + attention % value_layer % gradient output_shape = shape(output) if (.not. all(output_shape.eq.expected_shape)) then @@ -315,6 +315,72 @@ subroutine test_multihead_attention_update_gradients(attention, ok) end if end subroutine test_multihead_attention_update_gradients + subroutine test_multihead_attention_mask(ok) + logical, intent(in out) :: ok + type(multihead_attention_layer) :: attention + real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4]) + real :: gradient(3, 4) = reshape([0.1, 3., 2., 0.1, 3., 3., 0.1, 2., 0.1, 3., 0.1, 3.], [3, 4]) + real :: attention_mask(3, 3) = reshape([& + 0., 0., 0.,& + 0., 0., -100.,& + -100., 0., -100.& + ], [3, 3]) + real :: output(3, 4) + real, volatile :: output_flat(12) + real, volatile :: attn_weights_flat(18) + real :: expected_output_flat(12) = [& + 0.94935673, 1.0040786, 0.626,& + 0.94935673, 1.0040786, 0.626,& + 0.94935673, 1.0040786, 0.626,& + 0.94935673, 1.0040786, 0.626& + ] + real :: expected_attn_weights_flat(18) = [& + 0.149956360, 2.28110179E-02, 1.0,& + 0.850043654, 0.464612424, 0.0,& + 0.0, 0.512576580, 0.0,& + 0.149956360, 2.28110179E-02, 1.0,& + 0.850043654, 0.464612424, 0.0,& + 0.0, 0.512576580, 0.0& + ] + real :: gradient_flat(12) + real :: expacted_gradient_flat(12) = [& + 0.32137412, 0.30436403, 0.1854456,& + 0.32137412, 0.30436403, 0.1854456,& + 0.32137412, 0.30436403, 0.1854456,& + 0.32137412, 0.30436403, 0.1854456& + ] + + attention = multihead_attention_layer(n_heads=2) + call attention % init_base([3, 4]) + call set_weights(attention) + + call attention % common_forward(input, input, input, attention_mask=attention_mask) + + output_flat = reshape(attention % output, shape(output_flat)) + if (.not. allclose(output_flat, expected_output_flat)) then + ok = .false. + write(stderr, '(a)') 'forward w. attention mask returned incorrect values.. failed' + end if + + attn_weights_flat = reshape(attention % attention_matrix, shape(attn_weights_flat)) + if (.not. allclose(attn_weights_flat, expected_attn_weights_flat)) then + ok = .false. + write(stderr, '(a)') 'forward w. attention mask returned incorrect attention weights values.. failed' + end if + + call attention % common_backward(input, gradient, attention_mask) + gradient_flat = reshape(& + attention % query_layer % gradient & + + attention % key_layer % gradient & + + attention % value_layer % gradient,& + [12]& + ) + if (.not. allclose(gradient_flat, expacted_gradient_flat)) then + ok = .false. + write(stderr, '(a)') 'backward w. attention mask returned incorrect gradient values.. failed' + end if + end subroutine test_multihead_attention_mask + subroutine test_self_attention(ok) logical, intent(in out) :: ok type(self_attention_layer) :: attention