11package us .ihmc .perception .cuda ;
22
3+ import org .bytedeco .cuda .cudart .CUevent_st ;
34import org .bytedeco .cuda .cudart .CUfunc_st ;
45import org .bytedeco .cuda .cudart .CUmod_st ;
56import org .bytedeco .cuda .cudart .CUstream_st ;
1112import org .bytedeco .javacpp .LongPointer ;
1213import org .bytedeco .javacpp .Pointer ;
1314import org .bytedeco .javacpp .PointerPointer ;
15+ import us .ihmc .log .LogTools ;
1416
1517import java .util .ArrayList ;
18+ import java .util .LinkedList ;
1619import java .util .List ;
20+ import java .util .Optional ;
1721
18- import static org .bytedeco .cuda .global .cudart .cuLaunchKernel ;
19- import static org .bytedeco .cuda .global .cudart .cuModuleGetFunction ;
22+ import static org .bytedeco .cuda .global .cudart .*;
2023import static us .ihmc .perception .cuda .CUDATools .throwCUDAError ;
2124
2225@ SuppressWarnings ("resource" )
2326public class CUDAKernel implements AutoCloseable
2427{
28+ private final String name ;
2529 private final CUfunc_st kernelFunction = new CUfunc_st ();
2630 private final List <Pointer > parameters = new ArrayList <>();
2731 private boolean retainParameters = false ;
32+ private boolean enableKernelTimings = false ;
33+
34+ private CUDAKernelTimings kernelTimings ;
35+ private final CUevent_st start = new CUevent_st ();
36+ private final CUevent_st end = new CUevent_st ();
2837
2938 private int error ;
3039
3140 public CUDAKernel (String name , CUmod_st kernelModule ) throws Exception
3241 {
42+ this .name = name ;
3343 error = cuModuleGetFunction (kernelFunction , kernelModule , name );
3444 throwCUDAError (error );
3545 }
3646
47+ /**
48+ * Setting this to true enables the ability to run timings on the specific kernel.
49+ * The timing checks perform synchronization calls.
50+ */
51+ public void enableKernelTimings (boolean enableKernelTimings )
52+ {
53+ this .enableKernelTimings = enableKernelTimings ;
54+ kernelTimings = new CUDAKernelTimings ();
55+ }
56+
3757 public void retainParameters (boolean retainParameters )
3858 {
3959 this .retainParameters = retainParameters ;
@@ -53,6 +73,13 @@ public void run(CUstream_st stream, dim3 gridSize, dim3 blockSize, int sharedMem
5373 for (int i = 0 ; i < parameters .size (); ++i )
5474 parametersPointer .put (i , parameters .get (i ));
5575
76+ if (enableKernelTimings )
77+ {
78+ cudaEventCreate (start );
79+ cudaEventCreate (end );
80+ cudaEventRecord (start );
81+ }
82+
5683 error = cuLaunchKernel (kernelFunction ,
5784 gridSize .x (),
5885 gridSize .y (),
@@ -64,6 +91,16 @@ public void run(CUstream_st stream, dim3 gridSize, dim3 blockSize, int sharedMem
6491 stream ,
6592 parametersPointer ,
6693 new PointerPointer <>());
94+
95+ if (enableKernelTimings )
96+ {
97+ cudaEventRecord (end );
98+ cudaEventSynchronize (end );
99+
100+ kernelTimings .addExecutionTime (start , end );
101+ kernelTimings .printTimesForKernel ();
102+ }
103+
67104 CUDATools .checkCUDAError (error );
68105
69106 if (!retainParameters )
@@ -123,4 +160,96 @@ public void close()
123160 clearParameters ();
124161 kernelFunction .close ();
125162 }
163+
164+ /**
165+ * This class handles the kernel timings.
166+ * With options to compute the min/max, average, and variance of the dataset
167+ */
168+ private class CUDAKernelTimings
169+ {
170+ private static final int MAX_ENTRIES = 250 ;
171+ private final LinkedList <Float > executionTimes = new LinkedList <>();
172+
173+ private void addExecutionTime (CUevent_st start , CUevent_st end )
174+ {
175+ float [] milliseconds = new float [1 ];
176+ milliseconds [0 ] = 0.0f ;
177+ cudaEventElapsedTime (milliseconds , start , end );
178+ executionTimes .add (milliseconds [0 ]);
179+
180+ if (executionTimes .size () > MAX_ENTRIES )
181+ {
182+ executionTimes .pollFirst ();
183+ }
184+ }
185+
186+ public double getAverageTime (String kernelName )
187+ {
188+ if (executionTimes .isEmpty ())
189+ {
190+ LogTools .info ("No recorded times for " + kernelName );
191+ return Float .NaN ;
192+ }
193+ else
194+ {
195+ return executionTimes .stream ().mapToDouble (Float ::doubleValue ).average ().orElse (0.0 );
196+ }
197+ }
198+
199+ public Float getMinTime (String kernelName )
200+ {
201+ if (executionTimes .isEmpty ())
202+ {
203+ LogTools .info ("No recorded times for " + kernelName );
204+ return Float .NaN ;
205+ }
206+ Optional <Float > min = executionTimes .stream ().min (Float ::compareTo );
207+ return min .orElse (null );
208+ }
209+
210+ public Float getMaxTime (String kernelName )
211+ {
212+ if (executionTimes .isEmpty ())
213+ {
214+ LogTools .info ("No recorded times for " + kernelName );
215+ return Float .NaN ;
216+ }
217+
218+ Optional <Float > max = executionTimes .stream ().max (Float ::compareTo );
219+ return max .orElse (null );
220+ }
221+
222+ public double getStandardDeviation (String kernelName )
223+ {
224+ if (executionTimes .isEmpty ())
225+ {
226+ LogTools .info ("No recorded times for " + kernelName );
227+ return Float .NaN ;
228+ }
229+
230+ double average = executionTimes .stream ().mapToDouble (Float ::doubleValue ).average ().orElse (0.0 );
231+ double variance = executionTimes .stream ().mapToDouble (time -> Math .pow (time - average , 2 )).average ().orElse (0.0 );
232+ return Math .sqrt (variance );
233+ }
234+
235+ public void printTimesForKernel ()
236+ {
237+ if (executionTimes .isEmpty ())
238+ {
239+ LogTools .info ("No recorded times for " + CUDAKernel .this .name );
240+ }
241+
242+ double average = getAverageTime (CUDAKernel .this .name );
243+ double variance = getStandardDeviation (CUDAKernel .this .name );
244+ double min = getMinTime (CUDAKernel .this .name );
245+ double max = getMaxTime (CUDAKernel .this .name );
246+
247+ LogTools .info ("Timings for kernel " + CUDAKernel .this .name + " in milliseconds!" );
248+ LogTools .info ("| Average time: " + average );
249+ LogTools .info ("| Variance time: " + variance );
250+ LogTools .info ("| Min time: " + min );
251+ LogTools .info ("| Max time: " + max );
252+ LogTools .warn ("******************************************" );
253+ }
254+ }
126255}
0 commit comments