From e25c16948f94ec3f3039c1ebc428c141b6624a5b Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 1 Dec 2021 15:46:21 -0800 Subject: [PATCH 01/39] [WIP] Add video GPU decoder --- setup.py | 43 ++ torchvision/csrc/io/decoder/gpu/decoder.cpp | 467 ++++++++++++++++++ torchvision/csrc/io/decoder/gpu/decoder.h | 101 ++++ torchvision/csrc/io/decoder/gpu/demuxer.h | 193 ++++++++ .../csrc/io/decoder/gpu/gpu_decoder.cpp | 75 +++ torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 21 + torchvision/io/__init__.py | 4 + torchvision/io/gpu_decoder.py | 22 + 8 files changed, 926 insertions(+) create mode 100644 torchvision/csrc/io/decoder/gpu/decoder.cpp create mode 100644 torchvision/csrc/io/decoder/gpu/decoder.h create mode 100644 torchvision/csrc/io/decoder/gpu/demuxer.h create mode 100644 torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp create mode 100644 torchvision/csrc/io/decoder/gpu/gpu_decoder.h create mode 100644 torchvision/io/gpu_decoder.py diff --git a/setup.py b/setup.py index 0b151988d1b..aad7a7be598 100644 --- a/setup.py +++ b/setup.py @@ -426,6 +426,49 @@ def get_extensions(): ) ) + # Locating video codec + # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI + video_codec_found = ( + extension is CUDAExtension + and CUDA_HOME is not None + and os.path.exists("/usr/local/include/cuviddec.h") + and os.path.exists("/usr/local/include/nvcuvid.h") + ) + + print(f"video codec found: {video_codec_found}") + + gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") + print(f"DECODER PATH {gpu_decoder_path}") + gpu_decoder_src = ( + glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) + ) + cuda_libs = os.path.join(CUDA_HOME, 'lib64') + cuda_inc = os.path.join(CUDA_HOME, 'include') + + if video_codec_found and has_ffmpeg: + ext_modules.append( + extension( + "torchvision.Decoder", + gpu_decoder_src, + include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc], + library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs] + ['/usr/local/lib'], + libraries=[ + "avcodec", + "avformat", + "avutil", + "swresample", + "swscale", + "nvcuvid", + "cuda", + "cudart", + "z", + "pthread", + "dl", + ], + extra_compile_args=extra_compile_args, + ) + ) + return ext_modules diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp new file mode 100644 index 00000000000..62fbad11d74 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -0,0 +1,467 @@ +#include +#include +#include +#include +#include +#include "decoder.h" + + +static float GetChromaHeightFactor(cudaVideoSurfaceFormat surfaceFormat) +{ + return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 1.0 : 0.5; +} + +static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) +{ + return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 2 : 1; +} + +void Decoder::init(CUcontext context, cudaVideoCodec codec, bool useDevFrame, const Rect *rect, + const Dim *dim, int64_t clockRate, bool lowLatency, + bool forceZeroLat, int64_t maxWidth, int64_t maxHeight) +{ + cuContext = context; + forceZeroLatency = forceZeroLat; + useDeviceFrame = useDevFrame; + videoCodec = codec; + nMaxHeight = maxHeight; + nMaxWidth = maxWidth; + if (rect) { + cropRect = *rect; + } + if (dim) { + resizeDim = *dim; + } + CheckForCudaErrors( + cuvidCtxLockCreate(&ctxLock, cuContext), + __LINE__); + + CUVIDPARSERPARAMS parserParams = { + .CodecType = codec, + .ulMaxNumDecodeSurfaces = 1, + .ulClockRate = clockRate, + .ulMaxDisplayDelay = lowLatency ? 0u : 1u, + .pUserData = this, + .pfnSequenceCallback = HandleVideoSequenceProc, + .pfnDecodePicture = HandlePictureDecodeProc, + .pfnDisplayPicture = forceZeroLatency ? NULL : HandlePictureDisplayProc, + .pfnGetOperatingPoint = HandleOperatingPointProc, + }; + CheckForCudaErrors( + cuvidCreateVideoParser(&parser, &parserParams), + __LINE__); +} + +Decoder::~Decoder() +{ + if (parser) { + cuvidDestroyVideoParser(parser); + } + cuvidCtxLockDestroy(ctxLock); +} + +void Decoder::release() const +{ + cuCtxPushCurrent(cuContext); + if (decoder) { + cuvidDestroyDecoder(decoder); + } + cuCtxPopCurrent(NULL); +} + +torch::Tensor Decoder::Decode(const uint8_t *data, int64_t size, int64_t flags, int64_t pts) +{ + numDecodedFrames = 0; + frameTensor = torch::empty({GetFrameSize()}, + torch::dtype(torch::kU8).device(useDeviceFrame? torch::kCUDA : torch::kCPU)); + CUVIDSOURCEDATAPACKET pkt = { + .flags = flags | CUVID_PKT_TIMESTAMP, + .payload_size = size, + .payload = data, + .timestamp = pts + }; + if (!data || size == 0) { + pkt.flags |= CUVID_PKT_ENDOFSTREAM; + } + CheckForCudaErrors( + cuvidParseVideoData(parser, &pkt), + __LINE__); + cuvidStream = 0; + return frameTensor; +} + +int Decoder::HandlePictureDecode(CUVIDPICPARAMS *picParams) +{ + if (!decoder) { + throw std::runtime_error("Uninitialised decoder."); + } + picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; + CheckForCudaErrors( + cuCtxPushCurrent(cuContext), + __LINE__); + CheckForCudaErrors( + cuvidDecodePicture(decoder, picParams), + __LINE__); + if (forceZeroLatency && ((!picParams->field_pic_flag) || (picParams->second_field))) { + CUVIDPARSERDISPINFO dispInfo { + .picture_index = picParams->CurrPicIdx, + .progressive_frame = !picParams->field_pic_flag, + .top_field_first = picParams->bottom_field_flag ^ 1, + }; + HandlePictureDisplay(&dispInfo); + } + CheckForCudaErrors( + cuCtxPopCurrent(NULL), + __LINE__); + return 1; +} + +int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) +{ + CUVIDPROCPARAMS procParams = { + .progressive_frame = dispInfo->progressive_frame, + .second_field = dispInfo->repeat_first_field + 1, + .top_field_first = dispInfo->top_field_first, + .unpaired_field = dispInfo->repeat_first_field < 0, + .output_stream = cuvidStream, + }; + + CUdeviceptr dpSrcFrame = 0; + unsigned int nSrcPitch = 0; + CheckForCudaErrors( + cuCtxPushCurrent(cuContext), + __LINE__); + CheckForCudaErrors( + cuvidMapVideoFrame(decoder, dispInfo->picture_index, &dpSrcFrame, &nSrcPitch, &procParams), + __LINE__); + + CUVIDGETDECODESTATUS decodeStatus; + memset(&decodeStatus, 0, sizeof(decodeStatus)); + CUresult result = cuvidGetDecodeStatus(decoder, dispInfo->picture_index, &decodeStatus); + if (result == CUDA_SUCCESS && + (decodeStatus.decodeStatus == cuvidDecodeStatus_Error || decodeStatus.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { + printf("Decode Error occurred for picture %d\n", picNumInDecodeOrder[dispInfo->picture_index]); + } + + uint8_t *decodedFrame = frameTensor.data_ptr(); + numDecodedFrames++; + + // Copy luma plane + CUDA_MEMCPY2D m = { 0 }; + m.srcMemoryType = CU_MEMORYTYPE_DEVICE; + m.srcDevice = dpSrcFrame; + m.srcPitch = nSrcPitch; + m.dstMemoryType = useDeviceFrame ? CU_MEMORYTYPE_DEVICE : CU_MEMORYTYPE_HOST; + m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); + m.dstPitch = deviceFramePitch ? deviceFramePitch : GetWidth() * bytesPerPixel; + m.WidthInBytes = GetWidth() * bytesPerPixel; + m.Height = lumaHeight; + CheckForCudaErrors( + cuMemcpy2DAsync(&m, cuvidStream), + __LINE__); + + // Copy chroma plane + // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning height + m.srcDevice = (CUdeviceptr)((uint8_t *)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1)); + m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight); + m.Height = chromaHeight; + CheckForCudaErrors( + cuMemcpy2DAsync(&m, cuvidStream), + __LINE__); + + if (numChromaPlanes == 2) { + m.srcDevice = (CUdeviceptr)((uint8_t *)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1) * 2); + m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight * 2); + m.Height = chromaHeight; + CheckForCudaErrors( + cuMemcpy2DAsync(&m, cuvidStream), + __LINE__); + } + CheckForCudaErrors( + cuStreamSynchronize(cuvidStream), + __LINE__); + CheckForCudaErrors( + cuCtxPopCurrent(NULL), + __LINE__); + + CheckForCudaErrors( + cuvidUnmapVideoFrame(decoder, dpSrcFrame), + __LINE__); + return 1; +} + +void Decoder::queryHardware(CUVIDEOFORMAT *videoFormat) +{ + CUVIDDECODECAPS decodeCaps = { + .eCodecType = videoFormat->codec, + .eChromaFormat = videoFormat->chroma_format, + .nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8, + }; + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors(cuvidGetDecoderCaps(&decodeCaps), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + + if(!decodeCaps.bIsSupported) { + throw std::runtime_error("Codec not supported on this GPU"); + } + if ((videoFormat->coded_width > decodeCaps.nMaxWidth) || + (videoFormat->coded_height > decodeCaps.nMaxHeight)) { + std::ostringstream errorString; + errorString << std::endl + << "Resolution : " << videoFormat->coded_width << "x" << videoFormat->coded_height << std::endl + << "Max Supported (wxh) : " << decodeCaps.nMaxWidth << "x" << decodeCaps.nMaxHeight << std::endl + << "Resolution not supported on this GPU"; + + const std::string cErr = errorString.str(); + throw std::runtime_error(cErr); + } + if ((videoFormat->coded_width>>4)*(videoFormat->coded_height>>4) > decodeCaps.nMaxMBCount) { + std::ostringstream errorString; + errorString << std::endl + << "MBCount : " << (videoFormat->coded_width >> 4)*(videoFormat->coded_height >> 4) << std::endl + << "Max Supported mbcnt : " << decodeCaps.nMaxMBCount << std::endl + << "MBCount not supported on this GPU"; + + const std::string cErr = errorString.str(); + throw std::runtime_error(cErr); + } + // Check if output format supported. If not, check fallback options + if (!(decodeCaps.nOutputFormatMask & (1 << videoOutputFormat))) { + if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) + videoOutputFormat = cudaVideoSurfaceFormat_NV12; + else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) + videoOutputFormat = cudaVideoSurfaceFormat_P016; + else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) + videoOutputFormat = cudaVideoSurfaceFormat_YUV444; + else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) + videoOutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit; + else + throw std::runtime_error("No supported output format found"); + } +} + +int Decoder::HandleVideoSequence(CUVIDEOFORMAT *vidFormat) +{ + // videoCodec has been set in the constructor (for parser). Here it's set again for potential correction + videoCodec = vidFormat->codec; + videoChromaFormat = vidFormat->chroma_format; + bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; + bytesPerPixel = bitDepthMinus8 > 0 ? 2 : 1; + // Set the output surface format same as chroma format + switch (videoChromaFormat) { + case cudaVideoChromaFormat_Monochrome: + case cudaVideoChromaFormat_420: + videoOutputFormat = vidFormat->bit_depth_luma_minus8 ? cudaVideoSurfaceFormat_P016 : cudaVideoSurfaceFormat_NV12; + break; + case cudaVideoChromaFormat_444: + videoOutputFormat = vidFormat->bit_depth_luma_minus8 ? cudaVideoSurfaceFormat_YUV444_16Bit : cudaVideoSurfaceFormat_YUV444; + break; + case cudaVideoChromaFormat_422: + videoOutputFormat = cudaVideoSurfaceFormat_NV12; // no 4:2:2 output format supported yet so make 420 default + } + + queryHardware(vidFormat); + if (width && lumaHeight && chromaHeight) { + // cuvidCreateDecoder() has been called before, and now there's possible config change + return ReconfigureDecoder(vidFormat); + } + + videoFormat = *vidFormat; + int nDecodeSurface = vidFormat->min_num_decode_surfaces; + cudaVideoDeinterlaceMode deinterlaceMode = cudaVideoDeinterlaceMode_Adaptive; + if (vidFormat->progressive_sequence) { + deinterlaceMode = cudaVideoDeinterlaceMode_Weave; + } + + CUVIDDECODECREATEINFO videoDecodeCreateInfo = { + .ulWidth = vidFormat->coded_width, + .ulHeight = vidFormat->coded_height, + .ulNumDecodeSurfaces = nDecodeSurface, + .CodecType = vidFormat->codec, + .ChromaFormat = vidFormat->chroma_format, + // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded by NVDEC hardware + .ulCreationFlags = cudaVideoCreate_PreferCUVID, + .bitDepthMinus8 = vidFormat->bit_depth_luma_minus8, + .OutputFormat = videoOutputFormat, + .DeinterlaceMode = deinterlaceMode, + .ulNumOutputSurfaces = 2, + .vidLock = ctxLock, + }; + // AV1 has max width/height of sequence in sequence header + if (vidFormat->codec == cudaVideoCodec_AV1 && vidFormat->seqhdr_data_length > 0) { + // don't overwrite if it is already set from cmdline or reconfig.txt + if (!(nMaxWidth > vidFormat->coded_width || nMaxHeight > vidFormat->coded_height)) { + CUVIDEOFORMATEX *vidFormatEx = (CUVIDEOFORMATEX *)vidFormat; + nMaxWidth = vidFormatEx->av1.max_width; + nMaxHeight = vidFormatEx->av1.max_height; + } + } + if (nMaxWidth < (int)vidFormat->coded_width) + nMaxWidth = vidFormat->coded_width; + if (nMaxHeight < (int)vidFormat->coded_height) + nMaxHeight = vidFormat->coded_height; + videoDecodeCreateInfo.ulMaxWidth = nMaxWidth; + videoDecodeCreateInfo.ulMaxHeight = nMaxHeight; + + if (!(cropRect.right && cropRect.bottom) && !(resizeDim.width && resizeDim.height)) { + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; + videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; + } else { + if (resizeDim.width && resizeDim.height) { + videoDecodeCreateInfo.display_area.left = vidFormat->display_area.left; + videoDecodeCreateInfo.display_area.top = vidFormat->display_area.top; + videoDecodeCreateInfo.display_area.right = vidFormat->display_area.right; + videoDecodeCreateInfo.display_area.bottom = vidFormat->display_area.bottom; + width = resizeDim.width; + lumaHeight = resizeDim.height; + } + + if (cropRect.right && cropRect.bottom) { + videoDecodeCreateInfo.display_area.left = cropRect.left; + videoDecodeCreateInfo.display_area.top = cropRect.top; + videoDecodeCreateInfo.display_area.right = cropRect.right; + videoDecodeCreateInfo.display_area.bottom = cropRect.bottom; + width = cropRect.right - cropRect.left; + lumaHeight = cropRect.bottom - cropRect.top; + } + videoDecodeCreateInfo.ulTargetWidth = width; + videoDecodeCreateInfo.ulTargetHeight = lumaHeight; + } + + chromaHeight = (int)(ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat))); + numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + surfaceHeight = videoDecodeCreateInfo.ulTargetHeight; + surfaceWidth = videoDecodeCreateInfo.ulTargetWidth; + displayRect.bottom = videoDecodeCreateInfo.display_area.bottom; + displayRect.top = videoDecodeCreateInfo.display_area.top; + displayRect.left = videoDecodeCreateInfo.display_area.left; + displayRect.right = videoDecodeCreateInfo.display_area.right; + + + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors(cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + return nDecodeSurface; +} + +int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) +{ + if (vidFormat->bit_depth_luma_minus8 != videoFormat.bit_depth_luma_minus8 || vidFormat->bit_depth_chroma_minus8 != videoFormat.bit_depth_chroma_minus8) { + throw std::runtime_error("Reconfigure Not supported for bit depth change"); + } + if (vidFormat->chroma_format != videoFormat.chroma_format) { + throw std::runtime_error("Reconfigure Not supported for chroma format change"); + } + + bool bDecodeResChange = !(vidFormat->coded_width == videoFormat.coded_width && vidFormat->coded_height == videoFormat.coded_height); + bool bDisplayRectChange = !(vidFormat->display_area.bottom == videoFormat.display_area.bottom && vidFormat->display_area.top == videoFormat.display_area.top \ + && vidFormat->display_area.left == videoFormat.display_area.left && vidFormat->display_area.right == videoFormat.display_area.right); + + int nDecodeSurface = vidFormat->min_num_decode_surfaces; + + if ((vidFormat->coded_width > nMaxWidth) || (vidFormat->coded_height > nMaxHeight)) { + // For VP9, let driver handle the change if new width/height > maxwidth/maxheight + if ((videoCodec != cudaVideoCodec_VP9) || m_bReconfigExternal) { + throw std::runtime_error("Reconfigure Not supported when width/height > maxwidth/maxheight"); + } + return 1; + } + + if (!bDecodeResChange && !m_bReconfigExtPPChange) { + // if the coded_width/coded_height hasn't changed but display resolution has changed, then need to update width/height for + // correct output without cropping. Example : 1920x1080 vs 1920x1088 + if (bDisplayRectChange) { + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + chromaHeight = (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); + numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + } + + // no need for reconfigureDecoder(). Just return + return 1; + } + + CUVIDRECONFIGUREDECODERINFO reconfigParams = { 0 }; + + reconfigParams.ulWidth = videoFormat.coded_width = vidFormat->coded_width; + reconfigParams.ulHeight = videoFormat.coded_height = vidFormat->coded_height; + + // Dont change display rect and get scaled output from decoder. This will help display app to present apps smoothly + reconfigParams.display_area.bottom = displayRect.bottom; + reconfigParams.display_area.top = displayRect.top; + reconfigParams.display_area.left = displayRect.left; + reconfigParams.display_area.right = displayRect.right; + reconfigParams.ulTargetWidth = surfaceWidth; + reconfigParams.ulTargetHeight = surfaceHeight; + + // If external reconfigure is called along with resolution change even if post processing params is not changed, + // do full reconfigure params update + if ((m_bReconfigExternal && bDecodeResChange) || m_bReconfigExtPPChange) { + // update display rect and target resolution if requested explicitely + m_bReconfigExternal = false; + m_bReconfigExtPPChange = false; + videoFormat = *vidFormat; + if (!(cropRect.right && cropRect.bottom) && !(resizeDim.width && resizeDim.height)) { + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + reconfigParams.ulTargetWidth = vidFormat->coded_width; + reconfigParams.ulTargetHeight = vidFormat->coded_height; + } + else { + if (resizeDim.width && resizeDim.height) { + reconfigParams.display_area.left = vidFormat->display_area.left; + reconfigParams.display_area.top = vidFormat->display_area.top; + reconfigParams.display_area.right = vidFormat->display_area.right; + reconfigParams.display_area.bottom = vidFormat->display_area.bottom; + width = resizeDim.width; + lumaHeight = resizeDim.height; + } + + if (cropRect.right && cropRect.bottom) { + reconfigParams.display_area.left = cropRect.left; + reconfigParams.display_area.top = cropRect.top; + reconfigParams.display_area.right = cropRect.right; + reconfigParams.display_area.bottom = cropRect.bottom; + width = cropRect.right - cropRect.left; + lumaHeight = cropRect.bottom - cropRect.top; + } + reconfigParams.ulTargetWidth = width; + reconfigParams.ulTargetHeight = lumaHeight; + } + + chromaHeight = (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); + numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + surfaceHeight = reconfigParams.ulTargetHeight; + surfaceWidth = reconfigParams.ulTargetWidth; + displayRect.bottom = reconfigParams.display_area.bottom; + displayRect.top = reconfigParams.display_area.top; + displayRect.left = reconfigParams.display_area.left; + displayRect.right = reconfigParams.display_area.right; + } + + reconfigParams.ulNumDecodeSurfaces = nDecodeSurface; + + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors(cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + + return nDecodeSurface; +} + +int Decoder::GetOperatingPoint(CUVIDOPERATINGPOINTINFO *operPointInfo) +{ + if (operPointInfo->codec == cudaVideoCodec_AV1) { + if (operPointInfo->av1.operating_points_cnt > 1) { + // clip has SVC enabled + if (operatingPoint >= operPointInfo->av1.operating_points_cnt) + operatingPoint = 0; + + printf("AV1 SVC clip: operating point count %d ", operPointInfo->av1.operating_points_cnt); + printf("Selected operating point: %d, IDC 0x%x bOutputAllLayers %d\n", operatingPoint, operPointInfo->av1.operating_points_idc[operatingPoint], dispAllLayers); + return (operatingPoint | (dispAllLayers << 10)); + } + } + return -1; +} diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h new file mode 100644 index 00000000000..ee79786eb7e --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -0,0 +1,101 @@ +#include +#include +#include +#include +#include +#include + +static auto CheckForCudaErrors = [](CUresult result, int lineNum) +{ + if (CUDA_SUCCESS != result) { + std::stringstream errorStream; + const char *errorName = nullptr, *errorDesc = nullptr; + + errorStream << __FILE__ << ":" << lineNum << std::endl; + if (CUDA_SUCCESS != cuGetErrorName(result, &errorName)) { + errorStream << "CUDA error with code " << result << std::endl; + } else { + errorStream << "CUDA error: " << errorName << std::endl; + } + throw std::runtime_error(errorStream.str()); + } +}; + +struct Rect { + int left, top, right, bottom; +}; + +struct Dim { + int width, height; +}; + +class Decoder { + public: + Decoder() {} + ~Decoder(); + void init(CUcontext, cudaVideoCodec, bool, const Rect * = NULL, const Dim * = NULL, int64_t = 1000, bool = false, bool = false, int64_t = 0, int64_t = 0); + torch::Tensor Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); + cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } + void release() const; + int64_t GetNumDecodedFrames() const { return numDecodedFrames; } + + private: + CUcontext cuContext = NULL; + CUvideoctxlock ctxLock; + CUvideoparser parser = NULL; + CUvideodecoder decoder = NULL; + bool forceZeroLatency = false; + bool useDeviceFrame; + CUstream cuvidStream = 0; + int numDecodedFrames = 0; + unsigned int numChromaPlanes = 0; + torch::Tensor frameTensor; + // dimension of the output + unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; + cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; + cudaVideoChromaFormat videoChromaFormat = cudaVideoChromaFormat_420; + cudaVideoSurfaceFormat videoOutputFormat = cudaVideoSurfaceFormat_NV12; + int bitDepthMinus8 = 0; + int bytesPerPixel = 1; + CUVIDEOFORMAT videoFormat = {}; + unsigned int nMaxWidth = 0, nMaxHeight = 0; + Rect cropRect = {}; + Dim resizeDim = {}; + // height of the mapped surface + int surfaceHeight = 0; + int surfaceWidth = 0; + Rect displayRect = {}; + unsigned int operatingPoint = 0; + bool dispAllLayers = false; + int decodePicCount = 0, picNumInDecodeOrder[32]; + bool m_bReconfigExternal = false; + bool m_bReconfigExtPPChange = false; + size_t deviceFramePitch = 0; + bool m_bDeviceFramePitched = false; + int m_nFrameAlloc = 0; + + static int CUDAAPI HandleVideoSequenceProc(void *pUserData, CUVIDEOFORMAT *pVideoFormat) { return ((Decoder *)pUserData)->HandleVideoSequence(pVideoFormat); } + static int CUDAAPI HandlePictureDecodeProc(void *pUserData, CUVIDPICPARAMS *pPicParams) { return ((Decoder *)pUserData)->HandlePictureDecode(pPicParams); } + static int CUDAAPI HandlePictureDisplayProc(void *pUserData, CUVIDPARSERDISPINFO *pDispInfo) { return ((Decoder *)pUserData)->HandlePictureDisplay(pDispInfo); } + static int CUDAAPI HandleOperatingPointProc(void *pUserData, CUVIDOPERATINGPOINTINFO *pOPInfo) { return ((Decoder *)pUserData)->GetOperatingPoint(pOPInfo); } + + void queryHardware(CUVIDEOFORMAT *videoFormat); + int ReconfigureDecoder(CUVIDEOFORMAT *pVideoFormat); + int HandleVideoSequence(CUVIDEOFORMAT *pVideoFormat); + int HandlePictureDecode(CUVIDPICPARAMS *pPicParams); + int HandlePictureDisplay(CUVIDPARSERDISPINFO *pDispInfo); + int GetOperatingPoint(CUVIDOPERATINGPOINTINFO *pOPInfo); + + int GetWidth() + { + assert(width); + return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || videoOutputFormat == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; + } + + int GetFrameSize() + { + assert(width); + return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; + } +}; + diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h new file mode 100644 index 00000000000..f1abb51c209 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -0,0 +1,193 @@ +extern "C" { +#include +#include +#include +#include +} +#include + +inline bool check(int ret, int line) { + if (ret < 0) { + printf("Error %d at line %d in file %s.\n", ret, line); + return false; + } + return true; +} + +#define check_for_errors(call) check(call, __LINE__) + +class Demuxer { + private: + AVFormatContext *fmtCtx = NULL; + AVBSFContext *bsfCtx = NULL; + AVPacket pkt, pktFiltered; + AVCodecID eVideoCodec; + uint8_t *pDataWithHeader = NULL; + bool bMp4H264, bMp4HEVC, bMp4MPEG4; + unsigned int frameCount = 0; + int iVideoStream; + int64_t userTimeScale = 0; + double timeBase = 0.0; + + public: + Demuxer(const char *filePath, int64_t timeScale = 1000 /*Hz*/) + { + avformat_network_init(); + check_for_errors(avformat_open_input(&fmtCtx, filePath, NULL, NULL)); + if (!fmtCtx) { + printf("No AVFormatContext provided.\n"); + return; + } + + check_for_errors(avformat_find_stream_info(fmtCtx, NULL)); + iVideoStream = av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); + if (iVideoStream < 0) { + printf("FFmpeg error: %d, could not find stream in input file\n", __LINE__); + return; + } + + eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id; + AVRational rTimeBase = fmtCtx->streams[iVideoStream]->time_base; + timeBase = av_q2d(rTimeBase); + userTimeScale = timeScale; + + bMp4H264 = eVideoCodec == AV_CODEC_ID_H264 && ( + !strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") + || !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") + || !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4HEVC = eVideoCodec == AV_CODEC_ID_HEVC && ( + !strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") + || !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") + || !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4MPEG4 = eVideoCodec == AV_CODEC_ID_MPEG4 && ( + !strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") + || !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") + || !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + + av_init_packet(&pkt); + pkt.data = NULL; + pkt.size = 0; + av_init_packet(&pktFiltered); + pktFiltered.data = NULL; + pktFiltered.size = 0; + + if (bMp4H264) { + const AVBitStreamFilter *bsf = av_bsf_get_by_name("h264_mp4toannexb"); + if (!bsf) { + printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); + return; + } + check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); + avcodec_parameters_copy(bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); + check_for_errors(av_bsf_init(bsfCtx)); + } + if (bMp4HEVC) { + const AVBitStreamFilter *bsf = av_bsf_get_by_name("hevc_mp4toannexb"); + if (!bsf) { + printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); + return; + } + check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); + avcodec_parameters_copy(bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); + check_for_errors(av_bsf_init(bsfCtx)); + } + } + ~Demuxer() + { + if (!fmtCtx) { + return; + } + if (pkt.data) { + av_packet_unref(&pkt); + } + if (pktFiltered.data) { + av_packet_unref(&pktFiltered); + } + if (bsfCtx) { + av_bsf_free(&bsfCtx); + } + avformat_close_input(&fmtCtx); + if (pDataWithHeader) { + av_free(pDataWithHeader); + } + } + + AVCodecID GetVideoCodec() + { + return eVideoCodec; + } + + bool Demux(uint8_t **ppVideo, int64_t *pnVideoBytes, int64_t *pts = NULL) + { + if (!fmtCtx) { + return false; + } + *pnVideoBytes = 0; + + if (pkt.data) { + av_packet_unref(&pkt); + } + int e = 0; + while ((e = av_read_frame(fmtCtx, &pkt)) >= 0 && pkt.stream_index != iVideoStream) { + av_packet_unref(&pkt); + } + if (e < 0) { + return false; + } + + if (bMp4H264 || bMp4HEVC) { + if (pktFiltered.data) { + av_packet_unref(&pktFiltered); + } + check_for_errors(av_bsf_send_packet(bsfCtx, &pkt)); + check_for_errors(av_bsf_receive_packet(bsfCtx, &pktFiltered)); + *ppVideo = pktFiltered.data; + *pnVideoBytes = pktFiltered.size; + if (pts) { + *pts = (int64_t) (pktFiltered.pts * userTimeScale * timeBase); + } + } else { + if (bMp4MPEG4 && (frameCount == 0)) { + int extraDataSize = fmtCtx->streams[iVideoStream]->codecpar->extradata_size; + + if (extraDataSize > 0) { + pDataWithHeader = (uint8_t *)av_malloc(extraDataSize + pkt.size - 3 * sizeof(uint8_t)); + if (!pDataWithHeader) { + printf("FFmpeg error: %d\n", __LINE__); + return false; + } + memcpy(pDataWithHeader, fmtCtx->streams[iVideoStream]->codecpar->extradata, extraDataSize); + memcpy(pDataWithHeader+extraDataSize, pkt.data+3, pkt.size - 3 * sizeof(uint8_t)); + *ppVideo = pDataWithHeader; + *pnVideoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); + } + } else { + *ppVideo = pkt.data; + *pnVideoBytes = pkt.size; + } + + if (pts) { + *pts = (int64_t)(pkt.pts * userTimeScale * timeBase); + } + } + frameCount++; + return true; + } +}; + +inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) { + switch (id) { + case AV_CODEC_ID_MPEG1VIDEO : return cudaVideoCodec_MPEG1; + case AV_CODEC_ID_MPEG2VIDEO : return cudaVideoCodec_MPEG2; + case AV_CODEC_ID_MPEG4 : return cudaVideoCodec_MPEG4; + case AV_CODEC_ID_WMV3 : + case AV_CODEC_ID_VC1 : return cudaVideoCodec_VC1; + case AV_CODEC_ID_H264 : return cudaVideoCodec_H264; + case AV_CODEC_ID_HEVC : return cudaVideoCodec_HEVC; + case AV_CODEC_ID_VP8 : return cudaVideoCodec_VP8; + case AV_CODEC_ID_VP9 : return cudaVideoCodec_VP9; + case AV_CODEC_ID_MJPEG : return cudaVideoCodec_JPEG; + case AV_CODEC_ID_AV1 : return cudaVideoCodec_AV1; + default : return cudaVideoCodec_NumCodecs; + } +} diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp new file mode 100644 index 00000000000..fd2a429a18d --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -0,0 +1,75 @@ +#include +#include "gpu_decoder.h" + +GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame) : demuxer(src_file.c_str()) +{ + if (cudaSuccess != cudaSetDevice(0)) { + printf("Error setting device\n"); + return; + } + CheckForCudaErrors( + cuCtxCreate(&ctx, CU_CTX_SCHED_SPIN, 0), + __LINE__); + dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec()), useDevFrame); + demux_time = 0.0; + decode_time = 0.0; + totalFrames = 0; +} + +GPUDecoder::~GPUDecoder() +{ + dec.release(); + cuCtxDestroy(ctx); +} + +torch::Tensor GPUDecoder::decode() +{ + uint8_t *video; + torch::Tensor framesReturned; + int64_t numFrames; + clock_t start, end; + double cpu_time_used; + start = clock(); + demuxer.Demux(&video, &videoBytes); + end = clock(); + cpu_time_used = ((double) (end - start)) / CLOCKS_PER_SEC; + demux_time += cpu_time_used; + framesReturned = dec.Decode(video, videoBytes); + numFrames = dec.GetNumDecodedFrames(); + end = clock(); + cpu_time_used = ((double) (end - start)) / CLOCKS_PER_SEC; + decode_time += cpu_time_used; + totalFrames += numFrames; + return framesReturned; +} + +int64_t GPUDecoder::getDemuxedBytes() +{ + return videoBytes; +} + +double GPUDecoder::getDecodeTime() +{ + return decode_time; +} + +double GPUDecoder::getDemuxTime() +{ + return demux_time; +} + +int64_t GPUDecoder::getTotalFramesDecoded() +{ + return totalFrames; +} + +TORCH_LIBRARY(torchvision, m) { + m.class_("GPUDecoder") + .def(torch::init()) + .def("decode", &GPUDecoder::decode) + .def("getDecodeTime", &GPUDecoder::getDecodeTime) + .def("getDemuxTime", &GPUDecoder::getDemuxTime) + .def("getDemuxedBytes", &GPUDecoder::getDemuxedBytes) + .def("getTotalFramesDecoded", &GPUDecoder::getTotalFramesDecoded) + ; + } diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h new file mode 100644 index 00000000000..fea9974dd43 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -0,0 +1,21 @@ +#include +#include "decoder.h" +#include "demuxer.h" + +class GPUDecoder : public torch::CustomClassHolder { + public: + GPUDecoder(std::string, bool = false); + ~GPUDecoder(); + torch::Tensor decode(); + double getDecodeTime(); + double getDemuxTime(); + int64_t getDemuxedBytes(); + int64_t getTotalFramesDecoded(); + + private: + Demuxer demuxer; + CUcontext ctx; + Decoder dec; + double demux_time, decode_time; + int64_t totalFrames, videoBytes; +}; diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 382e06fb4f2..d2250e4a85b 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -32,6 +32,10 @@ write_video, ) +from .gpu_decoder import ( + GPUDecoder +) + if _HAS_VIDEO_OPT: diff --git a/torchvision/io/gpu_decoder.py b/torchvision/io/gpu_decoder.py new file mode 100644 index 00000000000..743c4379d52 --- /dev/null +++ b/torchvision/io/gpu_decoder.py @@ -0,0 +1,22 @@ +import torch + + +class GPUDecoder: + def __init__(self, src_file: str): + torch.ops.load_library('build/lib.linux-x86_64-3.8/torchvision/Decoder.so') + self.decoder = torch.classes.torchvision.GPUDecoder(src_file) + + def decode_frame(self): + return self.decoder.decode() + + def get_total_decoding_time(self): + return self.decoder.getDecodeTime() + + def get_total_demuxing_time(self): + return self.decoder.getDemuxTime() + + def get_demuxed_bytes(self): + return self.decoder.getDemuxedBytes() + + def get_total_frames_decoded(self): + return self.decoder.getTotalFramesDecoded() From 58dd405d20ff69c4438c3988c9bca728721101ec Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 7 Dec 2021 06:52:32 -0800 Subject: [PATCH 02/39] Expose use_dev_frame to python class and handle it internally --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 18 +++++++++++++++--- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 2 +- torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 2 +- torchvision/io/gpu_decoder.py | 4 ++-- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 62fbad11d74..0578b0e8b5f 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -72,8 +72,6 @@ void Decoder::release() const torch::Tensor Decoder::Decode(const uint8_t *data, int64_t size, int64_t flags, int64_t pts) { numDecodedFrames = 0; - frameTensor = torch::empty({GetFrameSize()}, - torch::dtype(torch::kU8).device(useDeviceFrame? torch::kCUDA : torch::kCPU)); CUVIDSOURCEDATAPACKET pkt = { .flags = flags | CUVID_PKT_TIMESTAMP, .payload_size = size, @@ -143,7 +141,16 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) printf("Decode Error occurred for picture %d\n", picNumInDecodeOrder[dispInfo->picture_index]); } - uint8_t *decodedFrame = frameTensor.data_ptr(); + uint8_t *decodedFrame = nullptr; + if (useDeviceFrame) { + cuMemAlloc((CUdeviceptr *)&decodedFrame, GetFrameSize()); + } + else { + frameTensor = torch::empty({GetFrameSize()}, + torch::dtype(torch::kU8).device(torch::kCPU)); + decodedFrame = frameTensor.data_ptr(); + } + numDecodedFrames++; // Copy luma plane @@ -180,6 +187,11 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) CheckForCudaErrors( cuStreamSynchronize(cuvidStream), __LINE__); + if (useDeviceFrame) { + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + frameTensor = torch::from_blob( + decodedFrame, {GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); + } CheckForCudaErrors( cuCtxPopCurrent(NULL), __LINE__); diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index fd2a429a18d..8dcff90de4d 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -65,7 +65,7 @@ int64_t GPUDecoder::getTotalFramesDecoded() TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") - .def(torch::init()) + .def(torch::init()) .def("decode", &GPUDecoder::decode) .def("getDecodeTime", &GPUDecoder::getDecodeTime) .def("getDemuxTime", &GPUDecoder::getDemuxTime) diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index fea9974dd43..ec32e9259e5 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -4,7 +4,7 @@ class GPUDecoder : public torch::CustomClassHolder { public: - GPUDecoder(std::string, bool = false); + GPUDecoder(std::string, bool); ~GPUDecoder(); torch::Tensor decode(); double getDecodeTime(); diff --git a/torchvision/io/gpu_decoder.py b/torchvision/io/gpu_decoder.py index 743c4379d52..9b72cc32c9c 100644 --- a/torchvision/io/gpu_decoder.py +++ b/torchvision/io/gpu_decoder.py @@ -2,9 +2,9 @@ class GPUDecoder: - def __init__(self, src_file: str): + def __init__(self, src_file: str, use_dev_frame: bool = True): torch.ops.load_library('build/lib.linux-x86_64-3.8/torchvision/Decoder.so') - self.decoder = torch.classes.torchvision.GPUDecoder(src_file) + self.decoder = torch.classes.torchvision.GPUDecoder(src_file, use_dev_frame) def decode_frame(self): return self.decoder.decode() From 7bdf86b42b6f15058fab24a8f8fab8174de51909 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 8 Dec 2021 09:21:24 -0800 Subject: [PATCH 03/39] Fixed invalid argument CUDA error --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 1 - torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp | 15 ++++++++++----- torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 5 +++-- torchvision/io/gpu_decoder.py | 4 ++-- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 0578b0e8b5f..12d92cdf76c 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -5,7 +5,6 @@ #include #include "decoder.h" - static float GetChromaHeightFactor(cudaVideoSurfaceFormat surfaceFormat) { return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 1.0 : 0.5; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 8dcff90de4d..373f919b4cb 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,16 +1,17 @@ #include #include "gpu_decoder.h" -GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame) : demuxer(src_file.c_str()) +GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev) : demuxer(src_file.c_str()), device(dev) { - if (cudaSuccess != cudaSetDevice(0)) { + if (cudaSuccess != cudaSetDevice(device)) { printf("Error setting device\n"); return; } CheckForCudaErrors( - cuCtxCreate(&ctx, CU_CTX_SCHED_SPIN, 0), + cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec()), useDevFrame); + initialised = true; demux_time = 0.0; decode_time = 0.0; totalFrames = 0; @@ -19,7 +20,11 @@ GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame) : demuxer(src_fil GPUDecoder::~GPUDecoder() { dec.release(); - cuCtxDestroy(ctx); + if (initialised) { + CheckForCudaErrors( + cuDevicePrimaryCtxRelease(device), + __LINE__); + } } torch::Tensor GPUDecoder::decode() @@ -65,7 +70,7 @@ int64_t GPUDecoder::getTotalFramesDecoded() TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") - .def(torch::init()) + .def(torch::init()) .def("decode", &GPUDecoder::decode) .def("getDecodeTime", &GPUDecoder::getDecodeTime) .def("getDemuxTime", &GPUDecoder::getDemuxTime) diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index ec32e9259e5..42080e53f38 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -4,7 +4,7 @@ class GPUDecoder : public torch::CustomClassHolder { public: - GPUDecoder(std::string, bool); + GPUDecoder(std::string, bool, int64_t); ~GPUDecoder(); torch::Tensor decode(); double getDecodeTime(); @@ -17,5 +17,6 @@ class GPUDecoder : public torch::CustomClassHolder { CUcontext ctx; Decoder dec; double demux_time, decode_time; - int64_t totalFrames, videoBytes; + int64_t totalFrames, videoBytes, device; + bool initialised = false; }; diff --git a/torchvision/io/gpu_decoder.py b/torchvision/io/gpu_decoder.py index 9b72cc32c9c..d05f8bf3af2 100644 --- a/torchvision/io/gpu_decoder.py +++ b/torchvision/io/gpu_decoder.py @@ -2,9 +2,9 @@ class GPUDecoder: - def __init__(self, src_file: str, use_dev_frame: bool = True): + def __init__(self, src_file: str, use_dev_frame: bool = True, dev: int = 0): torch.ops.load_library('build/lib.linux-x86_64-3.8/torchvision/Decoder.so') - self.decoder = torch.classes.torchvision.GPUDecoder(src_file, use_dev_frame) + self.decoder = torch.classes.torchvision.GPUDecoder(src_file, use_dev_frame, dev) def decode_frame(self): return self.decoder.decode() From 86e373eca81d9161a7449ee6c7c7501724de7601 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Mon, 13 Dec 2021 13:15:37 -0800 Subject: [PATCH 04/39] Fixed empty and missing frames --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 27 ++++++----- torchvision/csrc/io/decoder/gpu/decoder.h | 23 +++++----- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 45 +++++++++++++------ 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 12d92cdf76c..6708391c723 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -68,7 +68,7 @@ void Decoder::release() const cuCtxPopCurrent(NULL); } -torch::Tensor Decoder::Decode(const uint8_t *data, int64_t size, int64_t flags, int64_t pts) +int Decoder::Decode(const uint8_t *data, int64_t size, int64_t flags, int64_t pts) { numDecodedFrames = 0; CUVIDSOURCEDATAPACKET pkt = { @@ -84,7 +84,17 @@ torch::Tensor Decoder::Decode(const uint8_t *data, int64_t size, int64_t flags, cuvidParseVideoData(parser, &pkt), __LINE__); cuvidStream = 0; - return frameTensor; + return numDecodedFrames; +} + +uint8_t * Decoder::FetchFrame() +{ + if (decodedFrames.empty()) { + return nullptr; + } + uint8_t *frame = decodedFrames.front(); + decodedFrames.pop(); + return frame; } int Decoder::HandlePictureDecode(CUVIDPICPARAMS *picParams) @@ -143,11 +153,8 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) uint8_t *decodedFrame = nullptr; if (useDeviceFrame) { cuMemAlloc((CUdeviceptr *)&decodedFrame, GetFrameSize()); - } - else { - frameTensor = torch::empty({GetFrameSize()}, - torch::dtype(torch::kU8).device(torch::kCPU)); - decodedFrame = frameTensor.data_ptr(); + } else { + decodedFrame = (uint8_t *)malloc(GetFrameSize() * sizeof(uint8_t)); } numDecodedFrames++; @@ -186,11 +193,7 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) CheckForCudaErrors( cuStreamSynchronize(cuvidStream), __LINE__); - if (useDeviceFrame) { - auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); - frameTensor = torch::from_blob( - decodedFrame, {GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); - } + decodedFrames.push(decodedFrame); CheckForCudaErrors( cuCtxPopCurrent(NULL), __LINE__); diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index ee79786eb7e..78fb2d32c80 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -1,15 +1,15 @@ #include #include #include +#include #include #include -#include static auto CheckForCudaErrors = [](CUresult result, int lineNum) { if (CUDA_SUCCESS != result) { std::stringstream errorStream; - const char *errorName = nullptr, *errorDesc = nullptr; + const char *errorName = nullptr; errorStream << __FILE__ << ":" << lineNum << std::endl; if (CUDA_SUCCESS != cuGetErrorName(result, &errorName)) { @@ -34,10 +34,15 @@ class Decoder { Decoder() {} ~Decoder(); void init(CUcontext, cudaVideoCodec, bool, const Rect * = NULL, const Dim * = NULL, int64_t = 1000, bool = false, bool = false, int64_t = 0, int64_t = 0); - torch::Tensor Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); + int Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } void release() const; - int64_t GetNumDecodedFrames() const { return numDecodedFrames; } + bool UseDeviceFrame() const { return useDeviceFrame; } + int GetFrameSize() const + { + return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; + } + uint8_t * FetchFrame(); private: CUcontext cuContext = NULL; @@ -49,7 +54,6 @@ class Decoder { CUstream cuvidStream = 0; int numDecodedFrames = 0; unsigned int numChromaPlanes = 0; - torch::Tensor frameTensor; // dimension of the output unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; @@ -73,6 +77,7 @@ class Decoder { size_t deviceFramePitch = 0; bool m_bDeviceFramePitched = false; int m_nFrameAlloc = 0; + std::queue decodedFrames; static int CUDAAPI HandleVideoSequenceProc(void *pUserData, CUVIDEOFORMAT *pVideoFormat) { return ((Decoder *)pUserData)->HandleVideoSequence(pVideoFormat); } static int CUDAAPI HandlePictureDecodeProc(void *pUserData, CUVIDPICPARAMS *pPicParams) { return ((Decoder *)pUserData)->HandlePictureDecode(pPicParams); } @@ -86,16 +91,10 @@ class Decoder { int HandlePictureDisplay(CUVIDPARSERDISPINFO *pDispInfo); int GetOperatingPoint(CUVIDOPERATINGPOINTINFO *pOPInfo); - int GetWidth() + int GetWidth() const { - assert(width); return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || videoOutputFormat == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; } - int GetFrameSize() - { - assert(width); - return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; - } }; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 373f919b4cb..425ab69afde 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,4 +1,5 @@ #include +#include #include "gpu_decoder.h" GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev) : demuxer(src_file.c_str()), device(dev) @@ -30,22 +31,40 @@ GPUDecoder::~GPUDecoder() torch::Tensor GPUDecoder::decode() { uint8_t *video; - torch::Tensor framesReturned; + torch::Tensor frameTensor; int64_t numFrames; clock_t start, end; - double cpu_time_used; - start = clock(); - demuxer.Demux(&video, &videoBytes); - end = clock(); - cpu_time_used = ((double) (end - start)) / CLOCKS_PER_SEC; - demux_time += cpu_time_used; - framesReturned = dec.Decode(video, videoBytes); - numFrames = dec.GetNumDecodedFrames(); - end = clock(); - cpu_time_used = ((double) (end - start)) / CLOCKS_PER_SEC; - decode_time += cpu_time_used; + double cpu_time_used1, cpu_time_used2; + uint8_t *frame = nullptr; + do + { + start = clock(); + demuxer.Demux(&video, &videoBytes); + end = clock(); + cpu_time_used1 = ((double) (end - start)) / CLOCKS_PER_SEC; + start = clock(); + numFrames = dec.Decode(video, videoBytes); + end = clock(); + frame = dec.FetchFrame(); + cpu_time_used2 = ((double) (end - start)) / CLOCKS_PER_SEC; + } while (frame == nullptr && videoBytes > 0); + demux_time += cpu_time_used1; + decode_time += cpu_time_used2; totalFrames += numFrames; - return framesReturned; + if (frame == nullptr) { + auto options = torch::TensorOptions().dtype(torch::kU8).device(dec.UseDeviceFrame() ? torch::kCUDA : torch::kCPU); + return torch::zeros({0}, options); + } + if (dec.UseDeviceFrame()) { + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + frameTensor = torch::from_blob( + frame, {dec.GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); + } else { + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCPU); + frameTensor = torch::from_blob( + frame, {dec.GetFrameSize()}, [](auto p) { free(p); }, options); + } + return frameTensor; } int64_t GPUDecoder::getDemuxedBytes() From 717ee016c33e96e92584d5e62c01f4d074594f9d Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 14 Dec 2021 04:45:42 -0800 Subject: [PATCH 05/39] Free remaining frames in the queue --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 11 ++++++++++- torchvision/csrc/io/decoder/gpu/decoder.h | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 6708391c723..2e6f63f1bd4 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -59,12 +59,21 @@ Decoder::~Decoder() cuvidCtxLockDestroy(ctxLock); } -void Decoder::release() const +void Decoder::release() { cuCtxPushCurrent(cuContext); if (decoder) { cuvidDestroyDecoder(decoder); } + while (!decodedFrames.empty()) { + uint8_t *frame = decodedFrames.front(); + decodedFrames.pop(); + if (useDeviceFrame) { + cuMemFree((CUdeviceptr)frame); + } else { + free(frame); + } + } cuCtxPopCurrent(NULL); } diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 78fb2d32c80..7158f8a53bf 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -36,7 +36,7 @@ class Decoder { void init(CUcontext, cudaVideoCodec, bool, const Rect * = NULL, const Dim * = NULL, int64_t = 1000, bool = false, bool = false, int64_t = 0, int64_t = 0); int Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } - void release() const; + void release(); bool UseDeviceFrame() const { return useDeviceFrame; } int GetFrameSize() const { From 75f8f48a72536fe8568a4a44d843c0ee631d7423 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 14 Dec 2021 14:45:09 -0800 Subject: [PATCH 06/39] Added nv12 to yuv420 conversion support for host frames --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 10 +++--- torchvision/csrc/io/decoder/gpu/decoder.h | 21 +++++------- torchvision/csrc/io/decoder/gpu/demuxer.h | 7 ++-- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 7 ++-- torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 34 ++++++++++++++++++- torchvision/io/gpu_decoder.py | 6 ++-- 6 files changed, 59 insertions(+), 26 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 2e6f63f1bd4..1e3d312b256 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -386,13 +386,13 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) if ((vidFormat->coded_width > nMaxWidth) || (vidFormat->coded_height > nMaxHeight)) { // For VP9, let driver handle the change if new width/height > maxwidth/maxheight - if ((videoCodec != cudaVideoCodec_VP9) || m_bReconfigExternal) { + if ((videoCodec != cudaVideoCodec_VP9) || reconfigExternal) { throw std::runtime_error("Reconfigure Not supported when width/height > maxwidth/maxheight"); } return 1; } - if (!bDecodeResChange && !m_bReconfigExtPPChange) { + if (!bDecodeResChange && !reconfigExtPPChange) { // if the coded_width/coded_height hasn't changed but display resolution has changed, then need to update width/height for // correct output without cropping. Example : 1920x1080 vs 1920x1088 if (bDisplayRectChange) { @@ -421,10 +421,10 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) // If external reconfigure is called along with resolution change even if post processing params is not changed, // do full reconfigure params update - if ((m_bReconfigExternal && bDecodeResChange) || m_bReconfigExtPPChange) { + if ((reconfigExternal && bDecodeResChange) || reconfigExtPPChange) { // update display rect and target resolution if requested explicitely - m_bReconfigExternal = false; - m_bReconfigExtPPChange = false; + reconfigExternal = false; + reconfigExtPPChange = false; videoFormat = *vidFormat; if (!(cropRect.right && cropRect.bottom) && !(resizeDim.width && resizeDim.height)) { width = vidFormat->display_area.right - vidFormat->display_area.left; diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 7158f8a53bf..c214ab9c8c6 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -35,14 +35,19 @@ class Decoder { ~Decoder(); void init(CUcontext, cudaVideoCodec, bool, const Rect * = NULL, const Dim * = NULL, int64_t = 1000, bool = false, bool = false, int64_t = 0, int64_t = 0); int Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); - cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } void release(); + uint8_t * FetchFrame(); bool UseDeviceFrame() const { return useDeviceFrame; } + cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } int GetFrameSize() const { return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; } - uint8_t * FetchFrame(); + int GetWidth() const + { + return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || videoOutputFormat == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; + } + int GetHeight() const { return lumaHeight; } private: CUcontext cuContext = NULL; @@ -72,11 +77,9 @@ class Decoder { unsigned int operatingPoint = 0; bool dispAllLayers = false; int decodePicCount = 0, picNumInDecodeOrder[32]; - bool m_bReconfigExternal = false; - bool m_bReconfigExtPPChange = false; + bool reconfigExternal = false; + bool reconfigExtPPChange = false; size_t deviceFramePitch = 0; - bool m_bDeviceFramePitched = false; - int m_nFrameAlloc = 0; std::queue decodedFrames; static int CUDAAPI HandleVideoSequenceProc(void *pUserData, CUVIDEOFORMAT *pVideoFormat) { return ((Decoder *)pUserData)->HandleVideoSequence(pVideoFormat); } @@ -90,11 +93,5 @@ class Decoder { int HandlePictureDecode(CUVIDPICPARAMS *pPicParams); int HandlePictureDisplay(CUVIDPARSERDISPINFO *pDispInfo); int GetOperatingPoint(CUVIDOPERATINGPOINTINFO *pOPInfo); - - int GetWidth() const - { - return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || videoOutputFormat == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; - } - }; diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index f1abb51c209..bff1a0e4f00 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -4,9 +4,9 @@ extern "C" { #include #include } -#include -inline bool check(int ret, int line) { +inline bool check(int ret, int line) +{ if (ret < 0) { printf("Error %d at line %d in file %s.\n", ret, line); return false; @@ -175,7 +175,8 @@ class Demuxer { } }; -inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) { +inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) +{ switch (id) { case AV_CODEC_ID_MPEG1VIDEO : return cudaVideoCodec_MPEG1; case AV_CODEC_ID_MPEG2VIDEO : return cudaVideoCodec_MPEG2; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 425ab69afde..028888b7a14 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -2,7 +2,7 @@ #include #include "gpu_decoder.h" -GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev) : demuxer(src_file.c_str()), device(dev) +GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev, std::string out_format) : demuxer(src_file.c_str()), device(dev), output_format(out_format) { if (cudaSuccess != cudaSetDevice(device)) { printf("Error setting device\n"); @@ -60,6 +60,9 @@ torch::Tensor GPUDecoder::decode() frameTensor = torch::from_blob( frame, {dec.GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); } else { + if (output_format == "yuv420") { + NV12ToYUV420(frame, dec.GetWidth(), dec.GetHeight()); + } auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCPU); frameTensor = torch::from_blob( frame, {dec.GetFrameSize()}, [](auto p) { free(p); }, options); @@ -89,7 +92,7 @@ int64_t GPUDecoder::getTotalFramesDecoded() TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") - .def(torch::init()) + .def(torch::init()) .def("decode", &GPUDecoder::decode) .def("getDecodeTime", &GPUDecoder::getDecodeTime) .def("getDemuxTime", &GPUDecoder::getDemuxTime) diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index 42080e53f38..8b9631fb14b 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -4,7 +4,7 @@ class GPUDecoder : public torch::CustomClassHolder { public: - GPUDecoder(std::string, bool, int64_t); + GPUDecoder(std::string, bool, int64_t, std::string); ~GPUDecoder(); torch::Tensor decode(); double getDecodeTime(); @@ -19,4 +19,36 @@ class GPUDecoder : public torch::CustomClassHolder { double demux_time, decode_time; int64_t totalFrames, videoBytes, device; bool initialised = false; + std::string output_format; + + void NV12ToYUV420(uint8_t *frame, int width, int height) + { + int pitch = width; + uint8_t *ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; + + // sizes of source surface plane + int sizePlaneY = pitch * height; + int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2); + int sizePlaneV = sizePlaneU; + + uint8_t *uv = frame + sizePlaneY; + uint8_t *u = uv; + uint8_t *v = uv + sizePlaneU; + + // split chroma from interleave to planar + for (int y = 0; y < (height + 1) / 2; y++) { + for (int x = 0; x < (width + 1) / 2; x++) { + u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; + ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; + } + } + if (pitch == width) { + memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); + } else { + for (int i = 0; i < (height + 1) / 2; i++) { + memcpy(v + ((pitch + 1) / 2) * i, ptr + ((width + 1) / 2) * i, ((width + 1) / 2) * sizeof(uint8_t)); + } + } + delete[] ptr; + } }; diff --git a/torchvision/io/gpu_decoder.py b/torchvision/io/gpu_decoder.py index d05f8bf3af2..8ac46b5fe6b 100644 --- a/torchvision/io/gpu_decoder.py +++ b/torchvision/io/gpu_decoder.py @@ -2,9 +2,9 @@ class GPUDecoder: - def __init__(self, src_file: str, use_dev_frame: bool = True, dev: int = 0): - torch.ops.load_library('build/lib.linux-x86_64-3.8/torchvision/Decoder.so') - self.decoder = torch.classes.torchvision.GPUDecoder(src_file, use_dev_frame, dev) + def __init__(self, src_file: str, use_dev_frame: bool = True, dev: int = 0, output_format = "nv12"): + torch.ops.load_library("build/lib.linux-x86_64-3.8/torchvision/Decoder.so") + self.decoder = torch.classes.torchvision.GPUDecoder(src_file, use_dev_frame, dev, output_format) def decode_frame(self): return self.decoder.decode() From 34b82057cfbe70b69d900c028b1b88c19adc1af4 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 15 Dec 2021 09:21:51 -0800 Subject: [PATCH 07/39] Added unit test and cleaned up code --- test/test_video_gpu_decoder.py | 41 +++++++++++ torchvision/csrc/io/decoder/gpu/demuxer.h | 23 +++---- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 46 +------------ torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 6 +- torchvision/io/__init__.py | 69 +++++++++++++++++-- torchvision/io/_video_opt.py | 7 ++ torchvision/io/gpu_decoder.py | 22 ------ 7 files changed, 124 insertions(+), 90 deletions(-) create mode 100644 test/test_video_gpu_decoder.py delete mode 100644 torchvision/io/gpu_decoder.py diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py new file mode 100644 index 00000000000..f6c181fc30a --- /dev/null +++ b/test/test_video_gpu_decoder.py @@ -0,0 +1,41 @@ +import os + +import pytest +import torch +from torchvision.io import _HAS_VIDEO_DECODER, VideoReader + +try: + import av +except ImportError: + av = None + +VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") + +test_videos = [ + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", + "v_SoccerJuggling_g23_c01.avi", + "v_SoccerJuggling_g24_c01.avi", + "R6llTwEh07w.mp4", + "SOX5yA1l24A.mp4", + "WUzgd7C1pWA.mp4", +] + + +@pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") +class TestVideoGPUDecoder: + @pytest.mark.skipif(av is None, reason="PyAV unavailable") + def test_frame_reading(self): + for test_video in test_videos: + full_path = os.path.join(VIDEO_DIR, test_video) + decoder = VideoReader(full_path, device='cuda:0', output_format="yuv420", use_device_frame=False) + with av.open(full_path) as container: + for av_frame in container.decode(container.streams.video[0]): + av_frames = torch.tensor(av_frame.to_ndarray().flatten()) + vision_frames = next(decoder)["data"] + mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.float())) + assert mean_delta < 0.1 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index bff1a0e4f00..02e7c369823 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -117,12 +117,12 @@ class Demuxer { return eVideoCodec; } - bool Demux(uint8_t **ppVideo, int64_t *pnVideoBytes, int64_t *pts = NULL) + bool Demux(uint8_t **video, int64_t *videoBytes) { if (!fmtCtx) { return false; } - *pnVideoBytes = 0; + *videoBytes = 0; if (pkt.data) { av_packet_unref(&pkt); @@ -141,11 +141,8 @@ class Demuxer { } check_for_errors(av_bsf_send_packet(bsfCtx, &pkt)); check_for_errors(av_bsf_receive_packet(bsfCtx, &pktFiltered)); - *ppVideo = pktFiltered.data; - *pnVideoBytes = pktFiltered.size; - if (pts) { - *pts = (int64_t) (pktFiltered.pts * userTimeScale * timeBase); - } + *video = pktFiltered.data; + *videoBytes = pktFiltered.size; } else { if (bMp4MPEG4 && (frameCount == 0)) { int extraDataSize = fmtCtx->streams[iVideoStream]->codecpar->extradata_size; @@ -158,16 +155,12 @@ class Demuxer { } memcpy(pDataWithHeader, fmtCtx->streams[iVideoStream]->codecpar->extradata, extraDataSize); memcpy(pDataWithHeader+extraDataSize, pkt.data+3, pkt.size - 3 * sizeof(uint8_t)); - *ppVideo = pDataWithHeader; - *pnVideoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); + *video = pDataWithHeader; + *videoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); } } else { - *ppVideo = pkt.data; - *pnVideoBytes = pkt.size; - } - - if (pts) { - *pts = (int64_t)(pkt.pts * userTimeScale * timeBase); + *video = pkt.data; + *videoBytes = pkt.size; } } frameCount++; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 028888b7a14..bdd1ecd311b 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,4 +1,3 @@ -#include #include #include "gpu_decoder.h" @@ -13,9 +12,6 @@ GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev, std: __LINE__); dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec()), useDevFrame); initialised = true; - demux_time = 0.0; - decode_time = 0.0; - totalFrames = 0; } GPUDecoder::~GPUDecoder() @@ -30,27 +26,15 @@ GPUDecoder::~GPUDecoder() torch::Tensor GPUDecoder::decode() { - uint8_t *video; torch::Tensor frameTensor; - int64_t numFrames; - clock_t start, end; - double cpu_time_used1, cpu_time_used2; - uint8_t *frame = nullptr; + int64_t videoBytes, numFrames; + uint8_t *frame = nullptr, *video = nullptr; do { - start = clock(); demuxer.Demux(&video, &videoBytes); - end = clock(); - cpu_time_used1 = ((double) (end - start)) / CLOCKS_PER_SEC; - start = clock(); numFrames = dec.Decode(video, videoBytes); - end = clock(); frame = dec.FetchFrame(); - cpu_time_used2 = ((double) (end - start)) / CLOCKS_PER_SEC; } while (frame == nullptr && videoBytes > 0); - demux_time += cpu_time_used1; - decode_time += cpu_time_used2; - totalFrames += numFrames; if (frame == nullptr) { auto options = torch::TensorOptions().dtype(torch::kU8).device(dec.UseDeviceFrame() ? torch::kCUDA : torch::kCPU); return torch::zeros({0}, options); @@ -70,33 +54,9 @@ torch::Tensor GPUDecoder::decode() return frameTensor; } -int64_t GPUDecoder::getDemuxedBytes() -{ - return videoBytes; -} - -double GPUDecoder::getDecodeTime() -{ - return decode_time; -} - -double GPUDecoder::getDemuxTime() -{ - return demux_time; -} - -int64_t GPUDecoder::getTotalFramesDecoded() -{ - return totalFrames; -} - TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") .def(torch::init()) - .def("decode", &GPUDecoder::decode) - .def("getDecodeTime", &GPUDecoder::getDecodeTime) - .def("getDemuxTime", &GPUDecoder::getDemuxTime) - .def("getDemuxedBytes", &GPUDecoder::getDemuxedBytes) - .def("getTotalFramesDecoded", &GPUDecoder::getTotalFramesDecoded) + .def("next", &GPUDecoder::decode) ; } diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index 8b9631fb14b..95dd55ef95a 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -7,17 +7,13 @@ class GPUDecoder : public torch::CustomClassHolder { GPUDecoder(std::string, bool, int64_t, std::string); ~GPUDecoder(); torch::Tensor decode(); - double getDecodeTime(); - double getDemuxTime(); - int64_t getDemuxedBytes(); - int64_t getTotalFramesDecoded(); private: Demuxer demuxer; CUcontext ctx; Decoder dec; double demux_time, decode_time; - int64_t totalFrames, videoBytes, device; + int64_t device; bool initialised = false; std::string output_format; diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index d79a29cc67a..86e86bf2928 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -6,6 +6,7 @@ Timebase, VideoMetaData, _HAS_VIDEO_OPT, + _HAS_VIDEO_DECODER, _probe_video_from_file, _probe_video_from_memory, _read_video_from_file, @@ -32,10 +33,6 @@ write_video, ) -from .gpu_decoder import ( - GPUDecoder -) - if _HAS_VIDEO_OPT: @@ -49,6 +46,18 @@ def _has_video_opt() -> bool: return False +if _HAS_VIDEO_DECODER: + + def _has_video_decoder() -> bool: + return True + + +else: + + def _has_video_decoder() -> bool: + return False + + class VideoReader: """ Fine-grained video-reading API. @@ -107,9 +116,44 @@ class VideoReader: num_threads (int, optional): number of threads used by the codec to decode video. Default value (0) enables multithreading with codec-dependent heuristic. The performance will depend on the version of FFMPEG codecs supported. + + device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. + + output_format (str, optional): Format of the output frames. Defaults to ``"rgb"``. + + use_device_frame (bool, optional): Whether to get back device frames(gpu tensors) + or host frames(cpu tensors) after GPU decoding. Defaults to ``True``. + + Currently supported values of (use_device_frame, output_format): + (True, "nv12"), (False, "nv12"), (False, "yuv420"). + """ - def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> None: + def __init__(self, path: str, stream: str = "video", num_threads: int = 0, + device: str = "cpu", output_format: str = "rgb", use_device_frame: bool = True) -> None: + self.is_cuda = False + device = torch.device(device) + if device.type == "cuda": + supported_formats = [ + "nv12", + "yuv420", + ] + if not _has_video_decoder(): + raise RuntimeError("Not compiled with GPU decoder support.") + self.is_cuda = True + if device.index is None: + raise RuntimeError("Invalid cuda device!") + if output_format not in supported_formats: + raise RuntimeError( + f"{output_format} output format not supported with GPU decoding, " + "please use one of {', '.join(supported_formats)}.") + if output_format == "yuv420" and use_device_frame: + raise RuntimeError( + "yuv420 output not yet supported with GPU decoding when use_device_frame=True, " + "please use either nv12 or use_device_frame=False.") + self._c = torch.classes.torchvision.GPUDecoder( + path, use_device_frame, device.index, output_format) + return if not _has_video_opt(): raise RuntimeError( "Not compiled with video_reader support, " @@ -117,6 +161,9 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> No + "ffmpeg (version 4.2 is currently supported) and " + "build torchvision from source." ) + if output_format != "rgb": + raise RuntimeError("Only rgb output is supported with video_reader.") + self._c = torch.classes.torchvision.Video(path, stream, num_threads) def __next__(self) -> Dict[str, Any]: @@ -131,6 +178,11 @@ def __next__(self) -> Dict[str, Any]: and corresponding timestamp (``pts``) in seconds """ + if self.is_cuda: + frame = self._c.next() + if frame.numel() == 0: + raise StopIteration + return {"data": frame} frame, pts = self._c.next() if frame.numel() == 0: raise StopIteration @@ -152,6 +204,8 @@ def seek(self, time_s: float, keyframes_only: bool = False) -> "VideoReader": frame with the exact timestamp if it exists or the first frame with timestamp larger than ``time_s``. """ + if self.is_cuda: + raise RuntimeError("seek() not yet supported with GPU decoding.") self._c.seek(time_s, keyframes_only) return self @@ -161,6 +215,8 @@ def get_metadata(self) -> Dict[str, Any]: Returns: (dict): dictionary containing duration and frame rate for every stream """ + if self.is_cuda: + raise RuntimeError("get_metadata() not yet supported with GPU decoding.") return self._c.get_metadata() def set_current_stream(self, stream: str) -> bool: @@ -180,6 +236,8 @@ def set_current_stream(self, stream: str) -> bool: Returns: (bool): True on succes, False otherwise """ + if self.is_cuda: + print("GPU decoding only works with video stream.") return self._c.set_current_stream(stream) @@ -194,6 +252,7 @@ def set_current_stream(self, stream: str) -> bool: "_read_video_timestamps_from_memory", "_probe_video_from_memory", "_HAS_VIDEO_OPT", + "_HAS_VIDEO_DECODER", "_read_video_clip_from_memory", "_read_video_meta_data", "VideoMetaData", diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index d75dd41534d..1e802486d3e 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -14,6 +14,13 @@ except (ImportError, OSError): _HAS_VIDEO_OPT = False +try: + #_load_library("Decoder") + torch.ops.load_library("build/lib.linux-x86_64-3.8/torchvision/Decoder.so") + _HAS_VIDEO_DECODER = True +except (ImportError, OSError): + _HAS_VIDEO_DECODER = False + default_timebase = Fraction(0, 1) diff --git a/torchvision/io/gpu_decoder.py b/torchvision/io/gpu_decoder.py deleted file mode 100644 index 8ac46b5fe6b..00000000000 --- a/torchvision/io/gpu_decoder.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch - - -class GPUDecoder: - def __init__(self, src_file: str, use_dev_frame: bool = True, dev: int = 0, output_format = "nv12"): - torch.ops.load_library("build/lib.linux-x86_64-3.8/torchvision/Decoder.so") - self.decoder = torch.classes.torchvision.GPUDecoder(src_file, use_dev_frame, dev, output_format) - - def decode_frame(self): - return self.decoder.decode() - - def get_total_decoding_time(self): - return self.decoder.getDecodeTime() - - def get_total_demuxing_time(self): - return self.decoder.getDemuxTime() - - def get_demuxed_bytes(self): - return self.decoder.getDemuxedBytes() - - def get_total_frames_decoded(self): - return self.decoder.getTotalFramesDecoded() From 6672a052cf596230c1e3ba2e2d736ed48bdac59d Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 15 Dec 2021 11:14:07 -0800 Subject: [PATCH 08/39] Use CUDA_HOME inside if --- setup.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index afa565f7bef..daa7ca70638 100644 --- a/setup.py +++ b/setup.py @@ -385,7 +385,8 @@ def get_extensions(): print(f"{library} header files were not found, disabling ffmpeg support") has_ffmpeg = False - if has_ffmpeg: + #if has_ffmpeg: + if False: print(f"ffmpeg include path: {ffmpeg_include_dir}") print(f"ffmpeg library_dir: {ffmpeg_library_dir}") @@ -438,15 +439,14 @@ def get_extensions(): print(f"video codec found: {video_codec_found}") - gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") - print(f"DECODER PATH {gpu_decoder_path}") - gpu_decoder_src = ( - glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) - ) - cuda_libs = os.path.join(CUDA_HOME, 'lib64') - cuda_inc = os.path.join(CUDA_HOME, 'include') - if video_codec_found and has_ffmpeg: + gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") + gpu_decoder_src = ( + glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) + ) + cuda_libs = os.path.join(CUDA_HOME, 'lib64') + cuda_inc = os.path.join(CUDA_HOME, 'include') + ext_modules.append( extension( "torchvision.Decoder", From 30af8aca2282bab8cd32795564549d59dd788058 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 15 Dec 2021 11:17:04 -0800 Subject: [PATCH 09/39] Undo commented out code --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index daa7ca70638..1c1a99d74c6 100644 --- a/setup.py +++ b/setup.py @@ -385,8 +385,7 @@ def get_extensions(): print(f"{library} header files were not found, disabling ffmpeg support") has_ffmpeg = False - #if has_ffmpeg: - if False: + if has_ffmpeg: print(f"ffmpeg include path: {ffmpeg_include_dir}") print(f"ffmpeg library_dir: {ffmpeg_library_dir}") From 4b9bab8ae039a43727779dfd912a2f6f45014b34 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 15 Dec 2021 15:49:55 -0800 Subject: [PATCH 10/39] Add Readme --- torchvision/csrc/io/decoder/gpu/README.rst | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 torchvision/csrc/io/decoder/gpu/README.rst diff --git a/torchvision/csrc/io/decoder/gpu/README.rst b/torchvision/csrc/io/decoder/gpu/README.rst new file mode 100644 index 00000000000..b9d0947c6f3 --- /dev/null +++ b/torchvision/csrc/io/decoder/gpu/README.rst @@ -0,0 +1,20 @@ +GPU Decoder +=========== + +GPU decoder depends on ffmpeg for demuxing, uses NVDECODE APIs from the nvidia-video-codec sdk and uses cuda for processing on gpu. In order to use this, please follow the following steps: + +* Download the latest `nvidia-video-codec-sdk `_ +* Extract the zipped file and copy the header files and libraries. + +.. code:: bash + +sudo cp Interface/* /usr/local/include/ +sudo cp Lib/linux/stubs/x86_64/libnv* /usr/local/lib/ + +* Install ffmpeg and make sure ffmpeg headers and libraries are present under /usr/local/include and /usr/local/lib respectively. +* Set CUDA_HOME environment variable to the cuda root directory. +* Build torchvision from source: + +.. code:: bash + + python setup.py install From 5afb6ddb53d1475fce1f74b8c70c83718272af80 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 16 Dec 2021 15:46:58 -0800 Subject: [PATCH 11/39] Remove output_format and use_device_frame optional arguments from the VideoReader API --- test/test_video_gpu_decoder.py | 4 +- torchvision/csrc/io/decoder/gpu/decoder.cpp | 21 +++----- torchvision/csrc/io/decoder/gpu/decoder.h | 4 +- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 53 ++++++++++++++----- torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 36 ++----------- torchvision/io/__init__.py | 39 +++++--------- torchvision/io/_video_opt.py | 3 +- 7 files changed, 65 insertions(+), 95 deletions(-) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index f6c181fc30a..dc9fb149007 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -28,12 +28,12 @@ class TestVideoGPUDecoder: def test_frame_reading(self): for test_video in test_videos: full_path = os.path.join(VIDEO_DIR, test_video) - decoder = VideoReader(full_path, device='cuda:0', output_format="yuv420", use_device_frame=False) + decoder = VideoReader(full_path, device='cuda:0') with av.open(full_path) as container: for av_frame in container.decode(container.streams.video[0]): av_frames = torch.tensor(av_frame.to_ndarray().flatten()) vision_frames = next(decoder)["data"] - mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.float())) + mean_delta = torch.mean(torch.abs(av_frames.float() - decoder.reformat(vision_frames).float())) assert mean_delta < 0.1 diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 1e3d312b256..507be4ce05a 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -15,13 +15,12 @@ static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 2 : 1; } -void Decoder::init(CUcontext context, cudaVideoCodec codec, bool useDevFrame, const Rect *rect, - const Dim *dim, int64_t clockRate, bool lowLatency, +void Decoder::init(CUcontext context, cudaVideoCodec codec, const Rect *rect, + const Dim *dim, bool lowLatency, bool forceZeroLat, int64_t maxWidth, int64_t maxHeight) { cuContext = context; forceZeroLatency = forceZeroLat; - useDeviceFrame = useDevFrame; videoCodec = codec; nMaxHeight = maxHeight; nMaxWidth = maxWidth; @@ -38,7 +37,7 @@ void Decoder::init(CUcontext context, cudaVideoCodec codec, bool useDevFrame, co CUVIDPARSERPARAMS parserParams = { .CodecType = codec, .ulMaxNumDecodeSurfaces = 1, - .ulClockRate = clockRate, + .ulClockRate = 1000, .ulMaxDisplayDelay = lowLatency ? 0u : 1u, .pUserData = this, .pfnSequenceCallback = HandleVideoSequenceProc, @@ -68,11 +67,7 @@ void Decoder::release() while (!decodedFrames.empty()) { uint8_t *frame = decodedFrames.front(); decodedFrames.pop(); - if (useDeviceFrame) { - cuMemFree((CUdeviceptr)frame); - } else { - free(frame); - } + cuMemFree((CUdeviceptr)frame); } cuCtxPopCurrent(NULL); } @@ -160,11 +155,7 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) } uint8_t *decodedFrame = nullptr; - if (useDeviceFrame) { - cuMemAlloc((CUdeviceptr *)&decodedFrame, GetFrameSize()); - } else { - decodedFrame = (uint8_t *)malloc(GetFrameSize() * sizeof(uint8_t)); - } + cuMemAlloc((CUdeviceptr *)&decodedFrame, GetFrameSize()); numDecodedFrames++; @@ -173,7 +164,7 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) m.srcMemoryType = CU_MEMORYTYPE_DEVICE; m.srcDevice = dpSrcFrame; m.srcPitch = nSrcPitch; - m.dstMemoryType = useDeviceFrame ? CU_MEMORYTYPE_DEVICE : CU_MEMORYTYPE_HOST; + m.dstMemoryType = CU_MEMORYTYPE_DEVICE; m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); m.dstPitch = deviceFramePitch ? deviceFramePitch : GetWidth() * bytesPerPixel; m.WidthInBytes = GetWidth() * bytesPerPixel; diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index c214ab9c8c6..408156f9e1b 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -33,11 +33,10 @@ class Decoder { public: Decoder() {} ~Decoder(); - void init(CUcontext, cudaVideoCodec, bool, const Rect * = NULL, const Dim * = NULL, int64_t = 1000, bool = false, bool = false, int64_t = 0, int64_t = 0); + void init(CUcontext, cudaVideoCodec, const Rect * = NULL, const Dim * = NULL, bool = false, bool = false, int64_t = 0, int64_t = 0); int Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); void release(); uint8_t * FetchFrame(); - bool UseDeviceFrame() const { return useDeviceFrame; } cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } int GetFrameSize() const { @@ -55,7 +54,6 @@ class Decoder { CUvideoparser parser = NULL; CUvideodecoder decoder = NULL; bool forceZeroLatency = false; - bool useDeviceFrame; CUstream cuvidStream = 0; int numDecodedFrames = 0; unsigned int numChromaPlanes = 0; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index bdd1ecd311b..b65f95e080a 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,7 +1,6 @@ -#include #include "gpu_decoder.h" -GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev, std::string out_format) : demuxer(src_file.c_str()), device(dev), output_format(out_format) +GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) : demuxer(src_file.c_str()), device(dev) { if (cudaSuccess != cudaSetDevice(device)) { printf("Error setting device\n"); @@ -10,7 +9,7 @@ GPUDecoder::GPUDecoder(std::string src_file, bool useDevFrame, int64_t dev, std: CheckForCudaErrors( cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); - dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec()), useDevFrame); + dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec())); initialised = true; } @@ -36,27 +35,53 @@ torch::Tensor GPUDecoder::decode() frame = dec.FetchFrame(); } while (frame == nullptr && videoBytes > 0); if (frame == nullptr) { - auto options = torch::TensorOptions().dtype(torch::kU8).device(dec.UseDeviceFrame() ? torch::kCUDA : torch::kCPU); + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); return torch::zeros({0}, options); } - if (dec.UseDeviceFrame()) { - auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); - frameTensor = torch::from_blob( - frame, {dec.GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + frameTensor = torch::from_blob( + frame, {dec.GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); + return frameTensor; +} + +torch::Tensor GPUDecoder::NV12ToYUV420(torch::Tensor frameTensor) +{ + int width = dec.GetWidth(), height = dec.GetHeight(); + int pitch = width; + uint8_t *frame = frameTensor.data_ptr(); + uint8_t *ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; + + // sizes of source surface plane + int sizePlaneY = pitch * height; + int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2); + int sizePlaneV = sizePlaneU; + + uint8_t *uv = frame + sizePlaneY; + uint8_t *u = uv; + uint8_t *v = uv + sizePlaneU; + + // split chroma from interleave to planar + for (int y = 0; y < (height + 1) / 2; y++) { + for (int x = 0; x < (width + 1) / 2; x++) { + u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; + ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; + } + } + if (pitch == width) { + memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); } else { - if (output_format == "yuv420") { - NV12ToYUV420(frame, dec.GetWidth(), dec.GetHeight()); + for (int i = 0; i < (height + 1) / 2; i++) { + memcpy(v + ((pitch + 1) / 2) * i, ptr + ((width + 1) / 2) * i, ((width + 1) / 2) * sizeof(uint8_t)); } - auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCPU); - frameTensor = torch::from_blob( - frame, {dec.GetFrameSize()}, [](auto p) { free(p); }, options); } + delete[] ptr; return frameTensor; } TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") - .def(torch::init()) + .def(torch::init()) .def("next", &GPUDecoder::decode) + .def("reformat", &GPUDecoder::NV12ToYUV420) ; } diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index 95dd55ef95a..d400c213bf6 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -1,50 +1,20 @@ #include +#include #include "decoder.h" #include "demuxer.h" class GPUDecoder : public torch::CustomClassHolder { public: - GPUDecoder(std::string, bool, int64_t, std::string); + GPUDecoder(std::string, int64_t); ~GPUDecoder(); torch::Tensor decode(); + torch::Tensor NV12ToYUV420(torch::Tensor); private: Demuxer demuxer; CUcontext ctx; Decoder dec; - double demux_time, decode_time; int64_t device; bool initialised = false; std::string output_format; - - void NV12ToYUV420(uint8_t *frame, int width, int height) - { - int pitch = width; - uint8_t *ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; - - // sizes of source surface plane - int sizePlaneY = pitch * height; - int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2); - int sizePlaneV = sizePlaneU; - - uint8_t *uv = frame + sizePlaneY; - uint8_t *u = uv; - uint8_t *v = uv + sizePlaneU; - - // split chroma from interleave to planar - for (int y = 0; y < (height + 1) / 2; y++) { - for (int x = 0; x < (width + 1) / 2; x++) { - u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; - ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; - } - } - if (pitch == width) { - memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); - } else { - for (int i = 0; i < (height + 1) / 2; i++) { - memcpy(v + ((pitch + 1) / 2) * i, ptr + ((width + 1) / 2) * i, ((width + 1) / 2) * sizeof(uint8_t)); - } - } - delete[] ptr; - } }; diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 86e86bf2928..a4a250f0fa4 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -119,40 +119,18 @@ class VideoReader: device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``. - output_format (str, optional): Format of the output frames. Defaults to ``"rgb"``. - - use_device_frame (bool, optional): Whether to get back device frames(gpu tensors) - or host frames(cpu tensors) after GPU decoding. Defaults to ``True``. - - Currently supported values of (use_device_frame, output_format): - (True, "nv12"), (False, "nv12"), (False, "yuv420"). - """ - def __init__(self, path: str, stream: str = "video", num_threads: int = 0, - device: str = "cpu", output_format: str = "rgb", use_device_frame: bool = True) -> None: + def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None: self.is_cuda = False device = torch.device(device) if device.type == "cuda": - supported_formats = [ - "nv12", - "yuv420", - ] if not _has_video_decoder(): raise RuntimeError("Not compiled with GPU decoder support.") self.is_cuda = True if device.index is None: raise RuntimeError("Invalid cuda device!") - if output_format not in supported_formats: - raise RuntimeError( - f"{output_format} output format not supported with GPU decoding, " - "please use one of {', '.join(supported_formats)}.") - if output_format == "yuv420" and use_device_frame: - raise RuntimeError( - "yuv420 output not yet supported with GPU decoding when use_device_frame=True, " - "please use either nv12 or use_device_frame=False.") - self._c = torch.classes.torchvision.GPUDecoder( - path, use_device_frame, device.index, output_format) + self._c = torch.classes.torchvision.GPUDecoder(path, device.index) return if not _has_video_opt(): raise RuntimeError( @@ -161,8 +139,6 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, + "ffmpeg (version 4.2 is currently supported) and " + "build torchvision from source." ) - if output_format != "rgb": - raise RuntimeError("Only rgb output is supported with video_reader.") self._c = torch.classes.torchvision.Video(path, stream, num_threads) @@ -240,6 +216,17 @@ def set_current_stream(self, stream: str) -> bool: print("GPU decoding only works with video stream.") return self._c.set_current_stream(stream) + def reformat(self, tensor, format: str = "yuv420"): + supported_formats = [ + "yuv420", + ] + if format not in supported_formats: + raise RuntimeError(f"{format} not supported, please use one of {', '.join(supported_formats)}") + if not isinstance(tensor, torch.Tensor): + raise RuntimeError("Expected tensor as input parameter!") + return self._c.reformat(tensor.cpu()) + + __all__ = [ "write_video", diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index 1e802486d3e..487e28dd168 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -15,8 +15,7 @@ _HAS_VIDEO_OPT = False try: - #_load_library("Decoder") - torch.ops.load_library("build/lib.linux-x86_64-3.8/torchvision/Decoder.so") + _load_library("Decoder") _HAS_VIDEO_DECODER = True except (ImportError, OSError): _HAS_VIDEO_DECODER = False From e8ae42ea6e2c248a7011488f002fc748093a1587 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 17 Dec 2021 04:49:01 -0800 Subject: [PATCH 12/39] Cleaned up init() --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 114 ++++---------------- torchvision/csrc/io/decoder/gpu/decoder.h | 11 +- 2 files changed, 24 insertions(+), 101 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 507be4ce05a..1a43739f861 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -15,21 +15,10 @@ static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 2 : 1; } -void Decoder::init(CUcontext context, cudaVideoCodec codec, const Rect *rect, - const Dim *dim, bool lowLatency, - bool forceZeroLat, int64_t maxWidth, int64_t maxHeight) +void Decoder::init(CUcontext context, cudaVideoCodec codec) { cuContext = context; - forceZeroLatency = forceZeroLat; videoCodec = codec; - nMaxHeight = maxHeight; - nMaxWidth = maxWidth; - if (rect) { - cropRect = *rect; - } - if (dim) { - resizeDim = *dim; - } CheckForCudaErrors( cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__); @@ -38,11 +27,11 @@ void Decoder::init(CUcontext context, cudaVideoCodec codec, const Rect *rect, .CodecType = codec, .ulMaxNumDecodeSurfaces = 1, .ulClockRate = 1000, - .ulMaxDisplayDelay = lowLatency ? 0u : 1u, + .ulMaxDisplayDelay = 0u, .pUserData = this, .pfnSequenceCallback = HandleVideoSequenceProc, .pfnDecodePicture = HandlePictureDecodeProc, - .pfnDisplayPicture = forceZeroLatency ? NULL : HandlePictureDisplayProc, + .pfnDisplayPicture = HandlePictureDisplayProc, .pfnGetOperatingPoint = HandleOperatingPointProc, }; CheckForCudaErrors( @@ -113,14 +102,6 @@ int Decoder::HandlePictureDecode(CUVIDPICPARAMS *picParams) CheckForCudaErrors( cuvidDecodePicture(decoder, picParams), __LINE__); - if (forceZeroLatency && ((!picParams->field_pic_flag) || (picParams->second_field))) { - CUVIDPARSERDISPINFO dispInfo { - .picture_index = picParams->CurrPicIdx, - .progressive_frame = !picParams->field_pic_flag, - .top_field_first = picParams->bottom_field_flag ^ 1, - }; - HandlePictureDisplay(&dispInfo); - } CheckForCudaErrors( cuCtxPopCurrent(NULL), __LINE__); @@ -303,47 +284,22 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT *vidFormat) }; // AV1 has max width/height of sequence in sequence header if (vidFormat->codec == cudaVideoCodec_AV1 && vidFormat->seqhdr_data_length > 0) { - // don't overwrite if it is already set from cmdline or reconfig.txt - if (!(nMaxWidth > vidFormat->coded_width || nMaxHeight > vidFormat->coded_height)) { - CUVIDEOFORMATEX *vidFormatEx = (CUVIDEOFORMATEX *)vidFormat; - nMaxWidth = vidFormatEx->av1.max_width; - nMaxHeight = vidFormatEx->av1.max_height; - } + CUVIDEOFORMATEX *vidFormatEx = (CUVIDEOFORMATEX *)vidFormat; + maxWidth = vidFormatEx->av1.max_width; + maxHeight = vidFormatEx->av1.max_height; } - if (nMaxWidth < (int)vidFormat->coded_width) - nMaxWidth = vidFormat->coded_width; - if (nMaxHeight < (int)vidFormat->coded_height) - nMaxHeight = vidFormat->coded_height; - videoDecodeCreateInfo.ulMaxWidth = nMaxWidth; - videoDecodeCreateInfo.ulMaxHeight = nMaxHeight; - - if (!(cropRect.right && cropRect.bottom) && !(resizeDim.width && resizeDim.height)) { - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; - videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; - } else { - if (resizeDim.width && resizeDim.height) { - videoDecodeCreateInfo.display_area.left = vidFormat->display_area.left; - videoDecodeCreateInfo.display_area.top = vidFormat->display_area.top; - videoDecodeCreateInfo.display_area.right = vidFormat->display_area.right; - videoDecodeCreateInfo.display_area.bottom = vidFormat->display_area.bottom; - width = resizeDim.width; - lumaHeight = resizeDim.height; - } - - if (cropRect.right && cropRect.bottom) { - videoDecodeCreateInfo.display_area.left = cropRect.left; - videoDecodeCreateInfo.display_area.top = cropRect.top; - videoDecodeCreateInfo.display_area.right = cropRect.right; - videoDecodeCreateInfo.display_area.bottom = cropRect.bottom; - width = cropRect.right - cropRect.left; - lumaHeight = cropRect.bottom - cropRect.top; - } - videoDecodeCreateInfo.ulTargetWidth = width; - videoDecodeCreateInfo.ulTargetHeight = lumaHeight; + if (maxWidth < (int)vidFormat->coded_width) { + maxWidth = vidFormat->coded_width; } - + if (maxHeight < (int)vidFormat->coded_height) { + maxHeight = vidFormat->coded_height; + } + videoDecodeCreateInfo.ulMaxWidth = maxWidth; + videoDecodeCreateInfo.ulMaxHeight = maxHeight; + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; + videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; chromaHeight = (int)(ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat))); numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); surfaceHeight = videoDecodeCreateInfo.ulTargetHeight; @@ -353,7 +309,6 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT *vidFormat) displayRect.left = videoDecodeCreateInfo.display_area.left; displayRect.right = videoDecodeCreateInfo.display_area.right; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); CheckForCudaErrors(cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); @@ -375,7 +330,7 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) int nDecodeSurface = vidFormat->min_num_decode_surfaces; - if ((vidFormat->coded_width > nMaxWidth) || (vidFormat->coded_height > nMaxHeight)) { + if ((vidFormat->coded_width > maxWidth) || (vidFormat->coded_height > maxHeight)) { // For VP9, let driver handle the change if new width/height > maxwidth/maxheight if ((videoCodec != cudaVideoCodec_VP9) || reconfigExternal) { throw std::runtime_error("Reconfigure Not supported when width/height > maxwidth/maxheight"); @@ -417,34 +372,10 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) reconfigExternal = false; reconfigExtPPChange = false; videoFormat = *vidFormat; - if (!(cropRect.right && cropRect.bottom) && !(resizeDim.width && resizeDim.height)) { - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - reconfigParams.ulTargetWidth = vidFormat->coded_width; - reconfigParams.ulTargetHeight = vidFormat->coded_height; - } - else { - if (resizeDim.width && resizeDim.height) { - reconfigParams.display_area.left = vidFormat->display_area.left; - reconfigParams.display_area.top = vidFormat->display_area.top; - reconfigParams.display_area.right = vidFormat->display_area.right; - reconfigParams.display_area.bottom = vidFormat->display_area.bottom; - width = resizeDim.width; - lumaHeight = resizeDim.height; - } - - if (cropRect.right && cropRect.bottom) { - reconfigParams.display_area.left = cropRect.left; - reconfigParams.display_area.top = cropRect.top; - reconfigParams.display_area.right = cropRect.right; - reconfigParams.display_area.bottom = cropRect.bottom; - width = cropRect.right - cropRect.left; - lumaHeight = cropRect.bottom - cropRect.top; - } - reconfigParams.ulTargetWidth = width; - reconfigParams.ulTargetHeight = lumaHeight; - } - + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + reconfigParams.ulTargetWidth = vidFormat->coded_width; + reconfigParams.ulTargetHeight = vidFormat->coded_height; chromaHeight = (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); surfaceHeight = reconfigParams.ulTargetHeight; @@ -454,7 +385,6 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) displayRect.left = reconfigParams.display_area.left; displayRect.right = reconfigParams.display_area.right; } - reconfigParams.ulNumDecodeSurfaces = nDecodeSurface; CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 408156f9e1b..82d38bf01cd 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -25,15 +25,11 @@ struct Rect { int left, top, right, bottom; }; -struct Dim { - int width, height; -}; - class Decoder { public: Decoder() {} ~Decoder(); - void init(CUcontext, cudaVideoCodec, const Rect * = NULL, const Dim * = NULL, bool = false, bool = false, int64_t = 0, int64_t = 0); + void init(CUcontext, cudaVideoCodec); int Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); void release(); uint8_t * FetchFrame(); @@ -53,7 +49,6 @@ class Decoder { CUvideoctxlock ctxLock; CUvideoparser parser = NULL; CUvideodecoder decoder = NULL; - bool forceZeroLatency = false; CUstream cuvidStream = 0; int numDecodedFrames = 0; unsigned int numChromaPlanes = 0; @@ -65,9 +60,7 @@ class Decoder { int bitDepthMinus8 = 0; int bytesPerPixel = 1; CUVIDEOFORMAT videoFormat = {}; - unsigned int nMaxWidth = 0, nMaxHeight = 0; - Rect cropRect = {}; - Dim resizeDim = {}; + unsigned int maxWidth = 0, maxHeight = 0; // height of the mapped surface int surfaceHeight = 0; int surfaceWidth = 0; From 643378546cda9dff2bfa7e6f2cdfd1390cb0f95f Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 17 Dec 2021 06:35:06 -0800 Subject: [PATCH 13/39] Fix warnings --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 14 +++++++------- torchvision/csrc/io/decoder/gpu/decoder.h | 3 +-- torchvision/csrc/io/decoder/gpu/demuxer.h | 4 ++-- torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp | 4 ++-- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 1a43739f861..9927ca74023 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -61,14 +61,14 @@ void Decoder::release() cuCtxPopCurrent(NULL); } -int Decoder::Decode(const uint8_t *data, int64_t size, int64_t flags, int64_t pts) +unsigned long Decoder::Decode(const uint8_t *data, unsigned long size) { numDecodedFrames = 0; CUVIDSOURCEDATAPACKET pkt = { - .flags = flags | CUVID_PKT_TIMESTAMP, + .flags = CUVID_PKT_TIMESTAMP, .payload_size = size, .payload = data, - .timestamp = pts + .timestamp = 0 }; if (!data || size == 0) { pkt.flags |= CUVID_PKT_ENDOFSTREAM; @@ -147,7 +147,7 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) m.srcPitch = nSrcPitch; m.dstMemoryType = CU_MEMORYTYPE_DEVICE; m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); - m.dstPitch = deviceFramePitch ? deviceFramePitch : GetWidth() * bytesPerPixel; + m.dstPitch = GetWidth() * bytesPerPixel; m.WidthInBytes = GetWidth() * bytesPerPixel; m.Height = lumaHeight; CheckForCudaErrors( @@ -262,7 +262,7 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT *vidFormat) } videoFormat = *vidFormat; - int nDecodeSurface = vidFormat->min_num_decode_surfaces; + unsigned long nDecodeSurface = vidFormat->min_num_decode_surfaces; cudaVideoDeinterlaceMode deinterlaceMode = cudaVideoDeinterlaceMode_Adaptive; if (vidFormat->progressive_sequence) { deinterlaceMode = cudaVideoDeinterlaceMode_Weave; @@ -288,10 +288,10 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT *vidFormat) maxWidth = vidFormatEx->av1.max_width; maxHeight = vidFormatEx->av1.max_height; } - if (maxWidth < (int)vidFormat->coded_width) { + if (maxWidth < vidFormat->coded_width) { maxWidth = vidFormat->coded_width; } - if (maxHeight < (int)vidFormat->coded_height) { + if (maxHeight < vidFormat->coded_height) { maxHeight = vidFormat->coded_height; } videoDecodeCreateInfo.ulMaxWidth = maxWidth; diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 82d38bf01cd..1bbbb88de16 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -30,7 +30,7 @@ class Decoder { Decoder() {} ~Decoder(); void init(CUcontext, cudaVideoCodec); - int Decode(const uint8_t *, int64_t, int64_t = 0, int64_t = 0); + unsigned long Decode(const uint8_t *, unsigned long); void release(); uint8_t * FetchFrame(); cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } @@ -70,7 +70,6 @@ class Decoder { int decodePicCount = 0, picNumInDecodeOrder[32]; bool reconfigExternal = false; bool reconfigExtPPChange = false; - size_t deviceFramePitch = 0; std::queue decodedFrames; static int CUDAAPI HandleVideoSequenceProc(void *pUserData, CUVIDEOFORMAT *pVideoFormat) { return ((Decoder *)pUserData)->HandleVideoSequence(pVideoFormat); } diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index 02e7c369823..3c2e294cf43 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -8,7 +8,7 @@ extern "C" { inline bool check(int ret, int line) { if (ret < 0) { - printf("Error %d at line %d in file %s.\n", ret, line); + printf("Error %d at line %d in demuxer.h.\n", ret, line); return false; } return true; @@ -117,7 +117,7 @@ class Demuxer { return eVideoCodec; } - bool Demux(uint8_t **video, int64_t *videoBytes) + bool Demux(uint8_t **video, unsigned long *videoBytes) { if (!fmtCtx) { return false; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index b65f95e080a..128efbe6922 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -26,12 +26,12 @@ GPUDecoder::~GPUDecoder() torch::Tensor GPUDecoder::decode() { torch::Tensor frameTensor; - int64_t videoBytes, numFrames; + unsigned long videoBytes = 0; uint8_t *frame = nullptr, *video = nullptr; do { demuxer.Demux(&video, &videoBytes); - numFrames = dec.Decode(video, videoBytes); + dec.Decode(video, videoBytes); frame = dec.FetchFrame(); } while (frame == nullptr && videoBytes > 0); if (frame == nullptr) { From 962962a08e1ec2cd17c08629e679ffcd02557c36 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 17 Dec 2021 11:32:33 -0800 Subject: [PATCH 14/39] Fix python linter errors --- test/test_video_gpu_decoder.py | 2 +- torchvision/io/__init__.py | 16 +--------------- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index dc9fb149007..2f91f2e2466 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -28,7 +28,7 @@ class TestVideoGPUDecoder: def test_frame_reading(self): for test_video in test_videos: full_path = os.path.join(VIDEO_DIR, test_video) - decoder = VideoReader(full_path, device='cuda:0') + decoder = VideoReader(full_path, device="cuda:0") with av.open(full_path) as container: for av_frame in container.decode(container.streams.video[0]): av_frames = torch.tensor(av_frame.to_ndarray().flatten()) diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index a4a250f0fa4..139491d63a4 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -39,25 +39,12 @@ def _has_video_opt() -> bool: return True - else: def _has_video_opt() -> bool: return False -if _HAS_VIDEO_DECODER: - - def _has_video_decoder() -> bool: - return True - - -else: - - def _has_video_decoder() -> bool: - return False - - class VideoReader: """ Fine-grained video-reading API. @@ -125,7 +112,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, devic self.is_cuda = False device = torch.device(device) if device.type == "cuda": - if not _has_video_decoder(): + if not _HAS_VIDEO_DECODER: raise RuntimeError("Not compiled with GPU decoder support.") self.is_cuda = True if device.index is None: @@ -227,7 +214,6 @@ def reformat(self, tensor, format: str = "yuv420"): return self._c.reformat(tensor.cpu()) - __all__ = [ "write_video", "read_video", From 4dd1798c919f43e1d163c77e0ff05401997452b3 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 17 Dec 2021 11:37:00 -0800 Subject: [PATCH 15/39] Fix linter issues in setup.py --- setup.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 1c1a99d74c6..5f375083f12 100644 --- a/setup.py +++ b/setup.py @@ -440,18 +440,16 @@ def get_extensions(): if video_codec_found and has_ffmpeg: gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") - gpu_decoder_src = ( - glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) - ) - cuda_libs = os.path.join(CUDA_HOME, 'lib64') - cuda_inc = os.path.join(CUDA_HOME, 'include') + gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) + cuda_libs = os.path.join(CUDA_HOME, "lib64") + cuda_inc = os.path.join(CUDA_HOME, "include") ext_modules.append( extension( "torchvision.Decoder", gpu_decoder_src, include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc], - library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs] + ['/usr/local/lib'], + library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs] + ["/usr/local/lib"], libraries=[ "avcodec", "avformat", From 116fd025fd359de7c408c9cd4887019c8f130c56 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 17 Dec 2021 12:37:44 -0800 Subject: [PATCH 16/39] clang-format --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 652 +++++++++--------- torchvision/csrc/io/decoder/gpu/decoder.h | 137 ++-- torchvision/csrc/io/decoder/gpu/demuxer.h | 322 +++++---- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 74 +- torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 24 +- 5 files changed, 633 insertions(+), 576 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 9927ca74023..ef2199800d0 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -1,411 +1,435 @@ +#include "decoder.h" #include -#include #include +#include #include #include -#include "decoder.h" -static float GetChromaHeightFactor(cudaVideoSurfaceFormat surfaceFormat) -{ - return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 1.0 : 0.5; +static float GetChromaHeightFactor(cudaVideoSurfaceFormat surfaceFormat) { + return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || + surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) + ? 1.0 + : 0.5; } -static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) -{ - return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 2 : 1; +static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) { + return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || + surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) + ? 2 + : 1; } -void Decoder::init(CUcontext context, cudaVideoCodec codec) -{ +void Decoder::init(CUcontext context, cudaVideoCodec codec) { cuContext = context; videoCodec = codec; - CheckForCudaErrors( - cuvidCtxLockCreate(&ctxLock, cuContext), - __LINE__); + CheckForCudaErrors(cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__); CUVIDPARSERPARAMS parserParams = { - .CodecType = codec, - .ulMaxNumDecodeSurfaces = 1, - .ulClockRate = 1000, - .ulMaxDisplayDelay = 0u, - .pUserData = this, - .pfnSequenceCallback = HandleVideoSequenceProc, - .pfnDecodePicture = HandlePictureDecodeProc, - .pfnDisplayPicture = HandlePictureDisplayProc, - .pfnGetOperatingPoint = HandleOperatingPointProc, + .CodecType = codec, + .ulMaxNumDecodeSurfaces = 1, + .ulClockRate = 1000, + .ulMaxDisplayDelay = 0u, + .pUserData = this, + .pfnSequenceCallback = HandleVideoSequenceProc, + .pfnDecodePicture = HandlePictureDecodeProc, + .pfnDisplayPicture = HandlePictureDisplayProc, + .pfnGetOperatingPoint = HandleOperatingPointProc, }; - CheckForCudaErrors( - cuvidCreateVideoParser(&parser, &parserParams), - __LINE__); + CheckForCudaErrors(cuvidCreateVideoParser(&parser, &parserParams), __LINE__); } -Decoder::~Decoder() -{ +Decoder::~Decoder() { if (parser) { cuvidDestroyVideoParser(parser); } cuvidCtxLockDestroy(ctxLock); } -void Decoder::release() -{ +void Decoder::release() { cuCtxPushCurrent(cuContext); if (decoder) { cuvidDestroyDecoder(decoder); } while (!decodedFrames.empty()) { - uint8_t *frame = decodedFrames.front(); + uint8_t* frame = decodedFrames.front(); decodedFrames.pop(); cuMemFree((CUdeviceptr)frame); } cuCtxPopCurrent(NULL); } -unsigned long Decoder::Decode(const uint8_t *data, unsigned long size) -{ +unsigned long Decoder::Decode(const uint8_t* data, unsigned long size) { numDecodedFrames = 0; CUVIDSOURCEDATAPACKET pkt = { - .flags = CUVID_PKT_TIMESTAMP, - .payload_size = size, - .payload = data, - .timestamp = 0 - }; + .flags = CUVID_PKT_TIMESTAMP, + .payload_size = size, + .payload = data, + .timestamp = 0}; if (!data || size == 0) { pkt.flags |= CUVID_PKT_ENDOFSTREAM; } - CheckForCudaErrors( - cuvidParseVideoData(parser, &pkt), - __LINE__); + CheckForCudaErrors(cuvidParseVideoData(parser, &pkt), __LINE__); cuvidStream = 0; return numDecodedFrames; } -uint8_t * Decoder::FetchFrame() -{ +uint8_t* Decoder::FetchFrame() { if (decodedFrames.empty()) { return nullptr; } - uint8_t *frame = decodedFrames.front(); + uint8_t* frame = decodedFrames.front(); decodedFrames.pop(); return frame; } -int Decoder::HandlePictureDecode(CUVIDPICPARAMS *picParams) -{ - if (!decoder) { - throw std::runtime_error("Uninitialised decoder."); - } - picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; - CheckForCudaErrors( - cuCtxPushCurrent(cuContext), - __LINE__); - CheckForCudaErrors( - cuvidDecodePicture(decoder, picParams), - __LINE__); - CheckForCudaErrors( - cuCtxPopCurrent(NULL), - __LINE__); - return 1; +int Decoder::HandlePictureDecode(CUVIDPICPARAMS* picParams) { + if (!decoder) { + throw std::runtime_error("Uninitialised decoder."); + } + picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors(cuvidDecodePicture(decoder, picParams), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + return 1; } -int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO *dispInfo) -{ - CUVIDPROCPARAMS procParams = { +int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { + CUVIDPROCPARAMS procParams = { .progressive_frame = dispInfo->progressive_frame, .second_field = dispInfo->repeat_first_field + 1, .top_field_first = dispInfo->top_field_first, .unpaired_field = dispInfo->repeat_first_field < 0, .output_stream = cuvidStream, - }; + }; - CUdeviceptr dpSrcFrame = 0; - unsigned int nSrcPitch = 0; - CheckForCudaErrors( - cuCtxPushCurrent(cuContext), - __LINE__); - CheckForCudaErrors( - cuvidMapVideoFrame(decoder, dispInfo->picture_index, &dpSrcFrame, &nSrcPitch, &procParams), + CUdeviceptr dpSrcFrame = 0; + unsigned int nSrcPitch = 0; + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors( + cuvidMapVideoFrame( + decoder, + dispInfo->picture_index, + &dpSrcFrame, + &nSrcPitch, + &procParams), __LINE__); - CUVIDGETDECODESTATUS decodeStatus; - memset(&decodeStatus, 0, sizeof(decodeStatus)); - CUresult result = cuvidGetDecodeStatus(decoder, dispInfo->picture_index, &decodeStatus); - if (result == CUDA_SUCCESS && - (decodeStatus.decodeStatus == cuvidDecodeStatus_Error || decodeStatus.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { - printf("Decode Error occurred for picture %d\n", picNumInDecodeOrder[dispInfo->picture_index]); - } - - uint8_t *decodedFrame = nullptr; - cuMemAlloc((CUdeviceptr *)&decodedFrame, GetFrameSize()); - - numDecodedFrames++; - - // Copy luma plane - CUDA_MEMCPY2D m = { 0 }; - m.srcMemoryType = CU_MEMORYTYPE_DEVICE; - m.srcDevice = dpSrcFrame; - m.srcPitch = nSrcPitch; - m.dstMemoryType = CU_MEMORYTYPE_DEVICE; - m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); - m.dstPitch = GetWidth() * bytesPerPixel; - m.WidthInBytes = GetWidth() * bytesPerPixel; - m.Height = lumaHeight; - CheckForCudaErrors( - cuMemcpy2DAsync(&m, cuvidStream), - __LINE__); + CUVIDGETDECODESTATUS decodeStatus; + memset(&decodeStatus, 0, sizeof(decodeStatus)); + CUresult result = + cuvidGetDecodeStatus(decoder, dispInfo->picture_index, &decodeStatus); + if (result == CUDA_SUCCESS && + (decodeStatus.decodeStatus == cuvidDecodeStatus_Error || + decodeStatus.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { + printf( + "Decode Error occurred for picture %d\n", + picNumInDecodeOrder[dispInfo->picture_index]); + } - // Copy chroma plane - // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning height - m.srcDevice = (CUdeviceptr)((uint8_t *)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1)); - m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight); + uint8_t* decodedFrame = nullptr; + cuMemAlloc((CUdeviceptr*)&decodedFrame, GetFrameSize()); + + numDecodedFrames++; + + // Copy luma plane + CUDA_MEMCPY2D m = {0}; + m.srcMemoryType = CU_MEMORYTYPE_DEVICE; + m.srcDevice = dpSrcFrame; + m.srcPitch = nSrcPitch; + m.dstMemoryType = CU_MEMORYTYPE_DEVICE; + m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); + m.dstPitch = GetWidth() * bytesPerPixel; + m.WidthInBytes = GetWidth() * bytesPerPixel; + m.Height = lumaHeight; + CheckForCudaErrors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + + // Copy chroma plane + // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning + // height + m.srcDevice = + (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1)); + m.dstDevice = + (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight); + m.Height = chromaHeight; + CheckForCudaErrors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + + if (numChromaPlanes == 2) { + m.srcDevice = + (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1) * 2); + m.dstDevice = + (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight * 2); m.Height = chromaHeight; - CheckForCudaErrors( - cuMemcpy2DAsync(&m, cuvidStream), - __LINE__); - - if (numChromaPlanes == 2) { - m.srcDevice = (CUdeviceptr)((uint8_t *)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1) * 2); - m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight * 2); - m.Height = chromaHeight; - CheckForCudaErrors( - cuMemcpy2DAsync(&m, cuvidStream), - __LINE__); - } - CheckForCudaErrors( - cuStreamSynchronize(cuvidStream), - __LINE__); - decodedFrames.push(decodedFrame); - CheckForCudaErrors( - cuCtxPopCurrent(NULL), - __LINE__); + CheckForCudaErrors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + } + CheckForCudaErrors(cuStreamSynchronize(cuvidStream), __LINE__); + decodedFrames.push(decodedFrame); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); - CheckForCudaErrors( - cuvidUnmapVideoFrame(decoder, dpSrcFrame), - __LINE__); - return 1; + CheckForCudaErrors(cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__); + return 1; } -void Decoder::queryHardware(CUVIDEOFORMAT *videoFormat) -{ - CUVIDDECODECAPS decodeCaps = { +void Decoder::queryHardware(CUVIDEOFORMAT* videoFormat) { + CUVIDDECODECAPS decodeCaps = { .eCodecType = videoFormat->codec, .eChromaFormat = videoFormat->chroma_format, .nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8, - }; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors(cuvidGetDecoderCaps(&decodeCaps), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + }; + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors(cuvidGetDecoderCaps(&decodeCaps), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); - if(!decodeCaps.bIsSupported) { - throw std::runtime_error("Codec not supported on this GPU"); - } - if ((videoFormat->coded_width > decodeCaps.nMaxWidth) || - (videoFormat->coded_height > decodeCaps.nMaxHeight)) { - std::ostringstream errorString; - errorString << std::endl - << "Resolution : " << videoFormat->coded_width << "x" << videoFormat->coded_height << std::endl - << "Max Supported (wxh) : " << decodeCaps.nMaxWidth << "x" << decodeCaps.nMaxHeight << std::endl - << "Resolution not supported on this GPU"; - - const std::string cErr = errorString.str(); - throw std::runtime_error(cErr); - } - if ((videoFormat->coded_width>>4)*(videoFormat->coded_height>>4) > decodeCaps.nMaxMBCount) { - std::ostringstream errorString; - errorString << std::endl - << "MBCount : " << (videoFormat->coded_width >> 4)*(videoFormat->coded_height >> 4) << std::endl - << "Max Supported mbcnt : " << decodeCaps.nMaxMBCount << std::endl - << "MBCount not supported on this GPU"; - - const std::string cErr = errorString.str(); - throw std::runtime_error(cErr); - } - // Check if output format supported. If not, check fallback options - if (!(decodeCaps.nOutputFormatMask & (1 << videoOutputFormat))) { - if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) - videoOutputFormat = cudaVideoSurfaceFormat_NV12; - else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) - videoOutputFormat = cudaVideoSurfaceFormat_P016; - else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) - videoOutputFormat = cudaVideoSurfaceFormat_YUV444; - else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) - videoOutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit; - else - throw std::runtime_error("No supported output format found"); - } + if (!decodeCaps.bIsSupported) { + throw std::runtime_error("Codec not supported on this GPU"); + } + if ((videoFormat->coded_width > decodeCaps.nMaxWidth) || + (videoFormat->coded_height > decodeCaps.nMaxHeight)) { + std::ostringstream errorString; + errorString << std::endl + << "Resolution : " << videoFormat->coded_width << "x" + << videoFormat->coded_height << std::endl + << "Max Supported (wxh) : " << decodeCaps.nMaxWidth << "x" + << decodeCaps.nMaxHeight << std::endl + << "Resolution not supported on this GPU"; + + const std::string cErr = errorString.str(); + throw std::runtime_error(cErr); + } + if ((videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4) > + decodeCaps.nMaxMBCount) { + std::ostringstream errorString; + errorString << std::endl + << "MBCount : " + << (videoFormat->coded_width >> 4) * + (videoFormat->coded_height >> 4) + << std::endl + << "Max Supported mbcnt : " << decodeCaps.nMaxMBCount + << std::endl + << "MBCount not supported on this GPU"; + + const std::string cErr = errorString.str(); + throw std::runtime_error(cErr); + } + // Check if output format supported. If not, check fallback options + if (!(decodeCaps.nOutputFormatMask & (1 << videoOutputFormat))) { + if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) + videoOutputFormat = cudaVideoSurfaceFormat_NV12; + else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) + videoOutputFormat = cudaVideoSurfaceFormat_P016; + else if ( + decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) + videoOutputFormat = cudaVideoSurfaceFormat_YUV444; + else if ( + decodeCaps.nOutputFormatMask & + (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) + videoOutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit; + else + throw std::runtime_error("No supported output format found"); + } } -int Decoder::HandleVideoSequence(CUVIDEOFORMAT *vidFormat) -{ - // videoCodec has been set in the constructor (for parser). Here it's set again for potential correction - videoCodec = vidFormat->codec; - videoChromaFormat = vidFormat->chroma_format; - bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; - bytesPerPixel = bitDepthMinus8 > 0 ? 2 : 1; - // Set the output surface format same as chroma format - switch (videoChromaFormat) { - case cudaVideoChromaFormat_Monochrome: - case cudaVideoChromaFormat_420: - videoOutputFormat = vidFormat->bit_depth_luma_minus8 ? cudaVideoSurfaceFormat_P016 : cudaVideoSurfaceFormat_NV12; - break; - case cudaVideoChromaFormat_444: - videoOutputFormat = vidFormat->bit_depth_luma_minus8 ? cudaVideoSurfaceFormat_YUV444_16Bit : cudaVideoSurfaceFormat_YUV444; - break; - case cudaVideoChromaFormat_422: - videoOutputFormat = cudaVideoSurfaceFormat_NV12; // no 4:2:2 output format supported yet so make 420 default - } +int Decoder::HandleVideoSequence(CUVIDEOFORMAT* vidFormat) { + // videoCodec has been set in the constructor (for parser). Here it's set + // again for potential correction + videoCodec = vidFormat->codec; + videoChromaFormat = vidFormat->chroma_format; + bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; + bytesPerPixel = bitDepthMinus8 > 0 ? 2 : 1; + // Set the output surface format same as chroma format + switch (videoChromaFormat) { + case cudaVideoChromaFormat_Monochrome: + case cudaVideoChromaFormat_420: + videoOutputFormat = vidFormat->bit_depth_luma_minus8 + ? cudaVideoSurfaceFormat_P016 + : cudaVideoSurfaceFormat_NV12; + break; + case cudaVideoChromaFormat_444: + videoOutputFormat = vidFormat->bit_depth_luma_minus8 + ? cudaVideoSurfaceFormat_YUV444_16Bit + : cudaVideoSurfaceFormat_YUV444; + break; + case cudaVideoChromaFormat_422: + videoOutputFormat = + cudaVideoSurfaceFormat_NV12; // no 4:2:2 output format supported yet + // so make 420 default + } - queryHardware(vidFormat); - if (width && lumaHeight && chromaHeight) { - // cuvidCreateDecoder() has been called before, and now there's possible config change - return ReconfigureDecoder(vidFormat); - } + queryHardware(vidFormat); + if (width && lumaHeight && chromaHeight) { + // cuvidCreateDecoder() has been called before, and now there's possible + // config change + return ReconfigureDecoder(vidFormat); + } - videoFormat = *vidFormat; - unsigned long nDecodeSurface = vidFormat->min_num_decode_surfaces; - cudaVideoDeinterlaceMode deinterlaceMode = cudaVideoDeinterlaceMode_Adaptive; - if (vidFormat->progressive_sequence) { - deinterlaceMode = cudaVideoDeinterlaceMode_Weave; - } + videoFormat = *vidFormat; + unsigned long nDecodeSurface = vidFormat->min_num_decode_surfaces; + cudaVideoDeinterlaceMode deinterlaceMode = cudaVideoDeinterlaceMode_Adaptive; + if (vidFormat->progressive_sequence) { + deinterlaceMode = cudaVideoDeinterlaceMode_Weave; + } - CUVIDDECODECREATEINFO videoDecodeCreateInfo = { + CUVIDDECODECREATEINFO videoDecodeCreateInfo = { .ulWidth = vidFormat->coded_width, .ulHeight = vidFormat->coded_height, .ulNumDecodeSurfaces = nDecodeSurface, .CodecType = vidFormat->codec, .ChromaFormat = vidFormat->chroma_format, - // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded by NVDEC hardware + // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded + // by NVDEC hardware .ulCreationFlags = cudaVideoCreate_PreferCUVID, .bitDepthMinus8 = vidFormat->bit_depth_luma_minus8, .OutputFormat = videoOutputFormat, .DeinterlaceMode = deinterlaceMode, .ulNumOutputSurfaces = 2, .vidLock = ctxLock, - }; - // AV1 has max width/height of sequence in sequence header - if (vidFormat->codec == cudaVideoCodec_AV1 && vidFormat->seqhdr_data_length > 0) { - CUVIDEOFORMATEX *vidFormatEx = (CUVIDEOFORMATEX *)vidFormat; - maxWidth = vidFormatEx->av1.max_width; - maxHeight = vidFormatEx->av1.max_height; - } - if (maxWidth < vidFormat->coded_width) { - maxWidth = vidFormat->coded_width; - } - if (maxHeight < vidFormat->coded_height) { - maxHeight = vidFormat->coded_height; - } - videoDecodeCreateInfo.ulMaxWidth = maxWidth; - videoDecodeCreateInfo.ulMaxHeight = maxHeight; - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; - videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; - chromaHeight = (int)(ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat))); - numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); - surfaceHeight = videoDecodeCreateInfo.ulTargetHeight; - surfaceWidth = videoDecodeCreateInfo.ulTargetWidth; - displayRect.bottom = videoDecodeCreateInfo.display_area.bottom; - displayRect.top = videoDecodeCreateInfo.display_area.top; - displayRect.left = videoDecodeCreateInfo.display_area.left; - displayRect.right = videoDecodeCreateInfo.display_area.right; - - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors(cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); - return nDecodeSurface; + }; + // AV1 has max width/height of sequence in sequence header + if (vidFormat->codec == cudaVideoCodec_AV1 && + vidFormat->seqhdr_data_length > 0) { + CUVIDEOFORMATEX* vidFormatEx = (CUVIDEOFORMATEX*)vidFormat; + maxWidth = vidFormatEx->av1.max_width; + maxHeight = vidFormatEx->av1.max_height; + } + if (maxWidth < vidFormat->coded_width) { + maxWidth = vidFormat->coded_width; + } + if (maxHeight < vidFormat->coded_height) { + maxHeight = vidFormat->coded_height; + } + videoDecodeCreateInfo.ulMaxWidth = maxWidth; + videoDecodeCreateInfo.ulMaxHeight = maxHeight; + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; + videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; + chromaHeight = + (int)(ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat))); + numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + surfaceHeight = videoDecodeCreateInfo.ulTargetHeight; + surfaceWidth = videoDecodeCreateInfo.ulTargetWidth; + displayRect.bottom = videoDecodeCreateInfo.display_area.bottom; + displayRect.top = videoDecodeCreateInfo.display_area.top; + displayRect.left = videoDecodeCreateInfo.display_area.left; + displayRect.right = videoDecodeCreateInfo.display_area.right; + + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors( + cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + return nDecodeSurface; } -int Decoder::ReconfigureDecoder(CUVIDEOFORMAT *vidFormat) -{ - if (vidFormat->bit_depth_luma_minus8 != videoFormat.bit_depth_luma_minus8 || vidFormat->bit_depth_chroma_minus8 != videoFormat.bit_depth_chroma_minus8) { - throw std::runtime_error("Reconfigure Not supported for bit depth change"); - } - if (vidFormat->chroma_format != videoFormat.chroma_format) { - throw std::runtime_error("Reconfigure Not supported for chroma format change"); - } - - bool bDecodeResChange = !(vidFormat->coded_width == videoFormat.coded_width && vidFormat->coded_height == videoFormat.coded_height); - bool bDisplayRectChange = !(vidFormat->display_area.bottom == videoFormat.display_area.bottom && vidFormat->display_area.top == videoFormat.display_area.top \ - && vidFormat->display_area.left == videoFormat.display_area.left && vidFormat->display_area.right == videoFormat.display_area.right); - - int nDecodeSurface = vidFormat->min_num_decode_surfaces; +int Decoder::ReconfigureDecoder(CUVIDEOFORMAT* vidFormat) { + if (vidFormat->bit_depth_luma_minus8 != videoFormat.bit_depth_luma_minus8 || + vidFormat->bit_depth_chroma_minus8 != + videoFormat.bit_depth_chroma_minus8) { + throw std::runtime_error("Reconfigure Not supported for bit depth change"); + } + if (vidFormat->chroma_format != videoFormat.chroma_format) { + throw std::runtime_error( + "Reconfigure Not supported for chroma format change"); + } - if ((vidFormat->coded_width > maxWidth) || (vidFormat->coded_height > maxHeight)) { - // For VP9, let driver handle the change if new width/height > maxwidth/maxheight - if ((videoCodec != cudaVideoCodec_VP9) || reconfigExternal) { - throw std::runtime_error("Reconfigure Not supported when width/height > maxwidth/maxheight"); - } - return 1; + bool bDecodeResChange = + !(vidFormat->coded_width == videoFormat.coded_width && + vidFormat->coded_height == videoFormat.coded_height); + bool bDisplayRectChange = + !(vidFormat->display_area.bottom == videoFormat.display_area.bottom && + vidFormat->display_area.top == videoFormat.display_area.top && + vidFormat->display_area.left == videoFormat.display_area.left && + vidFormat->display_area.right == videoFormat.display_area.right); + + int nDecodeSurface = vidFormat->min_num_decode_surfaces; + + if ((vidFormat->coded_width > maxWidth) || + (vidFormat->coded_height > maxHeight)) { + // For VP9, let driver handle the change if new width/height > + // maxwidth/maxheight + if ((videoCodec != cudaVideoCodec_VP9) || reconfigExternal) { + throw std::runtime_error( + "Reconfigure Not supported when width/height > maxwidth/maxheight"); } + return 1; + } - if (!bDecodeResChange && !reconfigExtPPChange) { - // if the coded_width/coded_height hasn't changed but display resolution has changed, then need to update width/height for - // correct output without cropping. Example : 1920x1080 vs 1920x1088 - if (bDisplayRectChange) { - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - chromaHeight = (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); - numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); - } - - // no need for reconfigureDecoder(). Just return - return 1; + if (!bDecodeResChange && !reconfigExtPPChange) { + // if the coded_width/coded_height hasn't changed but display resolution has + // changed, then need to update width/height for correct output without + // cropping. Example : 1920x1080 vs 1920x1088 + if (bDisplayRectChange) { + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + chromaHeight = + (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); + numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); } - CUVIDRECONFIGUREDECODERINFO reconfigParams = { 0 }; - - reconfigParams.ulWidth = videoFormat.coded_width = vidFormat->coded_width; - reconfigParams.ulHeight = videoFormat.coded_height = vidFormat->coded_height; - - // Dont change display rect and get scaled output from decoder. This will help display app to present apps smoothly - reconfigParams.display_area.bottom = displayRect.bottom; - reconfigParams.display_area.top = displayRect.top; - reconfigParams.display_area.left = displayRect.left; - reconfigParams.display_area.right = displayRect.right; - reconfigParams.ulTargetWidth = surfaceWidth; - reconfigParams.ulTargetHeight = surfaceHeight; - - // If external reconfigure is called along with resolution change even if post processing params is not changed, - // do full reconfigure params update - if ((reconfigExternal && bDecodeResChange) || reconfigExtPPChange) { - // update display rect and target resolution if requested explicitely - reconfigExternal = false; - reconfigExtPPChange = false; - videoFormat = *vidFormat; - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - reconfigParams.ulTargetWidth = vidFormat->coded_width; - reconfigParams.ulTargetHeight = vidFormat->coded_height; - chromaHeight = (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); - numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); - surfaceHeight = reconfigParams.ulTargetHeight; - surfaceWidth = reconfigParams.ulTargetWidth; - displayRect.bottom = reconfigParams.display_area.bottom; - displayRect.top = reconfigParams.display_area.top; - displayRect.left = reconfigParams.display_area.left; - displayRect.right = reconfigParams.display_area.right; - } - reconfigParams.ulNumDecodeSurfaces = nDecodeSurface; + // no need for reconfigureDecoder(). Just return + return 1; + } - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors(cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + CUVIDRECONFIGUREDECODERINFO reconfigParams = {0}; + + reconfigParams.ulWidth = videoFormat.coded_width = vidFormat->coded_width; + reconfigParams.ulHeight = videoFormat.coded_height = vidFormat->coded_height; + + // Dont change display rect and get scaled output from decoder. This will help + // display app to present apps smoothly + reconfigParams.display_area.bottom = displayRect.bottom; + reconfigParams.display_area.top = displayRect.top; + reconfigParams.display_area.left = displayRect.left; + reconfigParams.display_area.right = displayRect.right; + reconfigParams.ulTargetWidth = surfaceWidth; + reconfigParams.ulTargetHeight = surfaceHeight; + + // If external reconfigure is called along with resolution change even if post + // processing params is not changed, do full reconfigure params update + if ((reconfigExternal && bDecodeResChange) || reconfigExtPPChange) { + // update display rect and target resolution if requested explicitely + reconfigExternal = false; + reconfigExtPPChange = false; + videoFormat = *vidFormat; + width = vidFormat->display_area.right - vidFormat->display_area.left; + lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; + reconfigParams.ulTargetWidth = vidFormat->coded_width; + reconfigParams.ulTargetHeight = vidFormat->coded_height; + chromaHeight = + (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); + numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + surfaceHeight = reconfigParams.ulTargetHeight; + surfaceWidth = reconfigParams.ulTargetWidth; + displayRect.bottom = reconfigParams.display_area.bottom; + displayRect.top = reconfigParams.display_area.top; + displayRect.left = reconfigParams.display_area.left; + displayRect.right = reconfigParams.display_area.right; + } + reconfigParams.ulNumDecodeSurfaces = nDecodeSurface; + + CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); + CheckForCudaErrors( + cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__); + CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); - return nDecodeSurface; + return nDecodeSurface; } -int Decoder::GetOperatingPoint(CUVIDOPERATINGPOINTINFO *operPointInfo) -{ - if (operPointInfo->codec == cudaVideoCodec_AV1) { - if (operPointInfo->av1.operating_points_cnt > 1) { - // clip has SVC enabled - if (operatingPoint >= operPointInfo->av1.operating_points_cnt) - operatingPoint = 0; - - printf("AV1 SVC clip: operating point count %d ", operPointInfo->av1.operating_points_cnt); - printf("Selected operating point: %d, IDC 0x%x bOutputAllLayers %d\n", operatingPoint, operPointInfo->av1.operating_points_idc[operatingPoint], dispAllLayers); - return (operatingPoint | (dispAllLayers << 10)); - } +int Decoder::GetOperatingPoint(CUVIDOPERATINGPOINTINFO* operPointInfo) { + if (operPointInfo->codec == cudaVideoCodec_AV1) { + if (operPointInfo->av1.operating_points_cnt > 1) { + // clip has SVC enabled + if (operatingPoint >= operPointInfo->av1.operating_points_cnt) + operatingPoint = 0; + + printf( + "AV1 SVC clip: operating point count %d ", + operPointInfo->av1.operating_points_cnt); + printf( + "Selected operating point: %d, IDC 0x%x bOutputAllLayers %d\n", + operatingPoint, + operPointInfo->av1.operating_points_idc[operatingPoint], + dispAllLayers); + return (operatingPoint | (dispAllLayers << 10)); } - return -1; + } + return -1; } diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 1bbbb88de16..8079bf93e3b 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -1,15 +1,15 @@ +#include +#include #include #include #include #include -#include -#include +#include -static auto CheckForCudaErrors = [](CUresult result, int lineNum) -{ +static auto CheckForCudaErrors = [](CUresult result, int lineNum) { if (CUDA_SUCCESS != result) { std::stringstream errorStream; - const char *errorName = nullptr; + const char* errorName = nullptr; errorStream << __FILE__ << ":" << lineNum << std::endl; if (CUDA_SUCCESS != cuGetErrorName(result, &errorName)) { @@ -26,62 +26,79 @@ struct Rect { }; class Decoder { - public: - Decoder() {} - ~Decoder(); - void init(CUcontext, cudaVideoCodec); - unsigned long Decode(const uint8_t *, unsigned long); - void release(); - uint8_t * FetchFrame(); - cudaVideoSurfaceFormat GetOutputFormat() const { return videoOutputFormat; } - int GetFrameSize() const - { - return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; - } - int GetWidth() const - { - return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || videoOutputFormat == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; - } - int GetHeight() const { return lumaHeight; } + public: + Decoder() {} + ~Decoder(); + void init(CUcontext, cudaVideoCodec); + unsigned long Decode(const uint8_t*, unsigned long); + void release(); + uint8_t* FetchFrame(); + cudaVideoSurfaceFormat GetOutputFormat() const { + return videoOutputFormat; + } + int GetFrameSize() const { + return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * + bytesPerPixel; + } + int GetWidth() const { + return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || + videoOutputFormat == cudaVideoSurfaceFormat_P016) + ? (width + 1) & ~1 + : width; + } + int GetHeight() const { + return lumaHeight; + } - private: - CUcontext cuContext = NULL; - CUvideoctxlock ctxLock; - CUvideoparser parser = NULL; - CUvideodecoder decoder = NULL; - CUstream cuvidStream = 0; - int numDecodedFrames = 0; - unsigned int numChromaPlanes = 0; - // dimension of the output - unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; - cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; - cudaVideoChromaFormat videoChromaFormat = cudaVideoChromaFormat_420; - cudaVideoSurfaceFormat videoOutputFormat = cudaVideoSurfaceFormat_NV12; - int bitDepthMinus8 = 0; - int bytesPerPixel = 1; - CUVIDEOFORMAT videoFormat = {}; - unsigned int maxWidth = 0, maxHeight = 0; - // height of the mapped surface - int surfaceHeight = 0; - int surfaceWidth = 0; - Rect displayRect = {}; - unsigned int operatingPoint = 0; - bool dispAllLayers = false; - int decodePicCount = 0, picNumInDecodeOrder[32]; - bool reconfigExternal = false; - bool reconfigExtPPChange = false; - std::queue decodedFrames; + private: + CUcontext cuContext = NULL; + CUvideoctxlock ctxLock; + CUvideoparser parser = NULL; + CUvideodecoder decoder = NULL; + CUstream cuvidStream = 0; + int numDecodedFrames = 0; + unsigned int numChromaPlanes = 0; + // dimension of the output + unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; + cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; + cudaVideoChromaFormat videoChromaFormat = cudaVideoChromaFormat_420; + cudaVideoSurfaceFormat videoOutputFormat = cudaVideoSurfaceFormat_NV12; + int bitDepthMinus8 = 0; + int bytesPerPixel = 1; + CUVIDEOFORMAT videoFormat = {}; + unsigned int maxWidth = 0, maxHeight = 0; + // height of the mapped surface + int surfaceHeight = 0; + int surfaceWidth = 0; + Rect displayRect = {}; + unsigned int operatingPoint = 0; + bool dispAllLayers = false; + int decodePicCount = 0, picNumInDecodeOrder[32]; + bool reconfigExternal = false; + bool reconfigExtPPChange = false; + std::queue decodedFrames; - static int CUDAAPI HandleVideoSequenceProc(void *pUserData, CUVIDEOFORMAT *pVideoFormat) { return ((Decoder *)pUserData)->HandleVideoSequence(pVideoFormat); } - static int CUDAAPI HandlePictureDecodeProc(void *pUserData, CUVIDPICPARAMS *pPicParams) { return ((Decoder *)pUserData)->HandlePictureDecode(pPicParams); } - static int CUDAAPI HandlePictureDisplayProc(void *pUserData, CUVIDPARSERDISPINFO *pDispInfo) { return ((Decoder *)pUserData)->HandlePictureDisplay(pDispInfo); } - static int CUDAAPI HandleOperatingPointProc(void *pUserData, CUVIDOPERATINGPOINTINFO *pOPInfo) { return ((Decoder *)pUserData)->GetOperatingPoint(pOPInfo); } + static int CUDAAPI + HandleVideoSequenceProc(void* pUserData, CUVIDEOFORMAT* pVideoFormat) { + return ((Decoder*)pUserData)->HandleVideoSequence(pVideoFormat); + } + static int CUDAAPI + HandlePictureDecodeProc(void* pUserData, CUVIDPICPARAMS* pPicParams) { + return ((Decoder*)pUserData)->HandlePictureDecode(pPicParams); + } + static int CUDAAPI + HandlePictureDisplayProc(void* pUserData, CUVIDPARSERDISPINFO* pDispInfo) { + return ((Decoder*)pUserData)->HandlePictureDisplay(pDispInfo); + } + static int CUDAAPI + HandleOperatingPointProc(void* pUserData, CUVIDOPERATINGPOINTINFO* pOPInfo) { + return ((Decoder*)pUserData)->GetOperatingPoint(pOPInfo); + } - void queryHardware(CUVIDEOFORMAT *videoFormat); - int ReconfigureDecoder(CUVIDEOFORMAT *pVideoFormat); - int HandleVideoSequence(CUVIDEOFORMAT *pVideoFormat); - int HandlePictureDecode(CUVIDPICPARAMS *pPicParams); - int HandlePictureDisplay(CUVIDPARSERDISPINFO *pDispInfo); - int GetOperatingPoint(CUVIDOPERATINGPOINTINFO *pOPInfo); + void queryHardware(CUVIDEOFORMAT* videoFormat); + int ReconfigureDecoder(CUVIDEOFORMAT* pVideoFormat); + int HandleVideoSequence(CUVIDEOFORMAT* pVideoFormat); + int HandlePictureDecode(CUVIDPICPARAMS* pPicParams); + int HandlePictureDisplay(CUVIDPARSERDISPINFO* pDispInfo); + int GetOperatingPoint(CUVIDOPERATINGPOINTINFO* pOPInfo); }; - diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index 3c2e294cf43..bb72ae3c3ea 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -5,8 +5,7 @@ extern "C" { #include } -inline bool check(int ret, int line) -{ +inline bool check(int ret, int line) { if (ret < 0) { printf("Error %d at line %d in demuxer.h.\n", ret, line); return false; @@ -17,171 +16,190 @@ inline bool check(int ret, int line) #define check_for_errors(call) check(call, __LINE__) class Demuxer { - private: - AVFormatContext *fmtCtx = NULL; - AVBSFContext *bsfCtx = NULL; - AVPacket pkt, pktFiltered; - AVCodecID eVideoCodec; - uint8_t *pDataWithHeader = NULL; - bool bMp4H264, bMp4HEVC, bMp4MPEG4; - unsigned int frameCount = 0; - int iVideoStream; - int64_t userTimeScale = 0; - double timeBase = 0.0; + private: + AVFormatContext* fmtCtx = NULL; + AVBSFContext* bsfCtx = NULL; + AVPacket pkt, pktFiltered; + AVCodecID eVideoCodec; + uint8_t* pDataWithHeader = NULL; + bool bMp4H264, bMp4HEVC, bMp4MPEG4; + unsigned int frameCount = 0; + int iVideoStream; + int64_t userTimeScale = 0; + double timeBase = 0.0; - public: - Demuxer(const char *filePath, int64_t timeScale = 1000 /*Hz*/) - { - avformat_network_init(); - check_for_errors(avformat_open_input(&fmtCtx, filePath, NULL, NULL)); - if (!fmtCtx) { - printf("No AVFormatContext provided.\n"); - return; - } + public: + Demuxer(const char* filePath, int64_t timeScale = 1000 /*Hz*/) { + avformat_network_init(); + check_for_errors(avformat_open_input(&fmtCtx, filePath, NULL, NULL)); + if (!fmtCtx) { + printf("No AVFormatContext provided.\n"); + return; + } - check_for_errors(avformat_find_stream_info(fmtCtx, NULL)); - iVideoStream = av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); - if (iVideoStream < 0) { - printf("FFmpeg error: %d, could not find stream in input file\n", __LINE__); - return; - } + check_for_errors(avformat_find_stream_info(fmtCtx, NULL)); + iVideoStream = + av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); + if (iVideoStream < 0) { + printf( + "FFmpeg error: %d, could not find stream in input file\n", __LINE__); + return; + } - eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id; - AVRational rTimeBase = fmtCtx->streams[iVideoStream]->time_base; - timeBase = av_q2d(rTimeBase); - userTimeScale = timeScale; + eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id; + AVRational rTimeBase = fmtCtx->streams[iVideoStream]->time_base; + timeBase = av_q2d(rTimeBase); + userTimeScale = timeScale; - bMp4H264 = eVideoCodec == AV_CODEC_ID_H264 && ( - !strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") - || !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") - || !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); - bMp4HEVC = eVideoCodec == AV_CODEC_ID_HEVC && ( - !strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") - || !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") - || !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); - bMp4MPEG4 = eVideoCodec == AV_CODEC_ID_MPEG4 && ( - !strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") - || !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") - || !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4H264 = eVideoCodec == AV_CODEC_ID_H264 && + (!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") || + !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") || + !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4HEVC = eVideoCodec == AV_CODEC_ID_HEVC && + (!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") || + !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") || + !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); + bMp4MPEG4 = eVideoCodec == AV_CODEC_ID_MPEG4 && + (!strcmp(fmtCtx->iformat->long_name, "QuickTime / MOV") || + !strcmp(fmtCtx->iformat->long_name, "FLV (Flash Video)") || + !strcmp(fmtCtx->iformat->long_name, "Matroska / WebM")); - av_init_packet(&pkt); - pkt.data = NULL; - pkt.size = 0; - av_init_packet(&pktFiltered); - pktFiltered.data = NULL; - pktFiltered.size = 0; + av_init_packet(&pkt); + pkt.data = NULL; + pkt.size = 0; + av_init_packet(&pktFiltered); + pktFiltered.data = NULL; + pktFiltered.size = 0; - if (bMp4H264) { - const AVBitStreamFilter *bsf = av_bsf_get_by_name("h264_mp4toannexb"); - if (!bsf) { - printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); - return; - } - check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); - avcodec_parameters_copy(bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); - check_for_errors(av_bsf_init(bsfCtx)); - } - if (bMp4HEVC) { - const AVBitStreamFilter *bsf = av_bsf_get_by_name("hevc_mp4toannexb"); - if (!bsf) { - printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); - return; - } - check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); - avcodec_parameters_copy(bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); - check_for_errors(av_bsf_init(bsfCtx)); - } + if (bMp4H264) { + const AVBitStreamFilter* bsf = av_bsf_get_by_name("h264_mp4toannexb"); + if (!bsf) { + printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); + return; + } + check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); + avcodec_parameters_copy( + bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); + check_for_errors(av_bsf_init(bsfCtx)); } - ~Demuxer() - { - if (!fmtCtx) { - return; - } - if (pkt.data) { - av_packet_unref(&pkt); - } - if (pktFiltered.data) { - av_packet_unref(&pktFiltered); - } - if (bsfCtx) { - av_bsf_free(&bsfCtx); - } - avformat_close_input(&fmtCtx); - if (pDataWithHeader) { - av_free(pDataWithHeader); - } + if (bMp4HEVC) { + const AVBitStreamFilter* bsf = av_bsf_get_by_name("hevc_mp4toannexb"); + if (!bsf) { + printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); + return; + } + check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); + avcodec_parameters_copy( + bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); + check_for_errors(av_bsf_init(bsfCtx)); } - - AVCodecID GetVideoCodec() - { - return eVideoCodec; + } + ~Demuxer() { + if (!fmtCtx) { + return; + } + if (pkt.data) { + av_packet_unref(&pkt); + } + if (pktFiltered.data) { + av_packet_unref(&pktFiltered); + } + if (bsfCtx) { + av_bsf_free(&bsfCtx); + } + avformat_close_input(&fmtCtx); + if (pDataWithHeader) { + av_free(pDataWithHeader); } + } - bool Demux(uint8_t **video, unsigned long *videoBytes) - { - if (!fmtCtx) { - return false; - } - *videoBytes = 0; + AVCodecID GetVideoCodec() { + return eVideoCodec; + } - if (pkt.data) { - av_packet_unref(&pkt); - } - int e = 0; - while ((e = av_read_frame(fmtCtx, &pkt)) >= 0 && pkt.stream_index != iVideoStream) { - av_packet_unref(&pkt); - } - if (e < 0) { - return false; - } + bool Demux(uint8_t** video, unsigned long* videoBytes) { + if (!fmtCtx) { + return false; + } + *videoBytes = 0; - if (bMp4H264 || bMp4HEVC) { - if (pktFiltered.data) { - av_packet_unref(&pktFiltered); - } - check_for_errors(av_bsf_send_packet(bsfCtx, &pkt)); - check_for_errors(av_bsf_receive_packet(bsfCtx, &pktFiltered)); - *video = pktFiltered.data; - *videoBytes = pktFiltered.size; - } else { - if (bMp4MPEG4 && (frameCount == 0)) { - int extraDataSize = fmtCtx->streams[iVideoStream]->codecpar->extradata_size; + if (pkt.data) { + av_packet_unref(&pkt); + } + int e = 0; + while ((e = av_read_frame(fmtCtx, &pkt)) >= 0 && + pkt.stream_index != iVideoStream) { + av_packet_unref(&pkt); + } + if (e < 0) { + return false; + } + + if (bMp4H264 || bMp4HEVC) { + if (pktFiltered.data) { + av_packet_unref(&pktFiltered); + } + check_for_errors(av_bsf_send_packet(bsfCtx, &pkt)); + check_for_errors(av_bsf_receive_packet(bsfCtx, &pktFiltered)); + *video = pktFiltered.data; + *videoBytes = pktFiltered.size; + } else { + if (bMp4MPEG4 && (frameCount == 0)) { + int extraDataSize = + fmtCtx->streams[iVideoStream]->codecpar->extradata_size; - if (extraDataSize > 0) { - pDataWithHeader = (uint8_t *)av_malloc(extraDataSize + pkt.size - 3 * sizeof(uint8_t)); - if (!pDataWithHeader) { - printf("FFmpeg error: %d\n", __LINE__); - return false; - } - memcpy(pDataWithHeader, fmtCtx->streams[iVideoStream]->codecpar->extradata, extraDataSize); - memcpy(pDataWithHeader+extraDataSize, pkt.data+3, pkt.size - 3 * sizeof(uint8_t)); - *video = pDataWithHeader; - *videoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); - } - } else { - *video = pkt.data; - *videoBytes = pkt.size; - } + if (extraDataSize > 0) { + pDataWithHeader = (uint8_t*)av_malloc( + extraDataSize + pkt.size - 3 * sizeof(uint8_t)); + if (!pDataWithHeader) { + printf("FFmpeg error: %d\n", __LINE__); + return false; + } + memcpy( + pDataWithHeader, + fmtCtx->streams[iVideoStream]->codecpar->extradata, + extraDataSize); + memcpy( + pDataWithHeader + extraDataSize, + pkt.data + 3, + pkt.size - 3 * sizeof(uint8_t)); + *video = pDataWithHeader; + *videoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); } - frameCount++; - return true; + } else { + *video = pkt.data; + *videoBytes = pkt.size; + } } + frameCount++; + return true; + } }; -inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) -{ - switch (id) { - case AV_CODEC_ID_MPEG1VIDEO : return cudaVideoCodec_MPEG1; - case AV_CODEC_ID_MPEG2VIDEO : return cudaVideoCodec_MPEG2; - case AV_CODEC_ID_MPEG4 : return cudaVideoCodec_MPEG4; - case AV_CODEC_ID_WMV3 : - case AV_CODEC_ID_VC1 : return cudaVideoCodec_VC1; - case AV_CODEC_ID_H264 : return cudaVideoCodec_H264; - case AV_CODEC_ID_HEVC : return cudaVideoCodec_HEVC; - case AV_CODEC_ID_VP8 : return cudaVideoCodec_VP8; - case AV_CODEC_ID_VP9 : return cudaVideoCodec_VP9; - case AV_CODEC_ID_MJPEG : return cudaVideoCodec_JPEG; - case AV_CODEC_ID_AV1 : return cudaVideoCodec_AV1; - default : return cudaVideoCodec_NumCodecs; - } +inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) { + switch (id) { + case AV_CODEC_ID_MPEG1VIDEO: + return cudaVideoCodec_MPEG1; + case AV_CODEC_ID_MPEG2VIDEO: + return cudaVideoCodec_MPEG2; + case AV_CODEC_ID_MPEG4: + return cudaVideoCodec_MPEG4; + case AV_CODEC_ID_WMV3: + case AV_CODEC_ID_VC1: + return cudaVideoCodec_VC1; + case AV_CODEC_ID_H264: + return cudaVideoCodec_H264; + case AV_CODEC_ID_HEVC: + return cudaVideoCodec_HEVC; + case AV_CODEC_ID_VP8: + return cudaVideoCodec_VP8; + case AV_CODEC_ID_VP9: + return cudaVideoCodec_VP9; + case AV_CODEC_ID_MJPEG: + return cudaVideoCodec_JPEG; + case AV_CODEC_ID_AV1: + return cudaVideoCodec_AV1; + default: + return cudaVideoCodec_NumCodecs; + } } diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 128efbe6922..a58bd08f74f 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,78 +1,77 @@ #include "gpu_decoder.h" -GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) : demuxer(src_file.c_str()), device(dev) -{ +GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) + : demuxer(src_file.c_str()), device(dev) { if (cudaSuccess != cudaSetDevice(device)) { - printf("Error setting device\n"); - return; + printf("Error setting device\n"); + return; } - CheckForCudaErrors( - cuDevicePrimaryCtxRetain(&ctx, device), - __LINE__); + CheckForCudaErrors(cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec())); initialised = true; } -GPUDecoder::~GPUDecoder() -{ +GPUDecoder::~GPUDecoder() { dec.release(); if (initialised) { - CheckForCudaErrors( - cuDevicePrimaryCtxRelease(device), - __LINE__); + CheckForCudaErrors(cuDevicePrimaryCtxRelease(device), __LINE__); } } -torch::Tensor GPUDecoder::decode() -{ +torch::Tensor GPUDecoder::decode() { torch::Tensor frameTensor; unsigned long videoBytes = 0; uint8_t *frame = nullptr, *video = nullptr; - do - { + do { demuxer.Demux(&video, &videoBytes); dec.Decode(video, videoBytes); frame = dec.FetchFrame(); } while (frame == nullptr && videoBytes > 0); if (frame == nullptr) { - auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + auto options = + torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); return torch::zeros({0}, options); } auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); frameTensor = torch::from_blob( - frame, {dec.GetFrameSize()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); + frame, + {dec.GetFrameSize()}, + [](auto p) { cuMemFree((CUdeviceptr)p); }, + options); return frameTensor; } -torch::Tensor GPUDecoder::NV12ToYUV420(torch::Tensor frameTensor) -{ +torch::Tensor GPUDecoder::NV12ToYUV420(torch::Tensor frameTensor) { int width = dec.GetWidth(), height = dec.GetHeight(); int pitch = width; - uint8_t *frame = frameTensor.data_ptr(); - uint8_t *ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; + uint8_t* frame = frameTensor.data_ptr(); + uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; // sizes of source surface plane int sizePlaneY = pitch * height; int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2); int sizePlaneV = sizePlaneU; - uint8_t *uv = frame + sizePlaneY; - uint8_t *u = uv; - uint8_t *v = uv + sizePlaneU; + uint8_t* uv = frame + sizePlaneY; + uint8_t* u = uv; + uint8_t* v = uv + sizePlaneU; // split chroma from interleave to planar for (int y = 0; y < (height + 1) / 2; y++) { - for (int x = 0; x < (width + 1) / 2; x++) { - u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; - ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; - } + for (int x = 0; x < (width + 1) / 2; x++) { + u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; + ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; + } } if (pitch == width) { - memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); + memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); } else { - for (int i = 0; i < (height + 1) / 2; i++) { - memcpy(v + ((pitch + 1) / 2) * i, ptr + ((width + 1) / 2) * i, ((width + 1) / 2) * sizeof(uint8_t)); - } + for (int i = 0; i < (height + 1) / 2; i++) { + memcpy( + v + ((pitch + 1) / 2) * i, + ptr + ((width + 1) / 2) * i, + ((width + 1) / 2) * sizeof(uint8_t)); + } } delete[] ptr; return frameTensor; @@ -80,8 +79,7 @@ torch::Tensor GPUDecoder::NV12ToYUV420(torch::Tensor frameTensor) TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") - .def(torch::init()) - .def("next", &GPUDecoder::decode) - .def("reformat", &GPUDecoder::NV12ToYUV420) - ; - } + .def(torch::init()) + .def("next", &GPUDecoder::decode) + .def("reformat", &GPUDecoder::NV12ToYUV420); +} diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index d400c213bf6..e8dd6de5d0b 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -4,17 +4,17 @@ #include "demuxer.h" class GPUDecoder : public torch::CustomClassHolder { - public: - GPUDecoder(std::string, int64_t); - ~GPUDecoder(); - torch::Tensor decode(); - torch::Tensor NV12ToYUV420(torch::Tensor); + public: + GPUDecoder(std::string, int64_t); + ~GPUDecoder(); + torch::Tensor decode(); + torch::Tensor NV12ToYUV420(torch::Tensor); - private: - Demuxer demuxer; - CUcontext ctx; - Decoder dec; - int64_t device; - bool initialised = false; - std::string output_format; + private: + Demuxer demuxer; + CUcontext ctx; + Decoder dec; + int64_t device; + bool initialised = false; + std::string output_format; }; From 5a08055ebcfed7c52b50b6a379050646b6aeac26 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sun, 19 Dec 2021 10:42:36 -0800 Subject: [PATCH 17/39] Make reformat private --- test/test_video_gpu_decoder.py | 2 +- torchvision/io/__init__.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index 2f91f2e2466..84309e3e217 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -33,7 +33,7 @@ def test_frame_reading(self): for av_frame in container.decode(container.streams.video[0]): av_frames = torch.tensor(av_frame.to_ndarray().flatten()) vision_frames = next(decoder)["data"] - mean_delta = torch.mean(torch.abs(av_frames.float() - decoder.reformat(vision_frames).float())) + mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float())) assert mean_delta < 0.1 diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 139491d63a4..f30eed18514 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -5,8 +5,8 @@ from ._video_opt import ( Timebase, VideoMetaData, - _HAS_VIDEO_OPT, _HAS_VIDEO_DECODER, + _HAS_VIDEO_OPT, _probe_video_from_file, _probe_video_from_memory, _read_video_from_file, @@ -203,12 +203,12 @@ def set_current_stream(self, stream: str) -> bool: print("GPU decoding only works with video stream.") return self._c.set_current_stream(stream) - def reformat(self, tensor, format: str = "yuv420"): + def _reformat(self, tensor, output_format: str = "yuv420"): supported_formats = [ "yuv420", ] - if format not in supported_formats: - raise RuntimeError(f"{format} not supported, please use one of {', '.join(supported_formats)}") + if output_format not in supported_formats: + raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}") if not isinstance(tensor, torch.Tensor): raise RuntimeError("Expected tensor as input parameter!") return self._c.reformat(tensor.cpu()) From fd30a8975321c2fc5c1338a7fd23b093d3ecb84d Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sun, 19 Dec 2021 12:05:19 -0800 Subject: [PATCH 18/39] Member function naming --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 98 +++++++++---------- torchvision/csrc/io/decoder/gpu/decoder.h | 57 +++++------ torchvision/csrc/io/decoder/gpu/demuxer.h | 6 +- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 22 ++--- torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 5 +- 5 files changed, 94 insertions(+), 94 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index ef2199800d0..ee1142985f2 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -2,17 +2,16 @@ #include #include #include -#include #include -static float GetChromaHeightFactor(cudaVideoSurfaceFormat surfaceFormat) { +static float chroma_height_factor(cudaVideoSurfaceFormat surfaceFormat) { return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 1.0 : 0.5; } -static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) { +static int chroma_plane_count(cudaVideoSurfaceFormat surfaceFormat) { return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) ? 2 @@ -22,7 +21,7 @@ static int GetChromaPlaneCount(cudaVideoSurfaceFormat surfaceFormat) { void Decoder::init(CUcontext context, cudaVideoCodec codec) { cuContext = context; videoCodec = codec; - CheckForCudaErrors(cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__); + check_for_cuda_errors(cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__); CUVIDPARSERPARAMS parserParams = { .CodecType = codec, @@ -30,12 +29,13 @@ void Decoder::init(CUcontext context, cudaVideoCodec codec) { .ulClockRate = 1000, .ulMaxDisplayDelay = 0u, .pUserData = this, - .pfnSequenceCallback = HandleVideoSequenceProc, - .pfnDecodePicture = HandlePictureDecodeProc, - .pfnDisplayPicture = HandlePictureDisplayProc, - .pfnGetOperatingPoint = HandleOperatingPointProc, + .pfnSequenceCallback = video_sequence_handler, + .pfnDecodePicture = picture_decode_handler, + .pfnDisplayPicture = picture_display_handler, + .pfnGetOperatingPoint = operating_point_handler, }; - CheckForCudaErrors(cuvidCreateVideoParser(&parser, &parserParams), __LINE__); + check_for_cuda_errors( + cuvidCreateVideoParser(&parser, &parserParams), __LINE__); } Decoder::~Decoder() { @@ -58,7 +58,7 @@ void Decoder::release() { cuCtxPopCurrent(NULL); } -unsigned long Decoder::Decode(const uint8_t* data, unsigned long size) { +unsigned long Decoder::decode(const uint8_t* data, unsigned long size) { numDecodedFrames = 0; CUVIDSOURCEDATAPACKET pkt = { .flags = CUVID_PKT_TIMESTAMP, @@ -68,12 +68,12 @@ unsigned long Decoder::Decode(const uint8_t* data, unsigned long size) { if (!data || size == 0) { pkt.flags |= CUVID_PKT_ENDOFSTREAM; } - CheckForCudaErrors(cuvidParseVideoData(parser, &pkt), __LINE__); + check_for_cuda_errors(cuvidParseVideoData(parser, &pkt), __LINE__); cuvidStream = 0; return numDecodedFrames; } -uint8_t* Decoder::FetchFrame() { +uint8_t* Decoder::fetch_frame() { if (decodedFrames.empty()) { return nullptr; } @@ -82,18 +82,18 @@ uint8_t* Decoder::FetchFrame() { return frame; } -int Decoder::HandlePictureDecode(CUVIDPICPARAMS* picParams) { +int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { if (!decoder) { throw std::runtime_error("Uninitialised decoder."); } picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors(cuvidDecodePicture(decoder, picParams), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors(cuvidDecodePicture(decoder, picParams), __LINE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); return 1; } -int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { +int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { CUVIDPROCPARAMS procParams = { .progressive_frame = dispInfo->progressive_frame, .second_field = dispInfo->repeat_first_field + 1, @@ -104,8 +104,8 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { CUdeviceptr dpSrcFrame = 0; unsigned int nSrcPitch = 0; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors( + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors( cuvidMapVideoFrame( decoder, dispInfo->picture_index, @@ -127,7 +127,7 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { } uint8_t* decodedFrame = nullptr; - cuMemAlloc((CUdeviceptr*)&decodedFrame, GetFrameSize()); + cuMemAlloc((CUdeviceptr*)&decodedFrame, get_frame_size()); numDecodedFrames++; @@ -138,10 +138,10 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { m.srcPitch = nSrcPitch; m.dstMemoryType = CU_MEMORYTYPE_DEVICE; m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); - m.dstPitch = GetWidth() * bytesPerPixel; - m.WidthInBytes = GetWidth() * bytesPerPixel; + m.dstPitch = get_width() * bytesPerPixel; + m.WidthInBytes = get_width() * bytesPerPixel; m.Height = lumaHeight; - CheckForCudaErrors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); // Copy chroma plane // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning @@ -151,7 +151,7 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight); m.Height = chromaHeight; - CheckForCudaErrors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); if (numChromaPlanes == 2) { m.srcDevice = @@ -159,25 +159,25 @@ int Decoder::HandlePictureDisplay(CUVIDPARSERDISPINFO* dispInfo) { m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight * 2); m.Height = chromaHeight; - CheckForCudaErrors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); } - CheckForCudaErrors(cuStreamSynchronize(cuvidStream), __LINE__); + check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__); decodedFrames.push(decodedFrame); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); - CheckForCudaErrors(cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__); + check_for_cuda_errors(cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__); return 1; } -void Decoder::queryHardware(CUVIDEOFORMAT* videoFormat) { +void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { CUVIDDECODECAPS decodeCaps = { .eCodecType = videoFormat->codec, .eChromaFormat = videoFormat->chroma_format, .nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8, }; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors(cuvidGetDecoderCaps(&decodeCaps), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors(cuvidGetDecoderCaps(&decodeCaps), __LINE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); if (!decodeCaps.bIsSupported) { throw std::runtime_error("Codec not supported on this GPU"); @@ -228,7 +228,7 @@ void Decoder::queryHardware(CUVIDEOFORMAT* videoFormat) { } } -int Decoder::HandleVideoSequence(CUVIDEOFORMAT* vidFormat) { +int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { // videoCodec has been set in the constructor (for parser). Here it's set // again for potential correction videoCodec = vidFormat->codec; @@ -254,11 +254,11 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT* vidFormat) { // so make 420 default } - queryHardware(vidFormat); + query_hardware(vidFormat); if (width && lumaHeight && chromaHeight) { // cuvidCreateDecoder() has been called before, and now there's possible // config change - return ReconfigureDecoder(vidFormat); + return reconfigure_decoder(vidFormat); } videoFormat = *vidFormat; @@ -303,8 +303,8 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT* vidFormat) { videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; chromaHeight = - (int)(ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat))); - numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + (int)(ceil(lumaHeight * chroma_height_factor(videoOutputFormat))); + numChromaPlanes = chroma_plane_count(videoOutputFormat); surfaceHeight = videoDecodeCreateInfo.ulTargetHeight; surfaceWidth = videoDecodeCreateInfo.ulTargetWidth; displayRect.bottom = videoDecodeCreateInfo.display_area.bottom; @@ -312,14 +312,14 @@ int Decoder::HandleVideoSequence(CUVIDEOFORMAT* vidFormat) { displayRect.left = videoDecodeCreateInfo.display_area.left; displayRect.right = videoDecodeCreateInfo.display_area.right; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors( + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors( cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); return nDecodeSurface; } -int Decoder::ReconfigureDecoder(CUVIDEOFORMAT* vidFormat) { +int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { if (vidFormat->bit_depth_luma_minus8 != videoFormat.bit_depth_luma_minus8 || vidFormat->bit_depth_chroma_minus8 != videoFormat.bit_depth_chroma_minus8) { @@ -360,8 +360,8 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT* vidFormat) { width = vidFormat->display_area.right - vidFormat->display_area.left; lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; chromaHeight = - (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); - numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + (int)ceil(lumaHeight * chroma_height_factor(videoOutputFormat)); + numChromaPlanes = chroma_plane_count(videoOutputFormat); } // no need for reconfigureDecoder(). Just return @@ -394,8 +394,8 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT* vidFormat) { reconfigParams.ulTargetWidth = vidFormat->coded_width; reconfigParams.ulTargetHeight = vidFormat->coded_height; chromaHeight = - (int)ceil(lumaHeight * GetChromaHeightFactor(videoOutputFormat)); - numChromaPlanes = GetChromaPlaneCount(videoOutputFormat); + (int)ceil(lumaHeight * chroma_height_factor(videoOutputFormat)); + numChromaPlanes = chroma_plane_count(videoOutputFormat); surfaceHeight = reconfigParams.ulTargetHeight; surfaceWidth = reconfigParams.ulTargetWidth; displayRect.bottom = reconfigParams.display_area.bottom; @@ -405,15 +405,15 @@ int Decoder::ReconfigureDecoder(CUVIDEOFORMAT* vidFormat) { } reconfigParams.ulNumDecodeSurfaces = nDecodeSurface; - CheckForCudaErrors(cuCtxPushCurrent(cuContext), __LINE__); - CheckForCudaErrors( + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors( cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__); - CheckForCudaErrors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); return nDecodeSurface; } -int Decoder::GetOperatingPoint(CUVIDOPERATINGPOINTINFO* operPointInfo) { +int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* operPointInfo) { if (operPointInfo->codec == cudaVideoCodec_AV1) { if (operPointInfo->av1.operating_points_cnt > 1) { // clip has SVC enabled diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 8079bf93e3b..01b77a4ee0b 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -6,7 +6,7 @@ #include #include -static auto CheckForCudaErrors = [](CUresult result, int lineNum) { +static auto check_for_cuda_errors = [](CUresult result, int lineNum) { if (CUDA_SUCCESS != result) { std::stringstream errorStream; const char* errorName = nullptr; @@ -30,23 +30,20 @@ class Decoder { Decoder() {} ~Decoder(); void init(CUcontext, cudaVideoCodec); - unsigned long Decode(const uint8_t*, unsigned long); + unsigned long decode(const uint8_t*, unsigned long); void release(); - uint8_t* FetchFrame(); - cudaVideoSurfaceFormat GetOutputFormat() const { - return videoOutputFormat; - } - int GetFrameSize() const { - return GetWidth() * (lumaHeight + (chromaHeight * numChromaPlanes)) * + uint8_t* fetch_frame(); + int get_frame_size() const { + return get_width() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; } - int GetWidth() const { + int get_width() const { return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || videoOutputFormat == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; } - int GetHeight() const { + int get_height() const { return lumaHeight; } @@ -78,27 +75,31 @@ class Decoder { bool reconfigExtPPChange = false; std::queue decodedFrames; - static int CUDAAPI - HandleVideoSequenceProc(void* pUserData, CUVIDEOFORMAT* pVideoFormat) { - return ((Decoder*)pUserData)->HandleVideoSequence(pVideoFormat); + static int video_sequence_handler( + void* pUserData, + CUVIDEOFORMAT* pVideoFormat) { + return ((Decoder*)pUserData)->handle_video_sequence(pVideoFormat); } - static int CUDAAPI - HandlePictureDecodeProc(void* pUserData, CUVIDPICPARAMS* pPicParams) { - return ((Decoder*)pUserData)->HandlePictureDecode(pPicParams); + static int picture_decode_handler( + void* pUserData, + CUVIDPICPARAMS* pPicParams) { + return ((Decoder*)pUserData)->handle_picture_decode(pPicParams); } - static int CUDAAPI - HandlePictureDisplayProc(void* pUserData, CUVIDPARSERDISPINFO* pDispInfo) { - return ((Decoder*)pUserData)->HandlePictureDisplay(pDispInfo); + static int picture_display_handler( + void* pUserData, + CUVIDPARSERDISPINFO* pDispInfo) { + return ((Decoder*)pUserData)->handle_picture_display(pDispInfo); } - static int CUDAAPI - HandleOperatingPointProc(void* pUserData, CUVIDOPERATINGPOINTINFO* pOPInfo) { - return ((Decoder*)pUserData)->GetOperatingPoint(pOPInfo); + static int operating_point_handler( + void* pUserData, + CUVIDOPERATINGPOINTINFO* pOPInfo) { + return ((Decoder*)pUserData)->get_operating_point(pOPInfo); } - void queryHardware(CUVIDEOFORMAT* videoFormat); - int ReconfigureDecoder(CUVIDEOFORMAT* pVideoFormat); - int HandleVideoSequence(CUVIDEOFORMAT* pVideoFormat); - int HandlePictureDecode(CUVIDPICPARAMS* pPicParams); - int HandlePictureDisplay(CUVIDPARSERDISPINFO* pDispInfo); - int GetOperatingPoint(CUVIDOPERATINGPOINTINFO* pOPInfo); + void query_hardware(CUVIDEOFORMAT* videoFormat); + int reconfigure_decoder(CUVIDEOFORMAT* pVideoFormat); + int handle_video_sequence(CUVIDEOFORMAT* pVideoFormat); + int handle_picture_decode(CUVIDPICPARAMS* pPicParams); + int handle_picture_display(CUVIDPARSERDISPINFO* pDispInfo); + int get_operating_point(CUVIDOPERATINGPOINTINFO* pOPInfo); }; diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index bb72ae3c3ea..653ce7bdd80 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -113,11 +113,11 @@ class Demuxer { } } - AVCodecID GetVideoCodec() { + AVCodecID get_video_codec() { return eVideoCodec; } - bool Demux(uint8_t** video, unsigned long* videoBytes) { + bool demux(uint8_t** video, unsigned long* videoBytes) { if (!fmtCtx) { return false; } @@ -176,7 +176,7 @@ class Demuxer { } }; -inline cudaVideoCodec FFmpeg2NvCodecId(AVCodecID id) { +inline cudaVideoCodec ffmpeg_to_codec(AVCodecID id) { switch (id) { case AV_CODEC_ID_MPEG1VIDEO: return cudaVideoCodec_MPEG1; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index a58bd08f74f..9b0c97e2cad 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -6,15 +6,15 @@ GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) printf("Error setting device\n"); return; } - CheckForCudaErrors(cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); - dec.init(ctx, FFmpeg2NvCodecId(demuxer.GetVideoCodec())); + check_for_cuda_errors(cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); + decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); initialised = true; } GPUDecoder::~GPUDecoder() { - dec.release(); + decoder.release(); if (initialised) { - CheckForCudaErrors(cuDevicePrimaryCtxRelease(device), __LINE__); + check_for_cuda_errors(cuDevicePrimaryCtxRelease(device), __LINE__); } } @@ -23,9 +23,9 @@ torch::Tensor GPUDecoder::decode() { unsigned long videoBytes = 0; uint8_t *frame = nullptr, *video = nullptr; do { - demuxer.Demux(&video, &videoBytes); - dec.Decode(video, videoBytes); - frame = dec.FetchFrame(); + demuxer.demux(&video, &videoBytes); + decoder.decode(video, videoBytes); + frame = decoder.fetch_frame(); } while (frame == nullptr && videoBytes > 0); if (frame == nullptr) { auto options = @@ -35,14 +35,14 @@ torch::Tensor GPUDecoder::decode() { auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); frameTensor = torch::from_blob( frame, - {dec.GetFrameSize()}, + {decoder.get_frame_size()}, [](auto p) { cuMemFree((CUdeviceptr)p); }, options); return frameTensor; } -torch::Tensor GPUDecoder::NV12ToYUV420(torch::Tensor frameTensor) { - int width = dec.GetWidth(), height = dec.GetHeight(); +torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) { + int width = decoder.get_width(), height = decoder.get_height(); int pitch = width; uint8_t* frame = frameTensor.data_ptr(); uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; @@ -81,5 +81,5 @@ TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") .def(torch::init()) .def("next", &GPUDecoder::decode) - .def("reformat", &GPUDecoder::NV12ToYUV420); + .def("reformat", &GPUDecoder::nv12_to_yuv420); } diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index e8dd6de5d0b..02b14fda99e 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -8,13 +8,12 @@ class GPUDecoder : public torch::CustomClassHolder { GPUDecoder(std::string, int64_t); ~GPUDecoder(); torch::Tensor decode(); - torch::Tensor NV12ToYUV420(torch::Tensor); + torch::Tensor nv12_to_yuv420(torch::Tensor); private: Demuxer demuxer; CUcontext ctx; - Decoder dec; + Decoder decoder; int64_t device; bool initialised = false; - std::string output_format; }; From 87ed21e1899aed7ff3a49e03a19dadaf27a1773c Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sun, 19 Dec 2021 15:25:54 -0800 Subject: [PATCH 19/39] Add comments --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 23 +++++++++++++++++++ .../csrc/io/decoder/gpu/gpu_decoder.cpp | 7 ++++++ 2 files changed, 30 insertions(+) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index ee1142985f2..b4079b14f3e 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -18,6 +18,9 @@ static int chroma_plane_count(cudaVideoSurfaceFormat surfaceFormat) { : 1; } +/* Initialise cuContext and videoCodec, create context lock and create parser + * object. + */ void Decoder::init(CUcontext context, cudaVideoCodec codec) { cuContext = context; videoCodec = codec; @@ -38,6 +41,8 @@ void Decoder::init(CUcontext context, cudaVideoCodec codec) { cuvidCreateVideoParser(&parser, &parserParams), __LINE__); } +/* Destroy parser object and context lock. + */ Decoder::~Decoder() { if (parser) { cuvidDestroyVideoParser(parser); @@ -45,6 +50,8 @@ Decoder::~Decoder() { cuvidCtxLockDestroy(ctxLock); } +/* Destroy CUvideodecoder object and free up all the unreturned decoded frames. + */ void Decoder::release() { cuCtxPushCurrent(cuContext); if (decoder) { @@ -58,6 +65,8 @@ void Decoder::release() { cuCtxPopCurrent(NULL); } +/* Trigger video decoding. + */ unsigned long Decoder::decode(const uint8_t* data, unsigned long size) { numDecodedFrames = 0; CUVIDSOURCEDATAPACKET pkt = { @@ -73,6 +82,8 @@ unsigned long Decoder::decode(const uint8_t* data, unsigned long size) { return numDecodedFrames; } +/* Fetch a decoded frame and remove it from the queue. + */ uint8_t* Decoder::fetch_frame() { if (decodedFrames.empty()) { return nullptr; @@ -82,6 +93,8 @@ uint8_t* Decoder::fetch_frame() { return frame; } +/* Called when a picture is ready to be decoded. + */ int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { if (!decoder) { throw std::runtime_error("Uninitialised decoder."); @@ -93,6 +106,8 @@ int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { return 1; } +/* Process the decoded data and copy it to a cuda memory location. + */ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { CUVIDPROCPARAMS procParams = { .progressive_frame = dispInfo->progressive_frame, @@ -169,6 +184,9 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { return 1; } +/* Query the capabilities of the underlying hardware video decoder and + * verify if the hardware supports decoding the passed video. + */ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { CUVIDDECODECAPS decodeCaps = { .eCodecType = videoFormat->codec, @@ -228,6 +246,9 @@ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { } } +/* Called before decoding frames and/or whenever there is a configuration + * change. + */ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { // videoCodec has been set in the constructor (for parser). Here it's set // again for potential correction @@ -413,6 +434,8 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { return nDecodeSurface; } +/* Called from AV1 sequence header to get operating point of a AV1 bitstream. + */ int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* operPointInfo) { if (operPointInfo->codec == cudaVideoCodec_AV1) { if (operPointInfo->av1.operating_points_cnt > 1) { diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 9b0c97e2cad..188b2880ccc 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,5 +1,7 @@ #include "gpu_decoder.h" +/* Set cuda device, create cuda context and initialise the demuxer and decoder. + */ GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) : demuxer(src_file.c_str()), device(dev) { if (cudaSuccess != cudaSetDevice(device)) { @@ -18,6 +20,8 @@ GPUDecoder::~GPUDecoder() { } } +/* Fetch a decoded frame tensor after demuxing and decoding. + */ torch::Tensor GPUDecoder::decode() { torch::Tensor frameTensor; unsigned long videoBytes = 0; @@ -41,6 +45,9 @@ torch::Tensor GPUDecoder::decode() { return frameTensor; } +/* Convert a tensor with data in NV12 format to a tensor with data in YUV420 + * format in-place. + */ torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) { int width = decoder.get_width(), height = decoder.get_height(); int pitch = width; From f6b6cfef6e911ccfe0131b6debca9ca6c266422d Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sun, 19 Dec 2021 15:41:30 -0800 Subject: [PATCH 20/39] Variable renaming --- torchvision/csrc/io/decoder/gpu/decoder.h | 36 +++++++++++------------ torchvision/csrc/io/decoder/gpu/demuxer.h | 16 +++++----- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 01b77a4ee0b..d5d46477d6d 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -76,30 +76,30 @@ class Decoder { std::queue decodedFrames; static int video_sequence_handler( - void* pUserData, - CUVIDEOFORMAT* pVideoFormat) { - return ((Decoder*)pUserData)->handle_video_sequence(pVideoFormat); + void* user_data, + CUVIDEOFORMAT* video_format) { + return ((Decoder*)user_data)->handle_video_sequence(video_format); } static int picture_decode_handler( - void* pUserData, - CUVIDPICPARAMS* pPicParams) { - return ((Decoder*)pUserData)->handle_picture_decode(pPicParams); + void* user_data, + CUVIDPICPARAMS* pic_params) { + return ((Decoder*)user_data)->handle_picture_decode(pic_params); } static int picture_display_handler( - void* pUserData, - CUVIDPARSERDISPINFO* pDispInfo) { - return ((Decoder*)pUserData)->handle_picture_display(pDispInfo); + void* user_data, + CUVIDPARSERDISPINFO* disp_info) { + return ((Decoder*)user_data)->handle_picture_display(disp_info); } static int operating_point_handler( - void* pUserData, - CUVIDOPERATINGPOINTINFO* pOPInfo) { - return ((Decoder*)pUserData)->get_operating_point(pOPInfo); + void* user_data, + CUVIDOPERATINGPOINTINFO* operating_info) { + return ((Decoder*)user_data)->get_operating_point(operating_info); } - void query_hardware(CUVIDEOFORMAT* videoFormat); - int reconfigure_decoder(CUVIDEOFORMAT* pVideoFormat); - int handle_video_sequence(CUVIDEOFORMAT* pVideoFormat); - int handle_picture_decode(CUVIDPICPARAMS* pPicParams); - int handle_picture_display(CUVIDPARSERDISPINFO* pDispInfo); - int get_operating_point(CUVIDOPERATINGPOINTINFO* pOPInfo); + void query_hardware(CUVIDEOFORMAT*); + int reconfigure_decoder(CUVIDEOFORMAT*); + int handle_video_sequence(CUVIDEOFORMAT*); + int handle_picture_decode(CUVIDPICPARAMS*); + int handle_picture_display(CUVIDPARSERDISPINFO*); + int get_operating_point(CUVIDOPERATINGPOINTINFO*); }; diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index 653ce7bdd80..e92ce549172 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -21,7 +21,7 @@ class Demuxer { AVBSFContext* bsfCtx = NULL; AVPacket pkt, pktFiltered; AVCodecID eVideoCodec; - uint8_t* pDataWithHeader = NULL; + uint8_t* dataWithHeader = NULL; bool bMp4H264, bMp4HEVC, bMp4MPEG4; unsigned int frameCount = 0; int iVideoStream; @@ -108,8 +108,8 @@ class Demuxer { av_bsf_free(&bsfCtx); } avformat_close_input(&fmtCtx); - if (pDataWithHeader) { - av_free(pDataWithHeader); + if (dataWithHeader) { + av_free(dataWithHeader); } } @@ -149,21 +149,21 @@ class Demuxer { fmtCtx->streams[iVideoStream]->codecpar->extradata_size; if (extraDataSize > 0) { - pDataWithHeader = (uint8_t*)av_malloc( + dataWithHeader = (uint8_t*)av_malloc( extraDataSize + pkt.size - 3 * sizeof(uint8_t)); - if (!pDataWithHeader) { + if (!dataWithHeader) { printf("FFmpeg error: %d\n", __LINE__); return false; } memcpy( - pDataWithHeader, + dataWithHeader, fmtCtx->streams[iVideoStream]->codecpar->extradata, extraDataSize); memcpy( - pDataWithHeader + extraDataSize, + dataWithHeader + extraDataSize, pkt.data + 3, pkt.size - 3 * sizeof(uint8_t)); - *video = pDataWithHeader; + *video = dataWithHeader; *videoBytes = extraDataSize + pkt.size - 3 * sizeof(uint8_t); } } else { From 3e309a50d5c6bb91cebd181249f57f3fe91713a0 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sun, 19 Dec 2021 16:16:35 -0800 Subject: [PATCH 21/39] Code cleanup --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 82 +++++++-------------- torchvision/csrc/io/decoder/gpu/decoder.h | 17 ++--- 2 files changed, 34 insertions(+), 65 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index b4079b14f3e..dd50a7a3e2b 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -250,8 +250,8 @@ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { * change. */ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { - // videoCodec has been set in the constructor (for parser). Here it's set - // again for potential correction + // videoCodec has been set in the init(). Here it's set + // again for potential correction. videoCodec = vidFormat->codec; videoChromaFormat = vidFormat->chroma_format; bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; @@ -270,21 +270,21 @@ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { : cudaVideoSurfaceFormat_YUV444; break; case cudaVideoChromaFormat_422: - videoOutputFormat = - cudaVideoSurfaceFormat_NV12; // no 4:2:2 output format supported yet - // so make 420 default + videoOutputFormat = cudaVideoSurfaceFormat_NV12; } query_hardware(vidFormat); + if (width && lumaHeight && chromaHeight) { - // cuvidCreateDecoder() has been called before, and now there's possible - // config change + // cuvidCreateDecoder() has been called before and now there's possible + // config change. return reconfigure_decoder(vidFormat); } videoFormat = *vidFormat; - unsigned long nDecodeSurface = vidFormat->min_num_decode_surfaces; + unsigned long decodeSurface = vidFormat->min_num_decode_surfaces; cudaVideoDeinterlaceMode deinterlaceMode = cudaVideoDeinterlaceMode_Adaptive; + if (vidFormat->progressive_sequence) { deinterlaceMode = cudaVideoDeinterlaceMode_Weave; } @@ -292,7 +292,7 @@ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { CUVIDDECODECREATEINFO videoDecodeCreateInfo = { .ulWidth = vidFormat->coded_width, .ulHeight = vidFormat->coded_height, - .ulNumDecodeSurfaces = nDecodeSurface, + .ulNumDecodeSurfaces = decodeSurface, .CodecType = vidFormat->codec, .ChromaFormat = vidFormat->chroma_format, // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded @@ -337,7 +337,7 @@ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { check_for_cuda_errors( cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); - return nDecodeSurface; + return decodeSurface; } int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { @@ -351,87 +351,61 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { "Reconfigure Not supported for chroma format change"); } - bool bDecodeResChange = + bool decodeResChange = !(vidFormat->coded_width == videoFormat.coded_width && vidFormat->coded_height == videoFormat.coded_height); - bool bDisplayRectChange = + bool displayRectChange = !(vidFormat->display_area.bottom == videoFormat.display_area.bottom && vidFormat->display_area.top == videoFormat.display_area.top && vidFormat->display_area.left == videoFormat.display_area.left && vidFormat->display_area.right == videoFormat.display_area.right); - int nDecodeSurface = vidFormat->min_num_decode_surfaces; + unsigned int decodeSurface = vidFormat->min_num_decode_surfaces; if ((vidFormat->coded_width > maxWidth) || (vidFormat->coded_height > maxHeight)) { // For VP9, let driver handle the change if new width/height > // maxwidth/maxheight - if ((videoCodec != cudaVideoCodec_VP9) || reconfigExternal) { + if (videoCodec != cudaVideoCodec_VP9) { throw std::runtime_error( "Reconfigure Not supported when width/height > maxwidth/maxheight"); } return 1; } - if (!bDecodeResChange && !reconfigExtPPChange) { - // if the coded_width/coded_height hasn't changed but display resolution has + if (!decodeResChange) { + // If the coded_width/coded_height hasn't changed but display resolution has // changed, then need to update width/height for correct output without - // cropping. Example : 1920x1080 vs 1920x1088 - if (bDisplayRectChange) { + // cropping. Example : 1920x1080 vs 1920x1088. + if (displayRectChange) { width = vidFormat->display_area.right - vidFormat->display_area.left; lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; chromaHeight = (int)ceil(lumaHeight * chroma_height_factor(videoOutputFormat)); numChromaPlanes = chroma_plane_count(videoOutputFormat); } - - // no need for reconfigureDecoder(). Just return return 1; } - - CUVIDRECONFIGUREDECODERINFO reconfigParams = {0}; - - reconfigParams.ulWidth = videoFormat.coded_width = vidFormat->coded_width; - reconfigParams.ulHeight = videoFormat.coded_height = vidFormat->coded_height; - - // Dont change display rect and get scaled output from decoder. This will help - // display app to present apps smoothly + videoFormat.coded_width = vidFormat->coded_width; + videoFormat.coded_height = vidFormat->coded_height; + CUVIDRECONFIGUREDECODERINFO reconfigParams = { + .ulWidth = vidFormat->coded_width, + .ulHeight = vidFormat->coded_height, + .ulTargetWidth = surfaceWidth, + .ulTargetHeight = surfaceHeight, + .ulNumDecodeSurfaces = decodeSurface, + }; reconfigParams.display_area.bottom = displayRect.bottom; reconfigParams.display_area.top = displayRect.top; reconfigParams.display_area.left = displayRect.left; reconfigParams.display_area.right = displayRect.right; - reconfigParams.ulTargetWidth = surfaceWidth; - reconfigParams.ulTargetHeight = surfaceHeight; - - // If external reconfigure is called along with resolution change even if post - // processing params is not changed, do full reconfigure params update - if ((reconfigExternal && bDecodeResChange) || reconfigExtPPChange) { - // update display rect and target resolution if requested explicitely - reconfigExternal = false; - reconfigExtPPChange = false; - videoFormat = *vidFormat; - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - reconfigParams.ulTargetWidth = vidFormat->coded_width; - reconfigParams.ulTargetHeight = vidFormat->coded_height; - chromaHeight = - (int)ceil(lumaHeight * chroma_height_factor(videoOutputFormat)); - numChromaPlanes = chroma_plane_count(videoOutputFormat); - surfaceHeight = reconfigParams.ulTargetHeight; - surfaceWidth = reconfigParams.ulTargetWidth; - displayRect.bottom = reconfigParams.display_area.bottom; - displayRect.top = reconfigParams.display_area.top; - displayRect.left = reconfigParams.display_area.left; - displayRect.right = reconfigParams.display_area.right; - } - reconfigParams.ulNumDecodeSurfaces = nDecodeSurface; check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); check_for_cuda_errors( cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); - return nDecodeSurface; + return decodeSurface; } /* Called from AV1 sequence header to get operating point of a AV1 bitstream. diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index d5d46477d6d..a8b2dfd2213 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -56,23 +56,18 @@ class Decoder { int numDecodedFrames = 0; unsigned int numChromaPlanes = 0; // dimension of the output + bool dispAllLayers = false; unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; + unsigned int surfaceHeight = 0, surfaceWidth = 0; + unsigned int maxWidth = 0, maxHeight = 0; + unsigned int operatingPoint = 0; + int bitDepthMinus8 = 0, bytesPerPixel = 1; + int decodePicCount = 0, picNumInDecodeOrder[32]; cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; cudaVideoChromaFormat videoChromaFormat = cudaVideoChromaFormat_420; cudaVideoSurfaceFormat videoOutputFormat = cudaVideoSurfaceFormat_NV12; - int bitDepthMinus8 = 0; - int bytesPerPixel = 1; CUVIDEOFORMAT videoFormat = {}; - unsigned int maxWidth = 0, maxHeight = 0; - // height of the mapped surface - int surfaceHeight = 0; - int surfaceWidth = 0; Rect displayRect = {}; - unsigned int operatingPoint = 0; - bool dispAllLayers = false; - int decodePicCount = 0, picNumInDecodeOrder[32]; - bool reconfigExternal = false; - bool reconfigExtPPChange = false; std::queue decodedFrames; static int video_sequence_handler( From d116edcf5991348df56ef6345489c06f83deeb9c Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Sun, 19 Dec 2021 16:26:06 -0800 Subject: [PATCH 22/39] Make return type of decode() void --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 6 +----- torchvision/csrc/io/decoder/gpu/decoder.h | 19 ++++++++----------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index dd50a7a3e2b..963bdfa0783 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -67,8 +67,7 @@ void Decoder::release() { /* Trigger video decoding. */ -unsigned long Decoder::decode(const uint8_t* data, unsigned long size) { - numDecodedFrames = 0; +void Decoder::decode(const uint8_t* data, unsigned long size) { CUVIDSOURCEDATAPACKET pkt = { .flags = CUVID_PKT_TIMESTAMP, .payload_size = size, @@ -79,7 +78,6 @@ unsigned long Decoder::decode(const uint8_t* data, unsigned long size) { } check_for_cuda_errors(cuvidParseVideoData(parser, &pkt), __LINE__); cuvidStream = 0; - return numDecodedFrames; } /* Fetch a decoded frame and remove it from the queue. @@ -144,8 +142,6 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { uint8_t* decodedFrame = nullptr; cuMemAlloc((CUdeviceptr*)&decodedFrame, get_frame_size()); - numDecodedFrames++; - // Copy luma plane CUDA_MEMCPY2D m = {0}; m.srcMemoryType = CU_MEMORYTYPE_DEVICE; diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index a8b2dfd2213..94fdfc0c838 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -30,8 +30,8 @@ class Decoder { Decoder() {} ~Decoder(); void init(CUcontext, cudaVideoCodec); - unsigned long decode(const uint8_t*, unsigned long); void release(); + void decode(const uint8_t*, unsigned long); uint8_t* fetch_frame(); int get_frame_size() const { return get_width() * (lumaHeight + (chromaHeight * numChromaPlanes)) * @@ -48,27 +48,24 @@ class Decoder { } private: - CUcontext cuContext = NULL; - CUvideoctxlock ctxLock; - CUvideoparser parser = NULL; - CUvideodecoder decoder = NULL; - CUstream cuvidStream = 0; - int numDecodedFrames = 0; - unsigned int numChromaPlanes = 0; - // dimension of the output bool dispAllLayers = false; unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; unsigned int surfaceHeight = 0, surfaceWidth = 0; unsigned int maxWidth = 0, maxHeight = 0; - unsigned int operatingPoint = 0; + unsigned int operatingPoint = 0, numChromaPlanes = 0; int bitDepthMinus8 = 0, bytesPerPixel = 1; int decodePicCount = 0, picNumInDecodeOrder[32]; + std::queue decodedFrames; + CUcontext cuContext = NULL; + CUvideoctxlock ctxLock; + CUvideoparser parser = NULL; + CUvideodecoder decoder = NULL; + CUstream cuvidStream = 0; cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; cudaVideoChromaFormat videoChromaFormat = cudaVideoChromaFormat_420; cudaVideoSurfaceFormat videoOutputFormat = cudaVideoSurfaceFormat_NV12; CUVIDEOFORMAT videoFormat = {}; Rect displayRect = {}; - std::queue decodedFrames; static int video_sequence_handler( void* user_data, From c01b66b5ea26487c6043c4215a254c44733adbb9 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 21 Dec 2021 11:24:27 -0800 Subject: [PATCH 23/39] Replace printing errors with throwing runtime_error --- torchvision/csrc/io/decoder/gpu/demuxer.h | 37 ++++++++++++++--------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index e92ce549172..eee41985aeb 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -1,3 +1,4 @@ +#include extern "C" { #include #include @@ -5,12 +6,12 @@ extern "C" { #include } -inline bool check(int ret, int line) { +inline void check(int ret, int line) { if (ret < 0) { - printf("Error %d at line %d in demuxer.h.\n", ret, line); - return false; + std::ostringstream error_string; + error_string << "Error " << ret << " at line " << line << " in demuxer.h\n"; + throw std::runtime_error(error_string.str()); } - return true; } #define check_for_errors(call) check(call, __LINE__) @@ -33,17 +34,17 @@ class Demuxer { avformat_network_init(); check_for_errors(avformat_open_input(&fmtCtx, filePath, NULL, NULL)); if (!fmtCtx) { - printf("No AVFormatContext provided.\n"); - return; + throw std::runtime_error("No AVFormatContext provided.\n"); } check_for_errors(avformat_find_stream_info(fmtCtx, NULL)); iVideoStream = av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); if (iVideoStream < 0) { - printf( - "FFmpeg error: %d, could not find stream in input file\n", __LINE__); - return; + std::ostringstream error_string; + error_string << "FFmpeg error at line: " << __LINE__ + << " in demuxer.h. Could not find stream in input file.\n"; + throw std::runtime_error(error_string.str()); } eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id; @@ -74,8 +75,10 @@ class Demuxer { if (bMp4H264) { const AVBitStreamFilter* bsf = av_bsf_get_by_name("h264_mp4toannexb"); if (!bsf) { - printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); - return; + std::ostringstream error_string; + error_string << "FFmpeg error at line: " << __LINE__ + << " in demuxer.h. av_bsf_get_by_name() failed\n"; + throw std::runtime_error(error_string.str()); } check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); avcodec_parameters_copy( @@ -85,8 +88,10 @@ class Demuxer { if (bMp4HEVC) { const AVBitStreamFilter* bsf = av_bsf_get_by_name("hevc_mp4toannexb"); if (!bsf) { - printf("FFmpeg error: %d, av_bsf_get_by_name() failed\n", __LINE__); - return; + std::ostringstream error_string; + error_string << "FFmpeg error at line: " << __LINE__ + << " in demuxer.h. av_bsf_get_by_name() failed\n"; + throw std::runtime_error(error_string.str()); } check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); avcodec_parameters_copy( @@ -152,8 +157,10 @@ class Demuxer { dataWithHeader = (uint8_t*)av_malloc( extraDataSize + pkt.size - 3 * sizeof(uint8_t)); if (!dataWithHeader) { - printf("FFmpeg error: %d\n", __LINE__); - return false; + std::ostringstream error_string; + error_string << "FFmpeg error at line: " << __LINE__ + << " in demuxer.h.\n"; + throw std::runtime_error(error_string.str()); } memcpy( dataWithHeader, From 7e0d884f12484cd12f51fe33d069c163413ad5b5 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 22 Dec 2021 09:51:59 -0800 Subject: [PATCH 24/39] Replaced runtime_error with TORCH_CHECK in demuxer.h --- torchvision/csrc/io/decoder/gpu/demuxer.h | 101 ++++++++++++++-------- 1 file changed, 65 insertions(+), 36 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/demuxer.h b/torchvision/csrc/io/decoder/gpu/demuxer.h index eee41985aeb..75d4765dd79 100644 --- a/torchvision/csrc/io/decoder/gpu/demuxer.h +++ b/torchvision/csrc/io/decoder/gpu/demuxer.h @@ -1,4 +1,3 @@ -#include extern "C" { #include #include @@ -6,16 +5,6 @@ extern "C" { #include } -inline void check(int ret, int line) { - if (ret < 0) { - std::ostringstream error_string; - error_string << "Error " << ret << " at line " << line << " in demuxer.h\n"; - throw std::runtime_error(error_string.str()); - } -} - -#define check_for_errors(call) check(call, __LINE__) - class Demuxer { private: AVFormatContext* fmtCtx = NULL; @@ -32,19 +21,32 @@ class Demuxer { public: Demuxer(const char* filePath, int64_t timeScale = 1000 /*Hz*/) { avformat_network_init(); - check_for_errors(avformat_open_input(&fmtCtx, filePath, NULL, NULL)); + TORCH_CHECK( + 0 <= avformat_open_input(&fmtCtx, filePath, NULL, NULL), + "avformat_open_input() failed at line ", + __LINE__, + " in demuxer.h\n"); if (!fmtCtx) { - throw std::runtime_error("No AVFormatContext provided.\n"); + TORCH_CHECK( + false, + "Encountered NULL AVFormatContext at line ", + __LINE__, + " in demuxer.h\n"); } - check_for_errors(avformat_find_stream_info(fmtCtx, NULL)); + TORCH_CHECK( + 0 <= avformat_find_stream_info(fmtCtx, NULL), + "avformat_find_stream_info() failed at line ", + __LINE__, + " in demuxer.h\n"); iVideoStream = av_find_best_stream(fmtCtx, AVMEDIA_TYPE_VIDEO, -1, -1, NULL, 0); if (iVideoStream < 0) { - std::ostringstream error_string; - error_string << "FFmpeg error at line: " << __LINE__ - << " in demuxer.h. Could not find stream in input file.\n"; - throw std::runtime_error(error_string.str()); + TORCH_CHECK( + false, + "av_find_best_stream() failed at line ", + __LINE__, + " in demuxer.h\n"); } eVideoCodec = fmtCtx->streams[iVideoStream]->codecpar->codec_id; @@ -75,28 +77,46 @@ class Demuxer { if (bMp4H264) { const AVBitStreamFilter* bsf = av_bsf_get_by_name("h264_mp4toannexb"); if (!bsf) { - std::ostringstream error_string; - error_string << "FFmpeg error at line: " << __LINE__ - << " in demuxer.h. av_bsf_get_by_name() failed\n"; - throw std::runtime_error(error_string.str()); + TORCH_CHECK( + false, + "av_bsf_get_by_name() failed at line ", + __LINE__, + " in demuxer.h\n"); } - check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); + TORCH_CHECK( + 0 <= av_bsf_alloc(bsf, &bsfCtx), + "av_bsf_alloc() failed at line ", + __LINE__, + " in demuxer.h\n"); avcodec_parameters_copy( bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); - check_for_errors(av_bsf_init(bsfCtx)); + TORCH_CHECK( + 0 <= av_bsf_init(bsfCtx), + "av_bsf_init() failed at line ", + __LINE__, + " in demuxer.h\n"); } if (bMp4HEVC) { const AVBitStreamFilter* bsf = av_bsf_get_by_name("hevc_mp4toannexb"); if (!bsf) { - std::ostringstream error_string; - error_string << "FFmpeg error at line: " << __LINE__ - << " in demuxer.h. av_bsf_get_by_name() failed\n"; - throw std::runtime_error(error_string.str()); + TORCH_CHECK( + false, + "av_bsf_get_by_name() failed at line ", + __LINE__, + " in demuxer.h\n"); } - check_for_errors(av_bsf_alloc(bsf, &bsfCtx)); + TORCH_CHECK( + 0 <= av_bsf_alloc(bsf, &bsfCtx), + "av_bsf_alloc() failed at line ", + __LINE__, + " in demuxer.h\n"); avcodec_parameters_copy( bsfCtx->par_in, fmtCtx->streams[iVideoStream]->codecpar); - check_for_errors(av_bsf_init(bsfCtx)); + TORCH_CHECK( + 0 <= av_bsf_init(bsfCtx), + "av_bsf_init() failed at line ", + __LINE__, + " in demuxer.h\n"); } } ~Demuxer() { @@ -144,8 +164,16 @@ class Demuxer { if (pktFiltered.data) { av_packet_unref(&pktFiltered); } - check_for_errors(av_bsf_send_packet(bsfCtx, &pkt)); - check_for_errors(av_bsf_receive_packet(bsfCtx, &pktFiltered)); + TORCH_CHECK( + 0 <= av_bsf_send_packet(bsfCtx, &pkt), + "av_bsf_send_packet() failed at line ", + __LINE__, + " in demuxer.h\n"); + TORCH_CHECK( + 0 <= av_bsf_receive_packet(bsfCtx, &pktFiltered), + "av_bsf_receive_packet() failed at line ", + __LINE__, + " in demuxer.h\n"); *video = pktFiltered.data; *videoBytes = pktFiltered.size; } else { @@ -157,10 +185,11 @@ class Demuxer { dataWithHeader = (uint8_t*)av_malloc( extraDataSize + pkt.size - 3 * sizeof(uint8_t)); if (!dataWithHeader) { - std::ostringstream error_string; - error_string << "FFmpeg error at line: " << __LINE__ - << " in demuxer.h.\n"; - throw std::runtime_error(error_string.str()); + TORCH_CHECK( + false, + "av_malloc() failed at line ", + __LINE__, + " in demuxer.h\n"); } memcpy( dataWithHeader, From 901501c60c77203aff3ec08239ed5cc9afe7c59e Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 22 Dec 2021 11:42:50 -0800 Subject: [PATCH 25/39] Use CUDAGuard instead of cudaSetDevice --- torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index 188b2880ccc..f6c4a21d956 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -1,19 +1,18 @@ #include "gpu_decoder.h" +#include /* Set cuda device, create cuda context and initialise the demuxer and decoder. */ GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) : demuxer(src_file.c_str()), device(dev) { - if (cudaSuccess != cudaSetDevice(device)) { - printf("Error setting device\n"); - return; - } + at::cuda::CUDAGuard device_guard(device); check_for_cuda_errors(cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); initialised = true; } GPUDecoder::~GPUDecoder() { + at::cuda::CUDAGuard device_guard(device); decoder.release(); if (initialised) { check_for_cuda_errors(cuDevicePrimaryCtxRelease(device), __LINE__); @@ -26,6 +25,7 @@ torch::Tensor GPUDecoder::decode() { torch::Tensor frameTensor; unsigned long videoBytes = 0; uint8_t *frame = nullptr, *video = nullptr; + at::cuda::CUDAGuard device_guard(device); do { demuxer.demux(&video, &videoBytes); decoder.decode(video, videoBytes); From 236590668bd4c2bcccad30d797ff5f5ab8381650 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 22 Dec 2021 14:29:41 -0800 Subject: [PATCH 26/39] Remove printf --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 963bdfa0783..044db83239b 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -1,4 +1,5 @@ #include "decoder.h" +#include #include #include #include @@ -134,9 +135,8 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { if (result == CUDA_SUCCESS && (decodeStatus.decodeStatus == cuvidDecodeStatus_Error || decodeStatus.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { - printf( - "Decode Error occurred for picture %d\n", - picNumInDecodeOrder[dispInfo->picture_index]); + VLOG(1) << "Decode Error occurred for picture " + << picNumInDecodeOrder[dispInfo->picture_index]; } uint8_t* decodedFrame = nullptr; @@ -410,17 +410,9 @@ int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* operPointInfo) { if (operPointInfo->codec == cudaVideoCodec_AV1) { if (operPointInfo->av1.operating_points_cnt > 1) { // clip has SVC enabled - if (operatingPoint >= operPointInfo->av1.operating_points_cnt) + if (operatingPoint >= operPointInfo->av1.operating_points_cnt) { operatingPoint = 0; - - printf( - "AV1 SVC clip: operating point count %d ", - operPointInfo->av1.operating_points_cnt); - printf( - "Selected operating point: %d, IDC 0x%x bOutputAllLayers %d\n", - operatingPoint, - operPointInfo->av1.operating_points_idc[operatingPoint], - dispAllLayers); + } return (operatingPoint | (dispAllLayers << 10)); } } From c77b558e0bf1ea0007b0b1847f00f48406f64042 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 22 Dec 2021 15:25:21 -0800 Subject: [PATCH 27/39] Use Tensor instead of uint8* and remove cuMemAlloc/cuMemFree --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 31 +++++++++---------- torchvision/csrc/io/decoder/gpu/decoder.h | 5 +-- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 19 +++--------- 3 files changed, 22 insertions(+), 33 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 044db83239b..150743e8c7a 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -58,11 +58,6 @@ void Decoder::release() { if (decoder) { cuvidDestroyDecoder(decoder); } - while (!decodedFrames.empty()) { - uint8_t* frame = decodedFrames.front(); - decodedFrames.pop(); - cuMemFree((CUdeviceptr)frame); - } cuCtxPopCurrent(NULL); } @@ -83,12 +78,14 @@ void Decoder::decode(const uint8_t* data, unsigned long size) { /* Fetch a decoded frame and remove it from the queue. */ -uint8_t* Decoder::fetch_frame() { - if (decodedFrames.empty()) { - return nullptr; +torch::Tensor Decoder::fetch_frame() { + if (decoded_frames.empty()) { + auto options = + torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + return torch::zeros({0}, options); } - uint8_t* frame = decodedFrames.front(); - decodedFrames.pop(); + torch::Tensor frame = decoded_frames.front(); + decoded_frames.pop(); return frame; } @@ -139,8 +136,9 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { << picNumInDecodeOrder[dispInfo->picture_index]; } - uint8_t* decodedFrame = nullptr; - cuMemAlloc((CUdeviceptr*)&decodedFrame, get_frame_size()); + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options); + uint8_t* frame_ptr = decoded_frame.data_ptr(); // Copy luma plane CUDA_MEMCPY2D m = {0}; @@ -148,7 +146,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { m.srcDevice = dpSrcFrame; m.srcPitch = nSrcPitch; m.dstMemoryType = CU_MEMORYTYPE_DEVICE; - m.dstDevice = (CUdeviceptr)(m.dstHost = decodedFrame); + m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr); m.dstPitch = get_width() * bytesPerPixel; m.WidthInBytes = get_width() * bytesPerPixel; m.Height = lumaHeight; @@ -159,8 +157,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { // height m.srcDevice = (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1)); - m.dstDevice = - (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight); + m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * lumaHeight); m.Height = chromaHeight; check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); @@ -168,12 +165,12 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { m.srcDevice = (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1) * 2); m.dstDevice = - (CUdeviceptr)(m.dstHost = decodedFrame + m.dstPitch * lumaHeight * 2); + (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * lumaHeight * 2); m.Height = chromaHeight; check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); } check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__); - decodedFrames.push(decodedFrame); + decoded_frames.push(decoded_frame); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); check_for_cuda_errors(cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__); diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 94fdfc0c838..dc737e9a5f8 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,7 @@ class Decoder { void init(CUcontext, cudaVideoCodec); void release(); void decode(const uint8_t*, unsigned long); - uint8_t* fetch_frame(); + torch::Tensor fetch_frame(); int get_frame_size() const { return get_width() * (lumaHeight + (chromaHeight * numChromaPlanes)) * bytesPerPixel; @@ -55,7 +56,7 @@ class Decoder { unsigned int operatingPoint = 0, numChromaPlanes = 0; int bitDepthMinus8 = 0, bytesPerPixel = 1; int decodePicCount = 0, picNumInDecodeOrder[32]; - std::queue decodedFrames; + std::queue decoded_frames; CUcontext cuContext = NULL; CUvideoctxlock ctxLock; CUvideoparser parser = NULL; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index f6c4a21d956..fca416f1d7e 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -24,25 +24,16 @@ GPUDecoder::~GPUDecoder() { torch::Tensor GPUDecoder::decode() { torch::Tensor frameTensor; unsigned long videoBytes = 0; - uint8_t *frame = nullptr, *video = nullptr; + uint8_t* video = nullptr; at::cuda::CUDAGuard device_guard(device); + auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); + torch::Tensor frame = torch::zeros({0}, options); do { demuxer.demux(&video, &videoBytes); decoder.decode(video, videoBytes); frame = decoder.fetch_frame(); - } while (frame == nullptr && videoBytes > 0); - if (frame == nullptr) { - auto options = - torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); - return torch::zeros({0}, options); - } - auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); - frameTensor = torch::from_blob( - frame, - {decoder.get_frame_size()}, - [](auto p) { cuMemFree((CUdeviceptr)p); }, - options); - return frameTensor; + } while (frame.numel() == 0 && videoBytes > 0); + return frame; } /* Convert a tensor with data in NV12 format to a tensor with data in YUV420 From e8c80edaa6541e560d3714fe568ce3d6d06794a7 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 23 Dec 2021 05:15:49 -0800 Subject: [PATCH 28/39] Use TORCH_CHECK instead of runtime_error --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 51 ++++++++++--------- torchvision/csrc/io/decoder/gpu/decoder.h | 30 ++++++----- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 6 ++- 3 files changed, 47 insertions(+), 40 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 150743e8c7a..a13f1d10cfd 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -1,6 +1,5 @@ #include "decoder.h" #include -#include #include #include #include @@ -25,7 +24,8 @@ static int chroma_plane_count(cudaVideoSurfaceFormat surfaceFormat) { void Decoder::init(CUcontext context, cudaVideoCodec codec) { cuContext = context; videoCodec = codec; - check_for_cuda_errors(cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__); + check_for_cuda_errors( + cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__, __FILE__); CUVIDPARSERPARAMS parserParams = { .CodecType = codec, @@ -39,7 +39,7 @@ void Decoder::init(CUcontext context, cudaVideoCodec codec) { .pfnGetOperatingPoint = operating_point_handler, }; check_for_cuda_errors( - cuvidCreateVideoParser(&parser, &parserParams), __LINE__); + cuvidCreateVideoParser(&parser, &parserParams), __LINE__, __FILE__); } /* Destroy parser object and context lock. @@ -72,7 +72,7 @@ void Decoder::decode(const uint8_t* data, unsigned long size) { if (!data || size == 0) { pkt.flags |= CUVID_PKT_ENDOFSTREAM; } - check_for_cuda_errors(cuvidParseVideoData(parser, &pkt), __LINE__); + check_for_cuda_errors(cuvidParseVideoData(parser, &pkt), __LINE__, __FILE__); cuvidStream = 0; } @@ -96,9 +96,10 @@ int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { throw std::runtime_error("Uninitialised decoder."); } picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); - check_for_cuda_errors(cuvidDecodePicture(decoder, picParams), __LINE__); - check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + check_for_cuda_errors( + cuvidDecodePicture(decoder, picParams), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); return 1; } @@ -115,7 +116,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { CUdeviceptr dpSrcFrame = 0; unsigned int nSrcPitch = 0; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); check_for_cuda_errors( cuvidMapVideoFrame( decoder, @@ -123,7 +124,8 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { &dpSrcFrame, &nSrcPitch, &procParams), - __LINE__); + __LINE__, + __FILE__); CUVIDGETDECODESTATUS decodeStatus; memset(&decodeStatus, 0, sizeof(decodeStatus)); @@ -150,7 +152,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { m.dstPitch = get_width() * bytesPerPixel; m.WidthInBytes = get_width() * bytesPerPixel; m.Height = lumaHeight; - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); // Copy chroma plane // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning @@ -159,7 +161,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1)); m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * lumaHeight); m.Height = chromaHeight; - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); if (numChromaPlanes == 2) { m.srcDevice = @@ -167,13 +169,14 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * lumaHeight * 2); m.Height = chromaHeight; - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__); + check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); } - check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__); + check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__); decoded_frames.push(decoded_frame); - check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); - check_for_cuda_errors(cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__); + check_for_cuda_errors( + cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__, __FILE__); return 1; } @@ -186,9 +189,9 @@ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { .eChromaFormat = videoFormat->chroma_format, .nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8, }; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); - check_for_cuda_errors(cuvidGetDecoderCaps(&decodeCaps), __LINE__); - check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + check_for_cuda_errors(cuvidGetDecoderCaps(&decodeCaps), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); if (!decodeCaps.bIsSupported) { throw std::runtime_error("Codec not supported on this GPU"); @@ -326,10 +329,10 @@ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { displayRect.left = videoDecodeCreateInfo.display_area.left; displayRect.right = videoDecodeCreateInfo.display_area.right; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); check_for_cuda_errors( - cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__); - check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); + cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); return decodeSurface; } @@ -393,10 +396,10 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { reconfigParams.display_area.left = displayRect.left; reconfigParams.display_area.right = displayRect.right; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__); + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); check_for_cuda_errors( - cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__); - check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__); + cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); return decodeSurface; } diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index dc737e9a5f8..480f0fdef75 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -5,22 +5,24 @@ #include #include #include -#include -static auto check_for_cuda_errors = [](CUresult result, int lineNum) { - if (CUDA_SUCCESS != result) { - std::stringstream errorStream; - const char* errorName = nullptr; +static auto check_for_cuda_errors = + [](CUresult result, int line_num, std::string file_name) { + if (CUDA_SUCCESS != result) { + const char* error_name = nullptr; - errorStream << __FILE__ << ":" << lineNum << std::endl; - if (CUDA_SUCCESS != cuGetErrorName(result, &errorName)) { - errorStream << "CUDA error with code " << result << std::endl; - } else { - errorStream << "CUDA error: " << errorName << std::endl; - } - throw std::runtime_error(errorStream.str()); - } -}; + TORCH_CHECK( + CUDA_SUCCESS != cuGetErrorName(result, &error_name), + "CUDA error: ", + error_name, + " in ", + file_name, + " at line ", + line_num) + TORCH_CHECK( + false, "Error: ", result, " in ", file_name, " at line ", line_num); + } + }; struct Rect { int left, top, right, bottom; diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index fca416f1d7e..e6255aab5aa 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -6,7 +6,8 @@ GPUDecoder::GPUDecoder(std::string src_file, int64_t dev) : demuxer(src_file.c_str()), device(dev) { at::cuda::CUDAGuard device_guard(device); - check_for_cuda_errors(cuDevicePrimaryCtxRetain(&ctx, device), __LINE__); + check_for_cuda_errors( + cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__); decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec())); initialised = true; } @@ -15,7 +16,8 @@ GPUDecoder::~GPUDecoder() { at::cuda::CUDAGuard device_guard(device); decoder.release(); if (initialised) { - check_for_cuda_errors(cuDevicePrimaryCtxRelease(device), __LINE__); + check_for_cuda_errors( + cuDevicePrimaryCtxRelease(device), __LINE__, __FILE__); } } From 9d78ce5df47d9d7b734401460369147fedd80b01 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 23 Dec 2021 07:16:01 -0800 Subject: [PATCH 29/39] Use TORCHVISION_INCLUDE and TORCHVISION_LIBRARY to pass video codec location --- setup.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 5f375083f12..805a485597b 100644 --- a/setup.py +++ b/setup.py @@ -428,12 +428,15 @@ def get_extensions(): ) # Locating video codec - # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI + # CUDA_HOME should be set to the cuda root directory. + # TORCHVISION_INCLUDE and TORCHVISION_LIBRARY should include the location to + # video codec header files and libraries respectively. video_codec_found = ( extension is CUDAExtension and CUDA_HOME is not None - and os.path.exists("/usr/local/include/cuviddec.h") - and os.path.exists("/usr/local/include/nvcuvid.h") + and any([os.path.exists(os.path.join(folder, 'cuviddec.h')) for folder in vision_include]) + and any([os.path.exists(os.path.join(folder, 'nvcuvid.h')) for folder in vision_include]) + and any([os.path.exists(os.path.join(folder, 'libnvcuvid.so')) for folder in library_dirs]) ) print(f"video codec found: {video_codec_found}") @@ -449,7 +452,7 @@ def get_extensions(): "torchvision.Decoder", gpu_decoder_src, include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc], - library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs] + ["/usr/local/lib"], + library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs], libraries=[ "avcodec", "avformat", From f733a9790483dc7d0b4f76e1f60b18a209466b40 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 23 Dec 2021 08:25:24 -0800 Subject: [PATCH 30/39] Include ffmpeg_include_dir --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 805a485597b..bcf3b0ec8b9 100644 --- a/setup.py +++ b/setup.py @@ -451,7 +451,7 @@ def get_extensions(): extension( "torchvision.Decoder", gpu_decoder_src, - include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc], + include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir , library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs], libraries=[ "avcodec", From d794cb1a7dc1e23b73b550922b63fdc2fa38a8bf Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 23 Dec 2021 08:28:24 -0800 Subject: [PATCH 31/39] Remove space --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bcf3b0ec8b9..0e4e92c6615 100644 --- a/setup.py +++ b/setup.py @@ -451,7 +451,7 @@ def get_extensions(): extension( "torchvision.Decoder", gpu_decoder_src, - include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir , + include_dirs=include_dirs + [gpu_decoder_path] + [cuda_inc] + ffmpeg_include_dir, library_dirs=ffmpeg_library_dir + library_dirs + [cuda_libs], libraries=[ "avcodec", From 53a20b247da39839f597365c92a1dc4612c22b12 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Thu, 23 Dec 2021 14:46:29 -0800 Subject: [PATCH 32/39] Removed use of runtime_error --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 72 ++++++++++----------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index a13f1d10cfd..9125a14f287 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -93,7 +93,7 @@ torch::Tensor Decoder::fetch_frame() { */ int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { if (!decoder) { - throw std::runtime_error("Uninitialised decoder."); + TORCH_CHECK(false, "Uninitialised decoder"); } picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); @@ -194,51 +194,49 @@ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); if (!decodeCaps.bIsSupported) { - throw std::runtime_error("Codec not supported on this GPU"); + TORCH_CHECK(false, "Codec not supported on this GPU"); } if ((videoFormat->coded_width > decodeCaps.nMaxWidth) || (videoFormat->coded_height > decodeCaps.nMaxHeight)) { - std::ostringstream errorString; - errorString << std::endl - << "Resolution : " << videoFormat->coded_width << "x" - << videoFormat->coded_height << std::endl - << "Max Supported (wxh) : " << decodeCaps.nMaxWidth << "x" - << decodeCaps.nMaxHeight << std::endl - << "Resolution not supported on this GPU"; - - const std::string cErr = errorString.str(); - throw std::runtime_error(cErr); + TORCH_CHECK( + false, + "Resolution : ", + videoFormat->coded_width, + "x", + videoFormat->coded_height, + "\nMax Supported (wxh) : ", + decodeCaps.nMaxWidth, + "x", + decodeCaps.nMaxHeight, + "\nResolution not supported on this GPU"); } if ((videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4) > decodeCaps.nMaxMBCount) { - std::ostringstream errorString; - errorString << std::endl - << "MBCount : " - << (videoFormat->coded_width >> 4) * - (videoFormat->coded_height >> 4) - << std::endl - << "Max Supported mbcnt : " << decodeCaps.nMaxMBCount - << std::endl - << "MBCount not supported on this GPU"; - - const std::string cErr = errorString.str(); - throw std::runtime_error(cErr); + TORCH_CHECK( + false, + "MBCount : ", + (videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4), + "\nMax Supported mbcnt : ", + decodeCaps.nMaxMBCount, + "\nMBCount not supported on this GPU"); } // Check if output format supported. If not, check fallback options if (!(decodeCaps.nOutputFormatMask & (1 << videoOutputFormat))) { - if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) + if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) { videoOutputFormat = cudaVideoSurfaceFormat_NV12; - else if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) + } else if ( + decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) { videoOutputFormat = cudaVideoSurfaceFormat_P016; - else if ( - decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) + } else if ( + decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) { videoOutputFormat = cudaVideoSurfaceFormat_YUV444; - else if ( + } else if ( decodeCaps.nOutputFormatMask & - (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) + (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) { videoOutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit; - else - throw std::runtime_error("No supported output format found"); + } else { + TORCH_CHECK(false, "No supported output format found"); + } } } @@ -340,11 +338,10 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { if (vidFormat->bit_depth_luma_minus8 != videoFormat.bit_depth_luma_minus8 || vidFormat->bit_depth_chroma_minus8 != videoFormat.bit_depth_chroma_minus8) { - throw std::runtime_error("Reconfigure Not supported for bit depth change"); + TORCH_CHECK(false, "Reconfigure not supported for bit depth change"); } if (vidFormat->chroma_format != videoFormat.chroma_format) { - throw std::runtime_error( - "Reconfigure Not supported for chroma format change"); + TORCH_CHECK(false, "Reconfigure not supported for chroma format change"); } bool decodeResChange = @@ -363,8 +360,9 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { // For VP9, let driver handle the change if new width/height > // maxwidth/maxheight if (videoCodec != cudaVideoCodec_VP9) { - throw std::runtime_error( - "Reconfigure Not supported when width/height > maxwidth/maxheight"); + TORCH_CHECK( + false, + "Reconfigure not supported when width/height > maxwidth/maxheight"); } return 1; } From 7ca13b73b0f58b5a34edf8ae1650194d30794c23 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 07:27:07 -0800 Subject: [PATCH 33/39] Update Readme --- torchvision/csrc/io/decoder/gpu/README.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/README.rst b/torchvision/csrc/io/decoder/gpu/README.rst index b9d0947c6f3..cebd31cb557 100644 --- a/torchvision/csrc/io/decoder/gpu/README.rst +++ b/torchvision/csrc/io/decoder/gpu/README.rst @@ -4,14 +4,15 @@ GPU Decoder GPU decoder depends on ffmpeg for demuxing, uses NVDECODE APIs from the nvidia-video-codec sdk and uses cuda for processing on gpu. In order to use this, please follow the following steps: * Download the latest `nvidia-video-codec-sdk `_ -* Extract the zipped file and copy the header files and libraries. +* Extract the zipped file. +* Set TORCHVISION_INCLUDE environment variable to the location of the video codec headers(`nvcuvid.h` and `cuviddec.h`), which would be under `Interface` directory. +* Set TORCHVISION_LIBRARY environment variable to the location of the video codec library(`libnvcuvid.so`), which would be under `Lib/linux/stubs/x86_64` directory. +* Install the latest ffmpeg from `conda-forge` channel. .. code:: bash -sudo cp Interface/* /usr/local/include/ -sudo cp Lib/linux/stubs/x86_64/libnv* /usr/local/lib/ + conda install -c conda-forge ffmpeg -* Install ffmpeg and make sure ffmpeg headers and libraries are present under /usr/local/include and /usr/local/lib respectively. * Set CUDA_HOME environment variable to the cuda root directory. * Build torchvision from source: From 83ac2b1b8be724eb26c5bf4ea5fe55a4708c1107 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 08:26:16 -0800 Subject: [PATCH 34/39] Check for bsf.h --- setup.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 0e4e92c6615..feef44e6e79 100644 --- a/setup.py +++ b/setup.py @@ -434,13 +434,21 @@ def get_extensions(): video_codec_found = ( extension is CUDAExtension and CUDA_HOME is not None - and any([os.path.exists(os.path.join(folder, 'cuviddec.h')) for folder in vision_include]) - and any([os.path.exists(os.path.join(folder, 'nvcuvid.h')) for folder in vision_include]) - and any([os.path.exists(os.path.join(folder, 'libnvcuvid.so')) for folder in library_dirs]) + and any([os.path.exists(os.path.join(folder, "cuviddec.h")) for folder in vision_include]) + and any([os.path.exists(os.path.join(folder, "nvcuvid.h")) for folder in vision_include]) + and any([os.path.exists(os.path.join(folder, "libnvcuvid.so")) for folder in library_dirs]) ) print(f"video codec found: {video_codec_found}") + if not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]): + print( + "The installed version of ffmpeg is missing the header file 'bsf.h' which is " + "required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:" + " `conda install -c conda-forge ffmpeg`." + ) + has_ffmpeg = False + if video_codec_found and has_ffmpeg: gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) From d69f820bdc2daa1a1ff4d78677ce80e37ff00bff Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 09:22:27 -0800 Subject: [PATCH 35/39] Change struct initialisation style --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 162 ++++++++++---------- 1 file changed, 81 insertions(+), 81 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 9125a14f287..a04f9b9ff76 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -27,19 +27,19 @@ void Decoder::init(CUcontext context, cudaVideoCodec codec) { check_for_cuda_errors( cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__, __FILE__); - CUVIDPARSERPARAMS parserParams = { - .CodecType = codec, - .ulMaxNumDecodeSurfaces = 1, - .ulClockRate = 1000, - .ulMaxDisplayDelay = 0u, - .pUserData = this, - .pfnSequenceCallback = video_sequence_handler, - .pfnDecodePicture = picture_decode_handler, - .pfnDisplayPicture = picture_display_handler, - .pfnGetOperatingPoint = operating_point_handler, - }; + CUVIDPARSERPARAMS parser_params = {}; + parser_params.CodecType = codec; + parser_params.ulMaxNumDecodeSurfaces = 1; + parser_params.ulClockRate = 1000; + parser_params.ulMaxDisplayDelay = 0u; + parser_params.pUserData = this; + parser_params.pfnSequenceCallback = video_sequence_handler; + parser_params.pfnDecodePicture = picture_decode_handler; + parser_params.pfnDisplayPicture = picture_display_handler; + parser_params.pfnGetOperatingPoint = operating_point_handler; + check_for_cuda_errors( - cuvidCreateVideoParser(&parser, &parserParams), __LINE__, __FILE__); + cuvidCreateVideoParser(&parser, &parser_params), __LINE__, __FILE__); } /* Destroy parser object and context lock. @@ -64,11 +64,11 @@ void Decoder::release() { /* Trigger video decoding. */ void Decoder::decode(const uint8_t* data, unsigned long size) { - CUVIDSOURCEDATAPACKET pkt = { - .flags = CUVID_PKT_TIMESTAMP, - .payload_size = size, - .payload = data, - .timestamp = 0}; + CUVIDSOURCEDATAPACKET pkt = {}; + pkt.flags = CUVID_PKT_TIMESTAMP; + pkt.payload_size = size; + pkt.payload = data; + pkt.timestamp = 0; if (!data || size == 0) { pkt.flags |= CUVID_PKT_ENDOFSTREAM; } @@ -106,13 +106,12 @@ int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { /* Process the decoded data and copy it to a cuda memory location. */ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { - CUVIDPROCPARAMS procParams = { - .progressive_frame = dispInfo->progressive_frame, - .second_field = dispInfo->repeat_first_field + 1, - .top_field_first = dispInfo->top_field_first, - .unpaired_field = dispInfo->repeat_first_field < 0, - .output_stream = cuvidStream, - }; + CUVIDPROCPARAMS proc_params = {}; + proc_params.progressive_frame = dispInfo->progressive_frame; + proc_params.second_field = dispInfo->repeat_first_field + 1; + proc_params.top_field_first = dispInfo->top_field_first; + proc_params.unpaired_field = dispInfo->repeat_first_field < 0; + proc_params.output_stream = cuvidStream; CUdeviceptr dpSrcFrame = 0; unsigned int nSrcPitch = 0; @@ -123,7 +122,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { dispInfo->picture_index, &dpSrcFrame, &nSrcPitch, - &procParams), + &proc_params), __LINE__, __FILE__); @@ -184,20 +183,20 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { * verify if the hardware supports decoding the passed video. */ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { - CUVIDDECODECAPS decodeCaps = { - .eCodecType = videoFormat->codec, - .eChromaFormat = videoFormat->chroma_format, - .nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8, - }; + CUVIDDECODECAPS decode_caps = {}; + decode_caps.eCodecType = videoFormat->codec; + decode_caps.eChromaFormat = videoFormat->chroma_format; + decode_caps.nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8; + check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); - check_for_cuda_errors(cuvidGetDecoderCaps(&decodeCaps), __LINE__, __FILE__); + check_for_cuda_errors(cuvidGetDecoderCaps(&decode_caps), __LINE__, __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); - if (!decodeCaps.bIsSupported) { + if (!decode_caps.bIsSupported) { TORCH_CHECK(false, "Codec not supported on this GPU"); } - if ((videoFormat->coded_width > decodeCaps.nMaxWidth) || - (videoFormat->coded_height > decodeCaps.nMaxHeight)) { + if ((videoFormat->coded_width > decode_caps.nMaxWidth) || + (videoFormat->coded_height > decode_caps.nMaxHeight)) { TORCH_CHECK( false, "Resolution : ", @@ -205,33 +204,33 @@ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { "x", videoFormat->coded_height, "\nMax Supported (wxh) : ", - decodeCaps.nMaxWidth, + decode_caps.nMaxWidth, "x", - decodeCaps.nMaxHeight, + decode_caps.nMaxHeight, "\nResolution not supported on this GPU"); } if ((videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4) > - decodeCaps.nMaxMBCount) { + decode_caps.nMaxMBCount) { TORCH_CHECK( false, "MBCount : ", (videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4), "\nMax Supported mbcnt : ", - decodeCaps.nMaxMBCount, + decode_caps.nMaxMBCount, "\nMBCount not supported on this GPU"); } // Check if output format supported. If not, check fallback options - if (!(decodeCaps.nOutputFormatMask & (1 << videoOutputFormat))) { - if (decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) { + if (!(decode_caps.nOutputFormatMask & (1 << videoOutputFormat))) { + if (decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) { videoOutputFormat = cudaVideoSurfaceFormat_NV12; } else if ( - decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) { + decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) { videoOutputFormat = cudaVideoSurfaceFormat_P016; } else if ( - decodeCaps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) { + decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) { videoOutputFormat = cudaVideoSurfaceFormat_YUV444; } else if ( - decodeCaps.nOutputFormatMask & + decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) { videoOutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit; } else { @@ -283,21 +282,21 @@ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { deinterlaceMode = cudaVideoDeinterlaceMode_Weave; } - CUVIDDECODECREATEINFO videoDecodeCreateInfo = { - .ulWidth = vidFormat->coded_width, - .ulHeight = vidFormat->coded_height, - .ulNumDecodeSurfaces = decodeSurface, - .CodecType = vidFormat->codec, - .ChromaFormat = vidFormat->chroma_format, - // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded - // by NVDEC hardware - .ulCreationFlags = cudaVideoCreate_PreferCUVID, - .bitDepthMinus8 = vidFormat->bit_depth_luma_minus8, - .OutputFormat = videoOutputFormat, - .DeinterlaceMode = deinterlaceMode, - .ulNumOutputSurfaces = 2, - .vidLock = ctxLock, - }; + CUVIDDECODECREATEINFO video_decode_create_info = {}; + video_decode_create_info.ulWidth = vidFormat->coded_width; + video_decode_create_info.ulHeight = vidFormat->coded_height; + video_decode_create_info.ulNumDecodeSurfaces = decodeSurface; + video_decode_create_info.CodecType = vidFormat->codec; + video_decode_create_info.ChromaFormat = vidFormat->chroma_format; + // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded + // by NVDEC hardware + video_decode_create_info.ulCreationFlags = cudaVideoCreate_PreferCUVID; + video_decode_create_info.bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; + video_decode_create_info.OutputFormat = videoOutputFormat; + video_decode_create_info.DeinterlaceMode = deinterlaceMode; + video_decode_create_info.ulNumOutputSurfaces = 2; + video_decode_create_info.vidLock = ctxLock; + // AV1 has max width/height of sequence in sequence header if (vidFormat->codec == cudaVideoCodec_AV1 && vidFormat->seqhdr_data_length > 0) { @@ -311,25 +310,27 @@ int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { if (maxHeight < vidFormat->coded_height) { maxHeight = vidFormat->coded_height; } - videoDecodeCreateInfo.ulMaxWidth = maxWidth; - videoDecodeCreateInfo.ulMaxHeight = maxHeight; + video_decode_create_info.ulMaxWidth = maxWidth; + video_decode_create_info.ulMaxHeight = maxHeight; width = vidFormat->display_area.right - vidFormat->display_area.left; lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - videoDecodeCreateInfo.ulTargetWidth = vidFormat->coded_width; - videoDecodeCreateInfo.ulTargetHeight = vidFormat->coded_height; + video_decode_create_info.ulTargetWidth = vidFormat->coded_width; + video_decode_create_info.ulTargetHeight = vidFormat->coded_height; chromaHeight = (int)(ceil(lumaHeight * chroma_height_factor(videoOutputFormat))); numChromaPlanes = chroma_plane_count(videoOutputFormat); - surfaceHeight = videoDecodeCreateInfo.ulTargetHeight; - surfaceWidth = videoDecodeCreateInfo.ulTargetWidth; - displayRect.bottom = videoDecodeCreateInfo.display_area.bottom; - displayRect.top = videoDecodeCreateInfo.display_area.top; - displayRect.left = videoDecodeCreateInfo.display_area.left; - displayRect.right = videoDecodeCreateInfo.display_area.right; + surfaceHeight = video_decode_create_info.ulTargetHeight; + surfaceWidth = video_decode_create_info.ulTargetWidth; + displayRect.bottom = video_decode_create_info.display_area.bottom; + displayRect.top = video_decode_create_info.display_area.top; + displayRect.left = video_decode_create_info.display_area.left; + displayRect.right = video_decode_create_info.display_area.right; check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); check_for_cuda_errors( - cuvidCreateDecoder(&decoder, &videoDecodeCreateInfo), __LINE__, __FILE__); + cuvidCreateDecoder(&decoder, &video_decode_create_info), + __LINE__, + __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); return decodeSurface; } @@ -382,21 +383,20 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { } videoFormat.coded_width = vidFormat->coded_width; videoFormat.coded_height = vidFormat->coded_height; - CUVIDRECONFIGUREDECODERINFO reconfigParams = { - .ulWidth = vidFormat->coded_width, - .ulHeight = vidFormat->coded_height, - .ulTargetWidth = surfaceWidth, - .ulTargetHeight = surfaceHeight, - .ulNumDecodeSurfaces = decodeSurface, - }; - reconfigParams.display_area.bottom = displayRect.bottom; - reconfigParams.display_area.top = displayRect.top; - reconfigParams.display_area.left = displayRect.left; - reconfigParams.display_area.right = displayRect.right; + CUVIDRECONFIGUREDECODERINFO reconfig_params = {}; + reconfig_params.ulWidth = vidFormat->coded_width; + reconfig_params.ulHeight = vidFormat->coded_height; + reconfig_params.ulTargetWidth = surfaceWidth; + reconfig_params.ulTargetHeight = surfaceHeight; + reconfig_params.ulNumDecodeSurfaces = decodeSurface; + reconfig_params.display_area.bottom = displayRect.bottom; + reconfig_params.display_area.top = displayRect.top; + reconfig_params.display_area.left = displayRect.left; + reconfig_params.display_area.right = displayRect.right; check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); check_for_cuda_errors( - cuvidReconfigureDecoder(decoder, &reconfigParams), __LINE__, __FILE__); + cuvidReconfigureDecoder(decoder, &reconfig_params), __LINE__, __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); return decodeSurface; From 5c5162e473bf694a7cae4890f9e424e0b9622e52 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 10:04:43 -0800 Subject: [PATCH 36/39] Clean-up get_operating_point --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 16 +++++----------- torchvision/csrc/io/decoder/gpu/decoder.h | 3 +-- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index a04f9b9ff76..0d0bc06849a 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -402,17 +402,11 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { return decodeSurface; } -/* Called from AV1 sequence header to get operating point of a AV1 bitstream. +/* Called from AV1 sequence header to get operating point of an AV1 bitstream. */ int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* operPointInfo) { - if (operPointInfo->codec == cudaVideoCodec_AV1) { - if (operPointInfo->av1.operating_points_cnt > 1) { - // clip has SVC enabled - if (operatingPoint >= operPointInfo->av1.operating_points_cnt) { - operatingPoint = 0; - } - return (operatingPoint | (dispAllLayers << 10)); - } - } - return -1; + return operPointInfo->codec == cudaVideoCodec_AV1 && + operPointInfo->av1.operating_points_cnt > 1 + ? 0 + : -1; } diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 480f0fdef75..3697c81712c 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -51,11 +51,10 @@ class Decoder { } private: - bool dispAllLayers = false; unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; unsigned int surfaceHeight = 0, surfaceWidth = 0; unsigned int maxWidth = 0, maxHeight = 0; - unsigned int operatingPoint = 0, numChromaPlanes = 0; + unsigned int numChromaPlanes = 0; int bitDepthMinus8 = 0, bytesPerPixel = 1; int decodePicCount = 0, picNumInDecodeOrder[32]; std::queue decoded_frames; From 83d84b056938293f8e9870f39152c2a11d2e892e Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 11:09:16 -0800 Subject: [PATCH 37/39] Make variable naming convention uniform --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 319 ++++++++++---------- torchvision/csrc/io/decoder/gpu/decoder.h | 36 +-- 2 files changed, 180 insertions(+), 175 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 0d0bc06849a..4471fd6b783 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -4,28 +4,28 @@ #include #include -static float chroma_height_factor(cudaVideoSurfaceFormat surfaceFormat) { - return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || - surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) +static float chroma_height_factor(cudaVideoSurfaceFormat surface_format) { + return (surface_format == cudaVideoSurfaceFormat_YUV444 || + surface_format == cudaVideoSurfaceFormat_YUV444_16Bit) ? 1.0 : 0.5; } -static int chroma_plane_count(cudaVideoSurfaceFormat surfaceFormat) { - return (surfaceFormat == cudaVideoSurfaceFormat_YUV444 || - surfaceFormat == cudaVideoSurfaceFormat_YUV444_16Bit) +static int chroma_plane_count(cudaVideoSurfaceFormat surface_format) { + return (surface_format == cudaVideoSurfaceFormat_YUV444 || + surface_format == cudaVideoSurfaceFormat_YUV444_16Bit) ? 2 : 1; } -/* Initialise cuContext and videoCodec, create context lock and create parser +/* Initialise cu_context and video_codec, create context lock and create parser * object. */ void Decoder::init(CUcontext context, cudaVideoCodec codec) { - cuContext = context; - videoCodec = codec; + cu_context = context; + video_codec = codec; check_for_cuda_errors( - cuvidCtxLockCreate(&ctxLock, cuContext), __LINE__, __FILE__); + cuvidCtxLockCreate(&ctx_lock, cu_context), __LINE__, __FILE__); CUVIDPARSERPARAMS parser_params = {}; parser_params.CodecType = codec; @@ -48,13 +48,13 @@ Decoder::~Decoder() { if (parser) { cuvidDestroyVideoParser(parser); } - cuvidCtxLockDestroy(ctxLock); + cuvidCtxLockDestroy(ctx_lock); } /* Destroy CUvideodecoder object and free up all the unreturned decoded frames. */ void Decoder::release() { - cuCtxPushCurrent(cuContext); + cuCtxPushCurrent(cu_context); if (decoder) { cuvidDestroyDecoder(decoder); } @@ -91,50 +91,50 @@ torch::Tensor Decoder::fetch_frame() { /* Called when a picture is ready to be decoded. */ -int Decoder::handle_picture_decode(CUVIDPICPARAMS* picParams) { +int Decoder::handle_picture_decode(CUVIDPICPARAMS* pic_params) { if (!decoder) { TORCH_CHECK(false, "Uninitialised decoder"); } - picNumInDecodeOrder[picParams->CurrPicIdx] = decodePicCount++; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + pic_num_in_decode_order[pic_params->CurrPicIdx] = decode_pic_count++; + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); check_for_cuda_errors( - cuvidDecodePicture(decoder, picParams), __LINE__, __FILE__); + cuvidDecodePicture(decoder, pic_params), __LINE__, __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); return 1; } /* Process the decoded data and copy it to a cuda memory location. */ -int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { +int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) { CUVIDPROCPARAMS proc_params = {}; - proc_params.progressive_frame = dispInfo->progressive_frame; - proc_params.second_field = dispInfo->repeat_first_field + 1; - proc_params.top_field_first = dispInfo->top_field_first; - proc_params.unpaired_field = dispInfo->repeat_first_field < 0; + proc_params.progressive_frame = disp_info->progressive_frame; + proc_params.second_field = disp_info->repeat_first_field + 1; + proc_params.top_field_first = disp_info->top_field_first; + proc_params.unpaired_field = disp_info->repeat_first_field < 0; proc_params.output_stream = cuvidStream; - CUdeviceptr dpSrcFrame = 0; - unsigned int nSrcPitch = 0; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + CUdeviceptr source_frame = 0; + unsigned int source_pitch = 0; + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); check_for_cuda_errors( cuvidMapVideoFrame( decoder, - dispInfo->picture_index, - &dpSrcFrame, - &nSrcPitch, + disp_info->picture_index, + &source_frame, + &source_pitch, &proc_params), __LINE__, __FILE__); - CUVIDGETDECODESTATUS decodeStatus; - memset(&decodeStatus, 0, sizeof(decodeStatus)); + CUVIDGETDECODESTATUS decode_status; + memset(&decode_status, 0, sizeof(decode_status)); CUresult result = - cuvidGetDecodeStatus(decoder, dispInfo->picture_index, &decodeStatus); + cuvidGetDecodeStatus(decoder, disp_info->picture_index, &decode_status); if (result == CUDA_SUCCESS && - (decodeStatus.decodeStatus == cuvidDecodeStatus_Error || - decodeStatus.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { + (decode_status.decodeStatus == cuvidDecodeStatus_Error || + decode_status.decodeStatus == cuvidDecodeStatus_Error_Concealed)) { VLOG(1) << "Decode Error occurred for picture " - << picNumInDecodeOrder[dispInfo->picture_index]; + << pic_num_in_decode_order[disp_info->picture_index]; } auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); @@ -144,30 +144,30 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { // Copy luma plane CUDA_MEMCPY2D m = {0}; m.srcMemoryType = CU_MEMORYTYPE_DEVICE; - m.srcDevice = dpSrcFrame; - m.srcPitch = nSrcPitch; + m.srcDevice = source_frame; + m.srcPitch = source_pitch; m.dstMemoryType = CU_MEMORYTYPE_DEVICE; m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr); - m.dstPitch = get_width() * bytesPerPixel; - m.WidthInBytes = get_width() * bytesPerPixel; - m.Height = lumaHeight; + m.dstPitch = get_width() * bytes_per_pixel; + m.WidthInBytes = get_width() * bytes_per_pixel; + m.Height = luma_height; check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); // Copy chroma plane // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning // height m.srcDevice = - (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1)); - m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * lumaHeight); - m.Height = chromaHeight; + (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1)); + m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height); + m.Height = chroma_height; check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); - if (numChromaPlanes == 2) { + if (num_chroma_planes == 2) { m.srcDevice = - (CUdeviceptr)((uint8_t*)dpSrcFrame + m.srcPitch * ((surfaceHeight + 1) & ~1) * 2); + (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2); m.dstDevice = - (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * lumaHeight * 2); - m.Height = chromaHeight; + (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2); + m.Height = chroma_height; check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); } check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__); @@ -175,64 +175,64 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* dispInfo) { check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); check_for_cuda_errors( - cuvidUnmapVideoFrame(decoder, dpSrcFrame), __LINE__, __FILE__); + cuvidUnmapVideoFrame(decoder, source_frame), __LINE__, __FILE__); return 1; } /* Query the capabilities of the underlying hardware video decoder and * verify if the hardware supports decoding the passed video. */ -void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { +void Decoder::query_hardware(CUVIDEOFORMAT* video_format) { CUVIDDECODECAPS decode_caps = {}; - decode_caps.eCodecType = videoFormat->codec; - decode_caps.eChromaFormat = videoFormat->chroma_format; - decode_caps.nBitDepthMinus8 = videoFormat->bit_depth_luma_minus8; + decode_caps.eCodecType = video_format->codec; + decode_caps.eChromaFormat = video_format->chroma_format; + decode_caps.nBitDepthMinus8 = video_format->bit_depth_luma_minus8; - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); check_for_cuda_errors(cuvidGetDecoderCaps(&decode_caps), __LINE__, __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); if (!decode_caps.bIsSupported) { TORCH_CHECK(false, "Codec not supported on this GPU"); } - if ((videoFormat->coded_width > decode_caps.nMaxWidth) || - (videoFormat->coded_height > decode_caps.nMaxHeight)) { + if ((video_format->coded_width > decode_caps.nMaxWidth) || + (video_format->coded_height > decode_caps.nMaxHeight)) { TORCH_CHECK( false, "Resolution : ", - videoFormat->coded_width, + video_format->coded_width, "x", - videoFormat->coded_height, + video_format->coded_height, "\nMax Supported (wxh) : ", decode_caps.nMaxWidth, "x", decode_caps.nMaxHeight, "\nResolution not supported on this GPU"); } - if ((videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4) > + if ((video_format->coded_width >> 4) * (video_format->coded_height >> 4) > decode_caps.nMaxMBCount) { TORCH_CHECK( false, "MBCount : ", - (videoFormat->coded_width >> 4) * (videoFormat->coded_height >> 4), + (video_format->coded_width >> 4) * (video_format->coded_height >> 4), "\nMax Supported mbcnt : ", decode_caps.nMaxMBCount, "\nMBCount not supported on this GPU"); } // Check if output format supported. If not, check fallback options - if (!(decode_caps.nOutputFormatMask & (1 << videoOutputFormat))) { + if (!(decode_caps.nOutputFormatMask & (1 << video_output_format))) { if (decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_NV12)) { - videoOutputFormat = cudaVideoSurfaceFormat_NV12; + video_output_format = cudaVideoSurfaceFormat_NV12; } else if ( decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_P016)) { - videoOutputFormat = cudaVideoSurfaceFormat_P016; + video_output_format = cudaVideoSurfaceFormat_P016; } else if ( decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444)) { - videoOutputFormat = cudaVideoSurfaceFormat_YUV444; + video_output_format = cudaVideoSurfaceFormat_YUV444; } else if ( decode_caps.nOutputFormatMask & (1 << cudaVideoSurfaceFormat_YUV444_16Bit)) { - videoOutputFormat = cudaVideoSurfaceFormat_YUV444_16Bit; + video_output_format = cudaVideoSurfaceFormat_YUV444_16Bit; } else { TORCH_CHECK(false, "No supported output format found"); } @@ -242,125 +242,128 @@ void Decoder::query_hardware(CUVIDEOFORMAT* videoFormat) { /* Called before decoding frames and/or whenever there is a configuration * change. */ -int Decoder::handle_video_sequence(CUVIDEOFORMAT* vidFormat) { - // videoCodec has been set in the init(). Here it's set +int Decoder::handle_video_sequence(CUVIDEOFORMAT* video_format) { + // video_codec has been set in init(). Here it's set // again for potential correction. - videoCodec = vidFormat->codec; - videoChromaFormat = vidFormat->chroma_format; - bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; - bytesPerPixel = bitDepthMinus8 > 0 ? 2 : 1; + video_codec = video_format->codec; + video_chroma_format = video_format->chroma_format; + bit_depth_minus8 = video_format->bit_depth_luma_minus8; + bytes_per_pixel = bit_depth_minus8 > 0 ? 2 : 1; // Set the output surface format same as chroma format - switch (videoChromaFormat) { + switch (video_chroma_format) { case cudaVideoChromaFormat_Monochrome: case cudaVideoChromaFormat_420: - videoOutputFormat = vidFormat->bit_depth_luma_minus8 + video_output_format = video_format->bit_depth_luma_minus8 ? cudaVideoSurfaceFormat_P016 : cudaVideoSurfaceFormat_NV12; break; case cudaVideoChromaFormat_444: - videoOutputFormat = vidFormat->bit_depth_luma_minus8 + video_output_format = video_format->bit_depth_luma_minus8 ? cudaVideoSurfaceFormat_YUV444_16Bit : cudaVideoSurfaceFormat_YUV444; break; case cudaVideoChromaFormat_422: - videoOutputFormat = cudaVideoSurfaceFormat_NV12; + video_output_format = cudaVideoSurfaceFormat_NV12; } - query_hardware(vidFormat); + query_hardware(video_format); - if (width && lumaHeight && chromaHeight) { + if (width && luma_height && chroma_height) { // cuvidCreateDecoder() has been called before and now there's possible // config change. - return reconfigure_decoder(vidFormat); + return reconfigure_decoder(video_format); } - videoFormat = *vidFormat; - unsigned long decodeSurface = vidFormat->min_num_decode_surfaces; - cudaVideoDeinterlaceMode deinterlaceMode = cudaVideoDeinterlaceMode_Adaptive; + cu_video_format = *video_format; + unsigned long decode_surface = video_format->min_num_decode_surfaces; + cudaVideoDeinterlaceMode deinterlace_mode = cudaVideoDeinterlaceMode_Adaptive; - if (vidFormat->progressive_sequence) { - deinterlaceMode = cudaVideoDeinterlaceMode_Weave; + if (video_format->progressive_sequence) { + deinterlace_mode = cudaVideoDeinterlaceMode_Weave; } CUVIDDECODECREATEINFO video_decode_create_info = {}; - video_decode_create_info.ulWidth = vidFormat->coded_width; - video_decode_create_info.ulHeight = vidFormat->coded_height; - video_decode_create_info.ulNumDecodeSurfaces = decodeSurface; - video_decode_create_info.CodecType = vidFormat->codec; - video_decode_create_info.ChromaFormat = vidFormat->chroma_format; + video_decode_create_info.ulWidth = video_format->coded_width; + video_decode_create_info.ulHeight = video_format->coded_height; + video_decode_create_info.ulNumDecodeSurfaces = decode_surface; + video_decode_create_info.CodecType = video_format->codec; + video_decode_create_info.ChromaFormat = video_format->chroma_format; // With PreferCUVID, JPEG is still decoded by CUDA while video is decoded // by NVDEC hardware video_decode_create_info.ulCreationFlags = cudaVideoCreate_PreferCUVID; - video_decode_create_info.bitDepthMinus8 = vidFormat->bit_depth_luma_minus8; - video_decode_create_info.OutputFormat = videoOutputFormat; - video_decode_create_info.DeinterlaceMode = deinterlaceMode; + video_decode_create_info.bitDepthMinus8 = video_format->bit_depth_luma_minus8; + video_decode_create_info.OutputFormat = video_output_format; + video_decode_create_info.DeinterlaceMode = deinterlace_mode; video_decode_create_info.ulNumOutputSurfaces = 2; - video_decode_create_info.vidLock = ctxLock; + video_decode_create_info.vidLock = ctx_lock; // AV1 has max width/height of sequence in sequence header - if (vidFormat->codec == cudaVideoCodec_AV1 && - vidFormat->seqhdr_data_length > 0) { - CUVIDEOFORMATEX* vidFormatEx = (CUVIDEOFORMATEX*)vidFormat; - maxWidth = vidFormatEx->av1.max_width; - maxHeight = vidFormatEx->av1.max_height; + if (video_format->codec == cudaVideoCodec_AV1 && + video_format->seqhdr_data_length > 0) { + CUVIDEOFORMATEX* video_format_ex = (CUVIDEOFORMATEX*)video_format; + max_width = video_format_ex->av1.max_width; + max_height = video_format_ex->av1.max_height; } - if (maxWidth < vidFormat->coded_width) { - maxWidth = vidFormat->coded_width; + if (max_width < video_format->coded_width) { + max_width = video_format->coded_width; } - if (maxHeight < vidFormat->coded_height) { - maxHeight = vidFormat->coded_height; + if (max_height < video_format->coded_height) { + max_height = video_format->coded_height; } - video_decode_create_info.ulMaxWidth = maxWidth; - video_decode_create_info.ulMaxHeight = maxHeight; - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - video_decode_create_info.ulTargetWidth = vidFormat->coded_width; - video_decode_create_info.ulTargetHeight = vidFormat->coded_height; - chromaHeight = - (int)(ceil(lumaHeight * chroma_height_factor(videoOutputFormat))); - numChromaPlanes = chroma_plane_count(videoOutputFormat); - surfaceHeight = video_decode_create_info.ulTargetHeight; - surfaceWidth = video_decode_create_info.ulTargetWidth; - displayRect.bottom = video_decode_create_info.display_area.bottom; - displayRect.top = video_decode_create_info.display_area.top; - displayRect.left = video_decode_create_info.display_area.left; - displayRect.right = video_decode_create_info.display_area.right; - - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + video_decode_create_info.ulMaxWidth = max_width; + video_decode_create_info.ulMaxHeight = max_height; + width = video_format->display_area.right - video_format->display_area.left; + luma_height = + video_format->display_area.bottom - video_format->display_area.top; + video_decode_create_info.ulTargetWidth = video_format->coded_width; + video_decode_create_info.ulTargetHeight = video_format->coded_height; + chroma_height = + (int)(ceil(luma_height * chroma_height_factor(video_output_format))); + num_chroma_planes = chroma_plane_count(video_output_format); + surface_height = video_decode_create_info.ulTargetHeight; + surface_width = video_decode_create_info.ulTargetWidth; + display_rect.bottom = video_decode_create_info.display_area.bottom; + display_rect.top = video_decode_create_info.display_area.top; + display_rect.left = video_decode_create_info.display_area.left; + display_rect.right = video_decode_create_info.display_area.right; + + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); check_for_cuda_errors( cuvidCreateDecoder(&decoder, &video_decode_create_info), __LINE__, __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); - return decodeSurface; + return decode_surface; } -int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { - if (vidFormat->bit_depth_luma_minus8 != videoFormat.bit_depth_luma_minus8 || - vidFormat->bit_depth_chroma_minus8 != - videoFormat.bit_depth_chroma_minus8) { +int Decoder::reconfigure_decoder(CUVIDEOFORMAT* video_format) { + if (video_format->bit_depth_luma_minus8 != + cu_video_format.bit_depth_luma_minus8 || + video_format->bit_depth_chroma_minus8 != + cu_video_format.bit_depth_chroma_minus8) { TORCH_CHECK(false, "Reconfigure not supported for bit depth change"); } - if (vidFormat->chroma_format != videoFormat.chroma_format) { + if (video_format->chroma_format != cu_video_format.chroma_format) { TORCH_CHECK(false, "Reconfigure not supported for chroma format change"); } - bool decodeResChange = - !(vidFormat->coded_width == videoFormat.coded_width && - vidFormat->coded_height == videoFormat.coded_height); - bool displayRectChange = - !(vidFormat->display_area.bottom == videoFormat.display_area.bottom && - vidFormat->display_area.top == videoFormat.display_area.top && - vidFormat->display_area.left == videoFormat.display_area.left && - vidFormat->display_area.right == videoFormat.display_area.right); + bool decode_res_change = + !(video_format->coded_width == cu_video_format.coded_width && + video_format->coded_height == cu_video_format.coded_height); + bool display_rect_change = + !(video_format->display_area.bottom == + cu_video_format.display_area.bottom && + video_format->display_area.top == cu_video_format.display_area.top && + video_format->display_area.left == cu_video_format.display_area.left && + video_format->display_area.right == cu_video_format.display_area.right); - unsigned int decodeSurface = vidFormat->min_num_decode_surfaces; + unsigned int decode_surface = video_format->min_num_decode_surfaces; - if ((vidFormat->coded_width > maxWidth) || - (vidFormat->coded_height > maxHeight)) { + if ((video_format->coded_width > max_width) || + (video_format->coded_height > max_height)) { // For VP9, let driver handle the change if new width/height > // maxwidth/maxheight - if (videoCodec != cudaVideoCodec_VP9) { + if (video_codec != cudaVideoCodec_VP9) { TORCH_CHECK( false, "Reconfigure not supported when width/height > maxwidth/maxheight"); @@ -368,45 +371,47 @@ int Decoder::reconfigure_decoder(CUVIDEOFORMAT* vidFormat) { return 1; } - if (!decodeResChange) { + if (!decode_res_change) { // If the coded_width/coded_height hasn't changed but display resolution has // changed, then need to update width/height for correct output without // cropping. Example : 1920x1080 vs 1920x1088. - if (displayRectChange) { - width = vidFormat->display_area.right - vidFormat->display_area.left; - lumaHeight = vidFormat->display_area.bottom - vidFormat->display_area.top; - chromaHeight = - (int)ceil(lumaHeight * chroma_height_factor(videoOutputFormat)); - numChromaPlanes = chroma_plane_count(videoOutputFormat); + if (display_rect_change) { + width = + video_format->display_area.right - video_format->display_area.left; + luma_height = + video_format->display_area.bottom - video_format->display_area.top; + chroma_height = + (int)ceil(luma_height * chroma_height_factor(video_output_format)); + num_chroma_planes = chroma_plane_count(video_output_format); } return 1; } - videoFormat.coded_width = vidFormat->coded_width; - videoFormat.coded_height = vidFormat->coded_height; + cu_video_format.coded_width = video_format->coded_width; + cu_video_format.coded_height = video_format->coded_height; CUVIDRECONFIGUREDECODERINFO reconfig_params = {}; - reconfig_params.ulWidth = vidFormat->coded_width; - reconfig_params.ulHeight = vidFormat->coded_height; - reconfig_params.ulTargetWidth = surfaceWidth; - reconfig_params.ulTargetHeight = surfaceHeight; - reconfig_params.ulNumDecodeSurfaces = decodeSurface; - reconfig_params.display_area.bottom = displayRect.bottom; - reconfig_params.display_area.top = displayRect.top; - reconfig_params.display_area.left = displayRect.left; - reconfig_params.display_area.right = displayRect.right; - - check_for_cuda_errors(cuCtxPushCurrent(cuContext), __LINE__, __FILE__); + reconfig_params.ulWidth = video_format->coded_width; + reconfig_params.ulHeight = video_format->coded_height; + reconfig_params.ulTargetWidth = surface_width; + reconfig_params.ulTargetHeight = surface_height; + reconfig_params.ulNumDecodeSurfaces = decode_surface; + reconfig_params.display_area.bottom = display_rect.bottom; + reconfig_params.display_area.top = display_rect.top; + reconfig_params.display_area.left = display_rect.left; + reconfig_params.display_area.right = display_rect.right; + + check_for_cuda_errors(cuCtxPushCurrent(cu_context), __LINE__, __FILE__); check_for_cuda_errors( cuvidReconfigureDecoder(decoder, &reconfig_params), __LINE__, __FILE__); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); - return decodeSurface; + return decode_surface; } /* Called from AV1 sequence header to get operating point of an AV1 bitstream. */ -int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* operPointInfo) { - return operPointInfo->codec == cudaVideoCodec_AV1 && - operPointInfo->av1.operating_points_cnt > 1 +int Decoder::get_operating_point(CUVIDOPERATINGPOINTINFO* oper_point_info) { + return oper_point_info->codec == cudaVideoCodec_AV1 && + oper_point_info->av1.operating_points_cnt > 1 ? 0 : -1; } diff --git a/torchvision/csrc/io/decoder/gpu/decoder.h b/torchvision/csrc/io/decoder/gpu/decoder.h index 3697c81712c..c3064eb1663 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.h +++ b/torchvision/csrc/io/decoder/gpu/decoder.h @@ -37,37 +37,37 @@ class Decoder { void decode(const uint8_t*, unsigned long); torch::Tensor fetch_frame(); int get_frame_size() const { - return get_width() * (lumaHeight + (chromaHeight * numChromaPlanes)) * - bytesPerPixel; + return get_width() * (luma_height + (chroma_height * num_chroma_planes)) * + bytes_per_pixel; } int get_width() const { - return (videoOutputFormat == cudaVideoSurfaceFormat_NV12 || - videoOutputFormat == cudaVideoSurfaceFormat_P016) + return (video_output_format == cudaVideoSurfaceFormat_NV12 || + video_output_format == cudaVideoSurfaceFormat_P016) ? (width + 1) & ~1 : width; } int get_height() const { - return lumaHeight; + return luma_height; } private: - unsigned int width = 0, lumaHeight = 0, chromaHeight = 0; - unsigned int surfaceHeight = 0, surfaceWidth = 0; - unsigned int maxWidth = 0, maxHeight = 0; - unsigned int numChromaPlanes = 0; - int bitDepthMinus8 = 0, bytesPerPixel = 1; - int decodePicCount = 0, picNumInDecodeOrder[32]; + unsigned int width = 0, luma_height = 0, chroma_height = 0; + unsigned int surface_height = 0, surface_width = 0; + unsigned int max_width = 0, max_height = 0; + unsigned int num_chroma_planes = 0; + int bit_depth_minus8 = 0, bytes_per_pixel = 1; + int decode_pic_count = 0, pic_num_in_decode_order[32]; std::queue decoded_frames; - CUcontext cuContext = NULL; - CUvideoctxlock ctxLock; + CUcontext cu_context = NULL; + CUvideoctxlock ctx_lock; CUvideoparser parser = NULL; CUvideodecoder decoder = NULL; CUstream cuvidStream = 0; - cudaVideoCodec videoCodec = cudaVideoCodec_NumCodecs; - cudaVideoChromaFormat videoChromaFormat = cudaVideoChromaFormat_420; - cudaVideoSurfaceFormat videoOutputFormat = cudaVideoSurfaceFormat_NV12; - CUVIDEOFORMAT videoFormat = {}; - Rect displayRect = {}; + cudaVideoCodec video_codec = cudaVideoCodec_NumCodecs; + cudaVideoChromaFormat video_chroma_format = cudaVideoChromaFormat_420; + cudaVideoSurfaceFormat video_output_format = cudaVideoSurfaceFormat_NV12; + CUVIDEOFORMAT cu_video_format = {}; + Rect display_rect = {}; static int video_sequence_handler( void* user_data, From d8d0fb5635274e67217eec8b027a252a72b3a402 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 11:14:29 -0800 Subject: [PATCH 38/39] Move checking for bsf.h around --- setup.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index feef44e6e79..9678a2c17b7 100644 --- a/setup.py +++ b/setup.py @@ -441,15 +441,11 @@ def get_extensions(): print(f"video codec found: {video_codec_found}") - if not any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]): - print( - "The installed version of ffmpeg is missing the header file 'bsf.h' which is " - "required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:" - " `conda install -c conda-forge ffmpeg`." - ) - has_ffmpeg = False - - if video_codec_found and has_ffmpeg: + if ( + video_codec_found + and has_ffmpeg + and any([os.path.exists(os.path.join(folder, "libavcodec", "bsf.h")) for folder in ffmpeg_include_dir]) + ): gpu_decoder_path = os.path.join(extensions_dir, "io", "decoder", "gpu") gpu_decoder_src = glob.glob(os.path.join(gpu_decoder_path, "*.cpp")) cuda_libs = os.path.join(CUDA_HOME, "lib64") @@ -477,6 +473,12 @@ def get_extensions(): extra_compile_args=extra_compile_args, ) ) + else: + print( + "The installed version of ffmpeg is missing the header file 'bsf.h' which is " + "required for GPU video decoding. Please install the latest ffmpeg from conda-forge channel:" + " `conda install -c conda-forge ffmpeg`." + ) return ext_modules From 559639e855e264f10f26740f34181d5278ee3d0e Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Fri, 24 Dec 2021 11:19:16 -0800 Subject: [PATCH 39/39] Fix linter error --- torchvision/io/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index a745aee8607..410ec5bfc2c 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -40,6 +40,7 @@ def _has_video_opt() -> bool: return True + else: def _has_video_opt() -> bool: