@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831 GGML_ASSERT (ggml_are_same_shape (src0, src0_grad));
88328832 GGML_ASSERT (ggml_are_same_shape (src0, src0_grad_m));
88338833 GGML_ASSERT (ggml_are_same_shape (src0, src0_grad_v));
8834- GGML_ASSERT (ggml_nelements (adamw_params) == 7 );
8834+ GGML_ASSERT (ggml_nelements (adamw_params) == 8 );
88358835
88368836 const int ith = params->ith ;
88378837 const int nth = params->nth ;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849 const int ir1 = MIN (ir0 + dr, nr);
88508850
88518851 const float * adamw_params_ptr = ggml_get_data_f32 (adamw_params);
8852+
88528853 const float alpha = adamw_params_ptr[0 ];
88538854 const float beta1 = adamw_params_ptr[1 ];
88548855 const float beta2 = adamw_params_ptr[2 ];
88558856 const float eps = adamw_params_ptr[3 ];
8856- const float wd = adamw_params_ptr[4 ];
88578857 const float beta1h = adamw_params_ptr[5 ];
88588858 const float beta2h = adamw_params_ptr[6 ];
8859-
8859+ const float keep = adamw_params_ptr[ 7 ];
88608860 for (int ir = ir0; ir < ir1; ++ir) {
88618861 const int64_t i03 = ir/(ne02*ne01);
88628862 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879 // The weight decay is applied independently of the Adam momenta m and v.
88808880 // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881 // See: https://arxiv.org/pdf/1711.05101v3.pdf
8882- w[i00] = w[i00]*( 1 . 0f - alpha*wd) - alpha*mh/ vh;
8882+ w[i00] = w[i00] * keep - alpha * mh / vh;
88838883 }
88848884 }
88858885}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901 }
89028902 }
89038903}
8904+
8905+ static void ggml_compute_forward_opt_step_sgd_f32 (const ggml_compute_params * params, ggml_tensor * dst) {
8906+ const ggml_tensor * src0 = dst->src [0 ];
8907+ const ggml_tensor * src0_grad = dst->src [1 ];
8908+ const ggml_tensor * adamw_params = dst->src [2 ];
8909+
8910+ GGML_ASSERT (ggml_are_same_shape (src0, src0_grad));
8911+ GGML_ASSERT (ggml_nelements (adamw_params) == 8 );
8912+
8913+ const int ith = params->ith ;
8914+ const int nth = params->nth ;
8915+
8916+ const int nr = ggml_nrows (src0);
8917+
8918+ GGML_TENSOR_UNARY_OP_LOCALS
8919+ GGML_ASSERT (nb00 == sizeof (float ));
8920+
8921+ // rows per thread
8922+ const int dr = (nr + nth - 1 ) / nth;
8923+
8924+ // row range for this thread
8925+ const int ir0 = dr * ith;
8926+ const int ir1 = MIN (ir0 + dr, nr);
8927+
8928+ // using adamw param subset we care about - alpha, wd - could have a separate struct
8929+ const float * adamw_params_ptr = ggml_get_data_f32 (adamw_params);
8930+ const float alpha = adamw_params_ptr[0 ];
8931+ const float keep = adamw_params_ptr[7 ];
8932+
8933+ for (int ir = ir0; ir < ir1; ++ir) {
8934+ const int64_t i03 = ir / (ne02 * ne01);
8935+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+
8938+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+
8940+ float * w = (float *) ((char *) src0->data + offset); // weight
8941+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
8942+
8943+ for (int i00 = 0 ; i00 < ne00; ++i00) {
8944+ w[i00] = w[i00] * keep - alpha * g[i00];
8945+ }
8946+ }
8947+ }
8948+
8949+ void ggml_compute_forward_opt_step_sgd (const ggml_compute_params * params, ggml_tensor * dst) {
8950+ const ggml_tensor * src0 = dst->src [0 ];
8951+
8952+ switch (src0->type ) {
8953+ case GGML_TYPE_F32:
8954+ {
8955+ ggml_compute_forward_opt_step_sgd_f32 (params, dst);
8956+ }
8957+ break ;
8958+ default :
8959+ {
8960+ GGML_ABORT (" fatal error - sgd is F32 only" );
8961+ }
8962+ }
8963+ }
0 commit comments