diff --git a/doc/io.rst b/doc/io.rst index 192890e112a..fa14a491658 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -176,6 +176,10 @@ for dealing with datasets too big to fit into memory. Instead, xarray integrates with dask.array (see :ref:`dask`), which provides a fully featured engine for streaming computation. +It is possible to append or overwrite netCDF variables using the ``mode='a'`` +argument. When using this option, all variables in the dataset will be written +to the original netCDF file, regardless if they exist in the original dataset. + .. _io.encoding: Reading encoded data @@ -390,7 +394,7 @@ over the network until we look at particular values: Some servers require authentication before we can access the data. For this purpose we can explicitly create a :py:class:`~xarray.backends.PydapDataStore` -and pass in a `Requests`__ session object. For example for +and pass in a `Requests`__ session object. For example for HTTP Basic authentication:: import xarray as xr @@ -403,7 +407,7 @@ HTTP Basic authentication:: session=session) ds = xr.open_dataset(store) -`Pydap's cas module`__ has functions that generate custom sessions for +`Pydap's cas module`__ has functions that generate custom sessions for servers that use CAS single sign-on. For example, to connect to servers that require NASA's URS authentication:: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 44425baa153..231f93199a1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -290,6 +290,10 @@ Bug fixes the first argument was a numpy variable (:issue:`1588`). By `Guido Imperiale `_. +- Fix bug in :py:meth:`~xarray.Dataset.to_netcdf` when writing in append mode + (:issue:`1215`). + By `Joe Hamman `_. + - Fix ``netCDF4`` backend to properly roundtrip the ``shuffle`` encoding option (:issue:`1606`). By `Joe Hamman `_. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index e49bdd344f4..6a7c5e4beb4 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -223,8 +223,12 @@ def set_variables(self, variables, check_encoding_set, for vn, v in iteritems(variables): name = _encode_variable_name(vn) check = vn in check_encoding_set - target, source = self.prepare_variable( - name, v, check, unlimited_dims=unlimited_dims) + if vn not in self.variables: + target, source = self.prepare_variable( + name, v, check, unlimited_dims=unlimited_dims) + else: + target, source = self.ds.variables[name], v.data + self.writer.add(source, target) def set_necessary_dimensions(self, variable, unlimited_dims=None): diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e54f307c075..49f14ffccc5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -974,7 +974,8 @@ def to_netcdf(self, path=None, mode='w', format=None, group=None, default format becomes NETCDF3_64BIT). mode : {'w', 'a'}, optional Write ('w') or append ('a') mode. If mode='w', any existing file at - this location will be overwritten. + this location will be overwritten. If mode='a', existing variables + will be overwritten. format : {'NETCDF4', 'NETCDF4_CLASSIC', 'NETCDF3_64BIT', 'NETCDF3_CLASSIC'}, optional File format for the resulting netCDF file: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 6cb291693f4..3370e5aae26 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -128,12 +128,42 @@ class Only32BitTypes(object): class DatasetIOTestCases(object): autoclose = False + engine = None + file_format = None def create_store(self): raise NotImplementedError - def roundtrip(self, data, **kwargs): - raise NotImplementedError + @contextlib.contextmanager + def roundtrip(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as path: + self.save(data, path, **save_kwargs) + with self.open(path, **open_kwargs) as ds: + yield ds + + @contextlib.contextmanager + def roundtrip_append(self, data, save_kwargs={}, open_kwargs={}, + allow_cleanup_failure=False): + with create_tmp_file( + allow_cleanup_failure=allow_cleanup_failure) as path: + for i, key in enumerate(data.variables): + mode = 'a' if i > 0 else 'w' + self.save(data[[key]], path, mode=mode, **save_kwargs) + with self.open(path, **open_kwargs) as ds: + yield ds + + # The save/open methods may be overwritten below + def save(self, dataset, path, **kwargs): + dataset.to_netcdf(path, engine=self.engine, format=self.file_format, + **kwargs) + + @contextlib.contextmanager + def open(self, path, **kwargs): + with open_dataset(path, engine=self.engine, autoclose=self.autoclose, + **kwargs) as ds: + yield ds def test_zero_dimensional_variable(self): expected = create_test_data() @@ -563,6 +593,23 @@ def test_encoding_same_dtype(self): self.assertEqual(actual.x.encoding['dtype'], 'f4') self.assertEqual(ds.x.encoding, {}) + def test_append_write(self): + # regression for GH1215 + data = create_test_data() + with self.roundtrip_append(data) as actual: + assert_allclose(data, actual) + + def test_append_overwrite_values(self): + # regression for GH1215 + data = create_test_data() + with create_tmp_file(allow_cleanup_failure=False) as tmp_file: + self.save(data, tmp_file, mode='w') + data['var2'][:] = -999 + data['var9'] = data['var2'] * 3 + self.save(data[['var2', 'var9']], tmp_file, mode='a') + with self.open(tmp_file) as actual: + assert_allclose(data, actual) + _counter = itertools.count() @@ -592,6 +639,9 @@ def create_tmp_files(nfiles, suffix='.nc', allow_cleanup_failure=False): @requires_netCDF4 class BaseNetCDF4Test(CFEncodedDataTest): + + engine = 'netcdf4' + def test_open_group(self): # Create a netCDF file with a dataset stored within a group with create_tmp_file() as tmp_file: @@ -813,16 +863,6 @@ def create_store(self): with backends.NetCDF4DataStore.open(tmp_file, mode='w') as store: yield store - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, **save_kwargs) - with open_dataset(tmp_file, - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - def test_variable_order(self): # doesn't work with scipy or h5py :( ds = Dataset() @@ -883,19 +923,13 @@ class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): @requires_scipy class ScipyInMemoryDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): + engine = 'scipy' + @contextlib.contextmanager def create_store(self): fobj = BytesIO() yield backends.ScipyDataStore(fobj, 'w') - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - serialized = data.to_netcdf(**save_kwargs) - with open_dataset(serialized, engine='scipy', - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') @@ -915,6 +949,8 @@ class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): @requires_scipy class ScipyFileObjectTest(CFEncodedDataTest, Only32BitTypes, TestCase): + engine = 'scipy' + @contextlib.contextmanager def create_store(self): fobj = BytesIO() @@ -925,9 +961,9 @@ def roundtrip(self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False): with create_tmp_file() as tmp_file: with open(tmp_file, 'wb') as f: - data.to_netcdf(f, **save_kwargs) + self.save(data, f, **save_kwargs) with open(tmp_file, 'rb') as f: - with open_dataset(f, engine='scipy', **open_kwargs) as ds: + with self.open(f, **open_kwargs) as ds: yield ds @pytest.mark.skip(reason='cannot pickle file objects') @@ -941,22 +977,14 @@ def test_pickle_dataarray(self): @requires_scipy class ScipyFilePathTest(CFEncodedDataTest, Only32BitTypes, TestCase): + engine = 'scipy' + @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: with backends.ScipyDataStore(tmp_file, mode='w') as store: yield store - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) - with open_dataset(tmp_file, engine='scipy', - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - def test_array_attrs(self): ds = Dataset(attrs={'foo': [[1, 2], [3, 4]]}) with self.assertRaisesRegexp(ValueError, 'must be 1-dimensional'): @@ -995,6 +1023,9 @@ class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): @requires_netCDF4 class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): + engine = 'netcdf4' + file_format = 'NETCDF3_CLASSIC' + @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -1002,17 +1033,6 @@ def create_store(self): tmp_file, mode='w', format='NETCDF3_CLASSIC') as store: yield store - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, format='NETCDF3_CLASSIC', - engine='netcdf4', **save_kwargs) - with open_dataset(tmp_file, engine='netcdf4', - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): autoclose = True @@ -1021,6 +1041,9 @@ class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, Only32BitTypes, TestCase): + engine = 'netcdf4' + file_format = 'NETCDF4_CLASSIC' + @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -1028,17 +1051,6 @@ def create_store(self): tmp_file, mode='w', format='NETCDF4_CLASSIC') as store: yield store - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, format='NETCDF4_CLASSIC', - engine='netcdf4', **save_kwargs) - with open_dataset(tmp_file, engine='netcdf4', - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( NetCDF4ClassicViaNetCDF4DataTest): @@ -1049,21 +1061,12 @@ class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( class GenericNetCDFDataTest(CFEncodedDataTest, Only32BitTypes, TestCase): # verify that we can read and write netCDF3 files as long as we have scipy # or netCDF4-python installed + file_format = 'netcdf3_64bit' def test_write_store(self): # there's no specific store to test here pass - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, format='netcdf3_64bit', **save_kwargs) - with open_dataset(tmp_file, - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - def test_engine(self): data = create_test_data() with self.assertRaisesRegexp(ValueError, 'unrecognized engine'): @@ -1122,21 +1125,13 @@ class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): @requires_h5netcdf @requires_netCDF4 class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): + engine = 'h5netcdf' + @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: yield backends.H5NetCDFStore(tmp_file, 'w') - @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, engine='h5netcdf', **save_kwargs) - with open_dataset(tmp_file, engine='h5netcdf', - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds - def test_orthogonal_indexing(self): # doesn't work for h5py (without using dask as an intermediate layer) pass @@ -1646,14 +1641,13 @@ def test_orthogonal_indexing(self): pass @contextlib.contextmanager - def roundtrip(self, data, save_kwargs={}, open_kwargs={}, - allow_cleanup_failure=False): - with create_tmp_file( - allow_cleanup_failure=allow_cleanup_failure) as tmp_file: - data.to_netcdf(tmp_file, engine='scipy', **save_kwargs) - with open_dataset(tmp_file, engine='pynio', - autoclose=self.autoclose, **open_kwargs) as ds: - yield ds + def open(self, path, **kwargs): + with open_dataset(path, engine='pynio', autoclose=self.autoclose, + **kwargs) as ds: + yield ds + + def save(self, dataset, path, **kwargs): + dataset.to_netcdf(path, engine='scipy', **kwargs) def test_weakrefs(self): example = Dataset({'foo': ('x', np.arange(5.0))})