Skip to content

Support path-like objects for reading/writing #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ def u(v, encoding='utf-8', encodingErrors='strict'):
def is_string(v):
return isinstance(v, basestring)

if sys.version_info[0:2] >= (3, 6):
def pathlike_obj(path):
if isinstance(path, os.PathLike):
return os.fsdecode(path)
else:
return path
else:
def pathlike_obj(path):
if is_string(path):
return path
elif hasattr(path, "__fspath__"):
return path.__fspath__()
else:
try:
return str(path)
except:
return path


# Begin

Expand Down Expand Up @@ -930,14 +948,14 @@ def __init__(self, *args, **kwargs):
self.encodingErrors = kwargs.pop('encodingErrors', 'strict')
# See if a shapefile name was passed as the first argument
if len(args) > 0:
if is_string(args[0]):
path = args[0]
path = pathlike_obj(args[0])
if is_string(path):

if '.zip' in path:
# Shapefile is inside a zipfile
if path.count('.zip') > 1:
# Multiple nested zipfiles
raise ShapefileException('Reading from multiple nested zipfiles is not supported: %s' % args[0])
raise ShapefileException('Reading from multiple nested zipfiles is not supported: %s' % path)
# Split into zipfile and shapefile paths
if path.endswith('.zip'):
zpath = path
Expand Down Expand Up @@ -1708,8 +1726,9 @@ def __init__(self, target=None, shapeType=None, autoBalance=False, **kwargs):
self.shapeType = shapeType
self.shp = self.shx = self.dbf = None
if target:
target = pathlike_obj(target)
if not is_string(target):
raise Exception('The target filepath {} must be of type str/unicode, not {}.'.format(repr(target), type(target)) )
raise Exception('The target filepath {} must be of type str/unicode or path-like, not {}.'.format(repr(target), type(target)) )
self.shp = self.__getFileObj(os.path.splitext(target)[0] + '.shp')
self.shx = self.__getFileObj(os.path.splitext(target)[0] + '.shx')
self.dbf = self.__getFileObj(os.path.splitext(target)[0] + '.dbf')
Expand Down
30 changes: 30 additions & 0 deletions test_shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,22 @@
"""
# std lib imports
import os.path
import sys
if sys.version_info.major == 3:
from pathlib import Path

# third party imports
import pytest
import json
import datetime
if sys.version_info.major == 2:
# required by pytest for python <36
from pathlib2 import Path

# our imports
import shapefile


# define various test shape tuples of (type, points, parts indexes, and expected geo interface output)
geo_interface_tests = [ (shapefile.POINT, # point
[(1,1)],
Expand Down Expand Up @@ -403,6 +410,15 @@ def test_reader_shapefile_extension_ignored():
assert not os.path.exists(filename)


def test_reader_pathlike():
"""
Assert that path-like objects can be read.
"""
base = Path("shapefiles")
with shapefile.Reader(base / "blockgroups") as sf:
assert len(sf) == 663


def test_reader_filelike_dbf_only():
"""
Assert that specifying just the
Expand Down Expand Up @@ -888,6 +904,20 @@ def test_write_default_shp_shx_dbf(tmpdir):
assert os.path.exists(filename + ".dbf")


def test_write_pathlike(tmpdir):
"""
Assert that path-like objects can be written.
Similar to test_write_default_shp_shx_dbf.
"""
filename = tmpdir.join("test")
assert not isinstance(filename, str)
with shapefile.Writer(filename) as writer:
writer.field('field1', 'C')
assert (filename + ".shp").ensure()
assert (filename + ".shx").ensure()
assert (filename + ".dbf").ensure()


def test_write_shapefile_extension_ignored(tmpdir):
"""
Assert that the filename's extension is
Expand Down