Skip to content

Commit 621be55

Browse files
committed
Support path-like objects for reading/writing
1 parent d3d83a9 commit 621be55

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

shapefile.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,22 @@ def u(v, encoding='utf-8', encodingErrors='strict'):
161161
def is_string(v):
162162
return isinstance(v, basestring)
163163

164+
if sys.version_info[0:2] >= (3, 6):
165+
def pathlike_obj(path):
166+
if isinstance(path, os.PathLike):
167+
return os.fsdecode(path)
168+
else:
169+
return path
170+
else:
171+
def pathlike_obj(path):
172+
# detect if the object is path-like
173+
if "path" in path.__class__.__name__.lower():
174+
try:
175+
return str(path)
176+
except TypeError:
177+
pass
178+
return path
179+
164180

165181
# Begin
166182

@@ -930,14 +946,14 @@ def __init__(self, *args, **kwargs):
930946
self.encodingErrors = kwargs.pop('encodingErrors', 'strict')
931947
# See if a shapefile name was passed as the first argument
932948
if len(args) > 0:
933-
if is_string(args[0]):
934-
path = args[0]
935-
949+
path = pathlike_obj(args[0])
950+
if is_string(path):
951+
936952
if '.zip' in path:
937953
# Shapefile is inside a zipfile
938954
if path.count('.zip') > 1:
939955
# Multiple nested zipfiles
940-
raise ShapefileException('Reading from multiple nested zipfiles is not supported: %s' % args[0])
956+
raise ShapefileException('Reading from multiple nested zipfiles is not supported: %s' % path)
941957
# Split into zipfile and shapefile paths
942958
if path.endswith('.zip'):
943959
zpath = path
@@ -1708,8 +1724,9 @@ def __init__(self, target=None, shapeType=None, autoBalance=False, **kwargs):
17081724
self.shapeType = shapeType
17091725
self.shp = self.shx = self.dbf = None
17101726
if target:
1727+
target = pathlike_obj(target)
17111728
if not is_string(target):
1712-
raise Exception('The target filepath {} must be of type str/unicode, not {}.'.format(repr(target), type(target)) )
1729+
raise Exception('The target filepath {} must be of type str/unicode or path-like, not {}.'.format(repr(target), type(target)) )
17131730
self.shp = self.__getFileObj(os.path.splitext(target)[0] + '.shp')
17141731
self.shx = self.__getFileObj(os.path.splitext(target)[0] + '.shx')
17151732
self.dbf = self.__getFileObj(os.path.splitext(target)[0] + '.dbf')

test_shapefile.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"""
44
# std lib imports
55
import os.path
6+
import sys
7+
if sys.version_info[0] >= 3:
8+
from pathlib import Path
69

710
# third party imports
811
import pytest
@@ -12,6 +15,10 @@
1215
# our imports
1316
import shapefile
1417

18+
19+
min_py3 = pytest.mark.skipif(
20+
sys.version_info[0] < 3, reason="minimum Python 3 required")
21+
1522
# define various test shape tuples of (type, points, parts indexes, and expected geo interface output)
1623
geo_interface_tests = [ (shapefile.POINT, # point
1724
[(1,1)],
@@ -403,6 +410,16 @@ def test_reader_shapefile_extension_ignored():
403410
assert not os.path.exists(filename)
404411

405412

413+
@min_py3
414+
def test_reader_pathlike():
415+
"""
416+
Assert that path-like objects can be read.
417+
"""
418+
base = Path("shapefiles")
419+
with shapefile.Reader(base / "blockgroups") as sf:
420+
assert len(sf) == 663
421+
422+
406423
def test_reader_filelike_dbf_only():
407424
"""
408425
Assert that specifying just the
@@ -888,6 +905,20 @@ def test_write_default_shp_shx_dbf(tmpdir):
888905
assert os.path.exists(filename + ".dbf")
889906

890907

908+
def test_write_pathlike(tmpdir):
909+
"""
910+
Assert that path-like objects can be written.
911+
Similar to test_write_default_shp_shx_dbf.
912+
"""
913+
filename = tmpdir.join("test")
914+
assert not isinstance(filename, str)
915+
with shapefile.Writer(filename) as writer:
916+
writer.field('field1', 'C')
917+
assert (filename + ".shp").ensure()
918+
assert (filename + ".shx").ensure()
919+
assert (filename + ".dbf").ensure()
920+
921+
891922
def test_write_shapefile_extension_ignored(tmpdir):
892923
"""
893924
Assert that the filename's extension is

0 commit comments

Comments
 (0)