Skip to content

Commit d102d39

Browse files
committed
Add support for png decoding on linux
1 parent e2573a7 commit d102d39

File tree

14 files changed

+266
-31
lines changed

14 files changed

+266
-31
lines changed

.circleci/config.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ commands:
2121
description: "checkout merge branch"
2222
steps:
2323
- checkout
24+
- run:
25+
name: initialize submodules
26+
command: git submodule update --init --recursive
2427
# - run:
2528
# name: Checkout merge branch
2629
# command: |
@@ -83,6 +86,8 @@ jobs:
8386
resource_class: 2xlarge+
8487
steps:
8588
- checkout_merge
89+
- run:
90+
command: yum install -yq zlib-devel
8691
- run: packaging/build_wheel.sh
8792
- store_artifacts:
8893
path: dist
@@ -128,7 +133,7 @@ jobs:
128133
ca-certificates \
129134
curl \
130135
gnupg-agent \
131-
software-properties-common
136+
software-properties-common
132137
133138
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
134139

.circleci/config.yml.in

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ commands:
2121
description: "checkout merge branch"
2222
steps:
2323
- checkout
24+
- run:
25+
name: initialize submodules
26+
command: git submodule update --init --recursive
2427
# - run:
2528
# name: Checkout merge branch
2629
# command: |
@@ -83,6 +86,8 @@ jobs:
8386
resource_class: 2xlarge+
8487
steps:
8588
- checkout_merge
89+
- run:
90+
command: yum install -yq zlib-devel
8691
- run: packaging/build_wheel.sh
8792
- store_artifacts:
8893
path: dist
@@ -128,7 +133,7 @@ jobs:
128133
ca-certificates \
129134
curl \
130135
gnupg-agent \
131-
software-properties-common
136+
software-properties-common
132137

133138
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
134139

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/libpng"]
2+
path = third_party/libpng
3+
url = https://github.com/glennrp/libpng

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ matrix:
2626

2727
before_install:
2828
- sudo apt-get update
29+
- sudo apt-get install -y zlib1g-dev
2930
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
3031
- bash miniconda.sh -b -p $HOME/miniconda
3132
- export PATH="$HOME/miniconda/bin:$PATH"

CMakeLists.txt

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ if(WITH_CUDA)
1010
add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
1111
endif()
1212

13+
if(Unix)
14+
add_subdirectory("third_party/libpng")
15+
endif()
16+
1317
find_package(Torch REQUIRED)
1418
find_package(pybind11 REQUIRED)
1519

@@ -21,8 +25,17 @@ endif()
2125
file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h)
2226
file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp)
2327

24-
add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES})
25-
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} pybind11::pybind11)
28+
file(GLOB IMAGE_HEADERS torchvision/csrc/image.h)
29+
file(GLOB IMAGE_SOURCES torchvision/csrc/cpu/image/*.h torchvision/csrc/cpu/image/*.cpp)
30+
31+
if(Unix)
32+
add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES} {IMAGE_SOURCES})
33+
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} pybind11::pybind11 "${PNG_LIBRARIES}")
34+
else()
35+
add_library(${PROJECT_NAME} SHARED ${MODELS_SOURCES} ${OPERATOR_SOURCES})
36+
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} pybind11::pybind11)
37+
endif()
38+
2639
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchVision)
2740

2841
target_include_directories(${PROJECT_NAME} INTERFACE
@@ -49,7 +62,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/TorchVisionConfig.cmake
4962
install(TARGETS ${PROJECT_NAME}
5063
EXPORT TorchVisionTargets)
5164

52-
install(EXPORT TorchVisionTargets
65+
install(EXPORT TorchVisionTargets
5366
NAMESPACE TorchVision::
5467
DESTINATION ${TORCHVISION_CMAKECONFIG_INSTALL_DIR})
5568

packaging/torchvision/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ source:
88
requirements:
99
build:
1010
- {{ compiler('c') }} # [win]
11+
- zlib
1112

1213
host:
1314
- python

setup.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,21 @@ def get_extensions():
8383

8484
main_file = glob.glob(os.path.join(extensions_dir, '*.cpp'))
8585
source_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', '*.cpp'))
86+
source_image_cpu = glob.glob(os.path.join(extensions_dir, 'cpu', 'image', '*.cpp'))
8687
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
8788

8889
sources = main_file + source_cpu
90+
91+
libraries = []
92+
extra_objects= []
93+
extra_compile_args = {}
94+
third_party_search_directories = []
95+
96+
if sys.platform.startswith('linux'):
97+
sources = sources + source_image_cpu
98+
libraries.append('png')
99+
third_party_search_directories.append(os.path.join(cwd, "third_party/libpng"))
100+
89101
extension = CppExtension
90102

91103
compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1'
@@ -102,7 +114,6 @@ def get_extensions():
102114

103115
define_macros = []
104116

105-
extra_compile_args = {}
106117
if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv('FORCE_CUDA', '0') == '1':
107118
extension = CUDAExtension
108119
sources += source_cuda
@@ -142,9 +153,12 @@ def get_extensions():
142153
extension(
143154
'torchvision._C',
144155
sources,
145-
include_dirs=include_dirs,
156+
libraries= libraries,
157+
library_dirs=third_party_search_directories,
158+
include_dirs=include_dirs + third_party_search_directories,
146159
define_macros=define_macros,
147160
extra_compile_args=extra_compile_args,
161+
extra_objects=extra_objects
148162
)
149163
]
150164
if compile_cpp_tests:
@@ -196,29 +210,42 @@ def run(self):
196210
# It's an old-style class in Python 2.7...
197211
distutils.command.clean.clean.run(self)
198212

213+
def build_deps():
214+
this_dir = os.path.dirname(os.path.abspath(__file__))
215+
if sys.platform.startswith('linux'):
216+
os.chdir("third_party/libpng/")
217+
os.system('cmake .')
218+
os.system("cmake --build .")
219+
os.chdir(this_dir)
220+
221+
222+
223+
def build_ext_with_dependencies(self):
224+
build_deps()
225+
return BuildExtension.with_options(no_python_abi_suffix=True)(self)
199226

200227
setup(
201-
# Metadata
202-
name=package_name,
203-
version=version,
204-
author='PyTorch Core Team',
205-
author_email='[email protected]',
206-
url='https://github.com/pytorch/vision',
207-
description='image and video datasets and models for torch deep learning',
208-
long_description=readme,
209-
license='BSD',
210-
211-
# Package info
212-
packages=find_packages(exclude=('test',)),
213-
214-
zip_safe=False,
215-
install_requires=requirements,
216-
extras_require={
217-
"scipy": ["scipy"],
218-
},
219-
ext_modules=get_extensions(),
220-
cmdclass={
221-
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True),
222-
'clean': clean,
223-
}
224-
)
228+
# Metadata
229+
name=package_name,
230+
version=version,
231+
author='PyTorch Core Team',
232+
author_email='[email protected]',
233+
url='https://github.com/pytorch/vision',
234+
description='image and video datasets and models for torch deep learning',
235+
long_description=readme,
236+
license='BSD',
237+
238+
# Package info
239+
packages=find_packages(exclude=('test',)),
240+
241+
zip_safe=False,
242+
install_requires=requirements,
243+
extras_require={
244+
"scipy": ["scipy"],
245+
},
246+
ext_modules=get_extensions(),
247+
cmdclass={
248+
'build_ext': build_ext_with_dependencies,
249+
'clean': clean,
250+
}
251+
)

test/test_image.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import unittest
3+
import sys
4+
5+
import torch
6+
from PIL import Image
7+
if sys.platform.startswith('linux'):
8+
from torchvision.io.image import read_png, decode_png
9+
import numpy as np
10+
11+
IMAGE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "imagefolder")
12+
13+
14+
def get_images(directory, img_ext):
15+
assert os.path.isdir(directory)
16+
for root, dir, files in os.walk(directory):
17+
for fl in files:
18+
_, ext = os.path.splitext(fl)
19+
if ext == img_ext:
20+
yield os.path.join(root, fl)
21+
22+
23+
class ImageTester(unittest.TestCase):
24+
@unittest.skipUnless(sys.platform.startswith("linux"), "Support only available on linux for now.")
25+
def test_read_png(self):
26+
for img_path in get_images(IMAGE_DIR, "png"):
27+
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
28+
img_lpng = read_png(img_path)
29+
self.assertEqual(img_lpng, img_pil)
30+
31+
@unittest.skipUnless(sys.platform.startswith("linux"), "Support only available on linux for now.")
32+
def test_decode_png(self):
33+
for img_path in get_images(IMAGE_DIR, "png"):
34+
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
35+
size = os.path.getsize(img_path)
36+
img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
37+
self.assertEqual(img_lpng, img_pil)
38+
39+
if __name__ == '__main__':
40+
unittest.main()

third_party/libpng

Submodule libpng added at 301f7a1
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include "readpng_cpu.h"
2+
3+
#include <png.h>
4+
#include <setjmp.h>
5+
#include <string>
6+
7+
torch::Tensor decodePNG(const torch::Tensor& data) {
8+
auto png_ptr =
9+
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr);
10+
TORCH_CHECK(png_ptr, "libpng read structure allocation failed!")
11+
auto info_ptr = png_create_info_struct(png_ptr);
12+
if (!info_ptr) {
13+
png_destroy_read_struct(&png_ptr, nullptr, nullptr);
14+
// Seems redundant with the if statement. done here to avoid leaking memory.
15+
TORCH_CHECK(info_ptr, "libpng info structure allocation failed!")
16+
}
17+
18+
auto datap = data.accessor<unsigned char, 1>().data();
19+
20+
if (setjmp(png_jmpbuf(png_ptr)) != 0) {
21+
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
22+
TORCH_CHECK(false, "Internal error.");
23+
}
24+
auto is_png = !png_sig_cmp(datap, 0, 8);
25+
TORCH_CHECK(is_png, "Content is not png!")
26+
27+
struct Reader {
28+
png_const_bytep ptr;
29+
} reader;
30+
reader.ptr = png_const_bytep(datap) + 8;
31+
32+
auto read_callback =
33+
[](png_structp png_ptr, png_bytep output, png_size_t bytes) {
34+
auto reader = static_cast<Reader*>(png_get_io_ptr(png_ptr));
35+
std::copy(reader->ptr, reader->ptr + bytes, output);
36+
reader->ptr += bytes;
37+
};
38+
png_set_sig_bytes(png_ptr, 8);
39+
png_set_read_fn(png_ptr, &reader, read_callback);
40+
png_read_info(png_ptr, info_ptr);
41+
42+
png_uint_32 width, height;
43+
int bit_depth, color_type;
44+
auto retval = png_get_IHDR(
45+
png_ptr,
46+
info_ptr,
47+
&width,
48+
&height,
49+
&bit_depth,
50+
&color_type,
51+
nullptr,
52+
nullptr,
53+
nullptr);
54+
55+
if (retval != 1) {
56+
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
57+
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
58+
}
59+
if (color_type != PNG_COLOR_TYPE_RGB) {
60+
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
61+
TORCH_CHECK(
62+
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")
63+
}
64+
65+
auto tensor =
66+
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
67+
auto ptr = tensor.accessor<uint8_t, 3>().data();
68+
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
69+
for (decltype(height) i = 0; i < height; ++i) {
70+
png_read_row(png_ptr, ptr, nullptr);
71+
ptr += bytes;
72+
}
73+
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
74+
return tensor;
75+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
#include <string>
5+
6+
torch::Tensor decodePNG(const torch::Tensor& data);

torchvision/csrc/image.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#pragma once
2+
3+
#include "cpu/image/readpng_cpu.h"
4+

torchvision/csrc/vision.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include "ROIAlign.h"
1212
#include "ROIPool.h"
1313
#include "empty_tensor_op.h"
14+
#ifdef __linux__
15+
#include "image.h"
16+
#endif
1417
#include "nms.h"
1518

1619
// If we are in a Windows environment, we need to define
@@ -49,4 +52,7 @@ static auto registry =
4952
.op("torchvision::ps_roi_align", &ps_roi_align)
5053
.op("torchvision::ps_roi_pool", &ps_roi_pool)
5154
.op("torchvision::deform_conv2d", &deform_conv2d)
55+
#ifdef __linux__
56+
.op("torchvision::decode_png", &decodePNG)
57+
#endif
5258
.op("torchvision::_cuda_version", &_cuda_version);

0 commit comments

Comments
 (0)