back to Reference (Gold) summary
Reference (Gold): xarray
Pytest Summary for test tests
status | count |
---|---|
passed | 15632 |
skipped | 1098 |
xfailed | 67 |
failed | 301 |
xpassed | 10 |
total | 17108 |
collected | 17108 |
Failed pytests:
test_backends.py::TestScipyInMemoryData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestScipyInMemoryData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::test_source_encoding_always_present_with_fsspec
test_backends.py::test_source_encoding_always_present_with_fsspec
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestScipyInMemoryData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestScipyInMemoryData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestScipyFileObject::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestScipyFileObject::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::test_use_cftime_false_standard_calendar_in_range[gregorian]
test_backends.py::test_use_cftime_false_standard_calendar_in_range[gregorian]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::test_use_cftime_false_standard_calendar_in_range[proleptic_gregorian]
test_backends.py::test_use_cftime_false_standard_calendar_in_range[proleptic_gregorian]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::test_use_cftime_false_standard_calendar_in_range[standard]
test_backends.py::test_use_cftime_false_standard_calendar_in_range[standard]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestScipyFileObject::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestScipyFileObject::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::test_load_single_value_h5netcdf
test_backends.py::test_load_single_value_h5netcdf
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::test_h5netcdf_entrypoint
test_backends.py::test_h5netcdf_entrypoint
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestScipyFilePath::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestScipyFilePath::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestScipyFilePath::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestScipyFilePath::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestZarrDirectoryStore::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestZarrDirectoryStore::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends_datatree.py::TestH5NetCDFDatatreeIO::test_to_netcdf
test_backends_datatree.py::TestH5NetCDFDatatreeIO::test_to_netcdf
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends_datatree.py::TestH5NetCDFDatatreeIO::test_to_netcdf_inherited_coords
test_backends_datatree.py::TestH5NetCDFDatatreeIO::test_to_netcdf_inherited_coords
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF3ViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF3ViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends_datatree.py::TestH5NetCDFDatatreeIO::test_netcdf_encoding
test_backends_datatree.py::TestH5NetCDFDatatreeIO::test_netcdf_encoding
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF3ViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF3ViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF4Data::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ClassicViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF4ClassicViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ClassicViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF4ClassicViaNetCDF4Data::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF4Data::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestGenericNetCDFData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestGenericNetCDFData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestGenericNetCDFData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestGenericNetCDFData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_compression_encoding[zstd]
test_backends.py::TestNetCDF4Data::test_compression_encoding[zstd]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_lz]
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_lz]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_zero_dimensional_variable
test_backends.py::TestH5NetCDFData::test_zero_dimensional_variable
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_write_store
test_backends.py::TestH5NetCDFData::test_write_store
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_lz4]
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_lz4]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_lz4hc]
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_lz4hc]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_test_data
test_backends.py::TestH5NetCDFData::test_roundtrip_test_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_zlib]
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_zlib]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_load
test_backends.py::TestH5NetCDFData::test_load
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_dataset_compute
test_backends.py::TestH5NetCDFData::test_dataset_compute
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_zstd]
test_backends.py::TestNetCDF4Data::test_compression_encoding[blosc_zstd]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_pickle
test_backends.py::TestH5NetCDFData::test_pickle
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_pickle_dataarray
test_backends.py::TestH5NetCDFData::test_pickle_dataarray
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_dataset_caching
test_backends.py::TestH5NetCDFData::test_dataset_caching
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_None_variable
test_backends.py::TestH5NetCDFData::test_roundtrip_None_variable
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_object_dtype
test_backends.py::TestH5NetCDFData::test_roundtrip_object_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_string_data
test_backends.py::TestH5NetCDFData::test_roundtrip_string_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF4ViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_string_encoded_characters
test_backends.py::TestH5NetCDFData::test_roundtrip_string_encoded_characters
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestNetCDF4ViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestZarrWriteEmpty::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestZarrWriteEmpty::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_numpy_datetime_data
test_backends.py::TestH5NetCDFData::test_roundtrip_numpy_datetime_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_cftime_datetime_data
test_backends.py::TestH5NetCDFData::test_roundtrip_cftime_datetime_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_timedelta_data
test_backends.py::TestH5NetCDFData::test_roundtrip_timedelta_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_float64_data
test_backends.py::TestH5NetCDFData::test_roundtrip_float64_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_example_1_netcdf
test_backends.py::TestH5NetCDFData::test_roundtrip_example_1_netcdf
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_coordinates
test_backends.py::TestH5NetCDFData::test_roundtrip_coordinates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[zstd]
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[zstd]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_global_coordinates
test_backends.py::TestH5NetCDFData::test_roundtrip_global_coordinates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_lz]
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_lz]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_boolean_dtype
test_backends.py::TestH5NetCDFData::test_roundtrip_boolean_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_lz4]
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_lz4]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_orthogonal_indexing
test_backends.py::TestH5NetCDFData::test_orthogonal_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_lz4hc]
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_lz4hc]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_zlib]
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_zlib]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_vectorized_indexing
test_backends.py::TestH5NetCDFData::test_vectorized_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_vectorized_indexing_negative_step
test_backends.py::TestH5NetCDFData::test_vectorized_indexing_negative_step
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_zstd]
test_backends.py::TestNetCDF4ViaDaskData::test_compression_encoding[blosc_zstd]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_outer_indexing_reversed
test_backends.py::TestH5NetCDFData::test_outer_indexing_reversed
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_isel_dataarray
test_backends.py::TestH5NetCDFData::test_isel_dataarray
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestZarrDictStore::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestZarrDictStore::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_array_type_after_indexing
test_backends.py::TestH5NetCDFData::test_array_type_after_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_dropna
test_backends.py::TestH5NetCDFData::test_dropna
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_conventions.py::TestCFEncodedDataStore::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_ondisk_after_print
test_backends.py::TestH5NetCDFData::test_ondisk_after_print
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_bytes_with_fill_value
test_backends.py::TestH5NetCDFData::test_roundtrip_bytes_with_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_empty_vlen_string_array
test_backends.py::TestH5NetCDFData::test_roundtrip_empty_vlen_string_array
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype0-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
test_backends.py::TestH5NetCDFData::test_roundtrip_mask_and_scale[dtype1-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestReduce2D::test_idxmin[dask-datetime]
test_dataarray.py::TestReduce2D::test_idxmin[dask-datetime]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[fillvalue0]
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[fillvalue0]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[fillvalue1]
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[fillvalue1]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[-1]
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[-1]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestReduce2D::test_idxmax[dask-datetime]
test_dataarray.py::TestReduce2D::test_idxmax[dask-datetime]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[255]
test_backends.py::TestH5NetCDFData::test_roundtrip_unsigned[255]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_coordinate_variables_after_dataset_roundtrip
test_backends.py::TestH5NetCDFData::test_coordinate_variables_after_dataset_roundtrip
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip
test_backends.py::TestH5NetCDFData::test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_coordinates_encoding
test_backends.py::TestH5NetCDFData::test_coordinates_encoding
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_endian
test_backends.py::TestH5NetCDFData::test_roundtrip_endian
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg
test_backends.py::TestH5NetCDFData::test_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_combine.py::TestNestedCombine::test_nested_concat_too_many_dims_at_once
test_combine.py::TestNestedCombine::test_nested_concat_too_many_dims_at_once
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_dates
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_dates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_fixed_width_string
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_fixed_width_string
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_default_fill_value
test_backends.py::TestH5NetCDFData::test_default_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value_via_encoding_kwarg
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value_via_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value_in_coord
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value_in_coord
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataset.py::TestDataset::test_copy_coords[True-expected_orig0]
test_dataset.py::TestDataset::test_copy_coords[True-expected_orig0]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataset.py::TestDataset::test_copy_coords[False-expected_orig1]
test_dataset.py::TestDataset::test_copy_coords[False-expected_orig1]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestGetItem::test_getitem_multiple_data_variables
test_datatree.py::TestGetItem::test_getitem_multiple_data_variables
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestGetItem::test_getitem_dict_like_selection_access_to_dataset
test_datatree.py::TestGetItem::test_getitem_dict_like_selection_access_to_dataset
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg
test_backends.py::TestH5NetCDFData::test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestCopy::test_copy_with_data
test_datatree.py::TestCopy::test_copy_with_data
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestSetItem::test_setitem_dataset_on_this_node
test_datatree.py::TestSetItem::test_setitem_dataset_on_this_node
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestSetItem::test_setitem_dataset_as_new_node
test_datatree.py::TestSetItem::test_setitem_dataset_as_new_node
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_same_dtype
test_backends.py::TestH5NetCDFData::test_encoding_same_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestSetItem::test_setitem_dataset_as_new_node_requiring_intermediate_nodes
test_datatree.py::TestSetItem::test_setitem_dataset_as_new_node_requiring_intermediate_nodes
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_append_write
test_backends.py::TestH5NetCDFData::test_append_write
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree.py::TestTreeFromDict::test_roundtrip_unnamed_root
test_datatree.py::TestTreeFromDict::test_roundtrip_unnamed_root
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_append_overwrite_values
test_backends.py::TestH5NetCDFData::test_append_overwrite_values
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_append_with_invalid_dim_raises
test_backends.py::TestH5NetCDFData::test_append_with_invalid_dim_raises
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::TestToDaskDataFrame::test_to_dask_dataframe_2D
test_dask.py::TestToDaskDataFrame::test_to_dask_dataframe_2D
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree_mapping.py::TestMapOverSubTree::test_return_inconsistent_number_of_results
test_datatree_mapping.py::TestMapOverSubTree::test_return_inconsistent_number_of_results
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_multiindex_not_implemented
test_backends.py::TestH5NetCDFData::test_multiindex_not_implemented
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree_mapping.py::TestMapOverSubTree::test_trees_with_different_node_names
test_datatree_mapping.py::TestMapOverSubTree::test_trees_with_different_node_names
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree_mapping.py::TestMapOverSubTree::test_error_contains_path_of_offending_node
test_datatree_mapping.py::TestMapOverSubTree::test_error_contains_path_of_offending_node
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_datatree_mapping.py::TestMapOverSubTreeInplace::test_map_over_subtree_inplace
test_datatree_mapping.py::TestMapOverSubTreeInplace::test_map_over_subtree_inplace
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::TestToDaskDataFrame::test_to_dask_dataframe_2D_set_index
test_dask.py::TestToDaskDataFrame::test_to_dask_dataframe_2D_set_index
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::TestToDaskDataFrame::test_to_dask_dataframe_not_daskarray
test_dask.py::TestToDaskDataFrame::test_to_dask_dataframe_not_daskarray
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_refresh_from_disk
test_backends.py::TestH5NetCDFData::test_refresh_from_disk
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_open_group
test_backends.py::TestH5NetCDFData::test_open_group
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_open_subgroup
test_backends.py::TestH5NetCDFData::test_open_subgroup
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_write_groups
test_backends.py::TestH5NetCDFData::test_write_groups
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_vlen_string[input_strings0-True]
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_vlen_string[input_strings0-True]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_vlen_string[input_strings1-False]
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_vlen_string[input_strings1-False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_vlen_string[input_strings2-False]
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_vlen_string[input_strings2-False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_string_with_fill_value_vlen[XXX]
test_backends.py::TestH5NetCDFData::test_roundtrip_string_with_fill_value_vlen[XXX]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_string_with_fill_value_vlen[]
test_backends.py::TestH5NetCDFData::test_roundtrip_string_with_fill_value_vlen[]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_string_with_fill_value_vlen[b\xe1r]
test_backends.py::TestH5NetCDFData::test_roundtrip_string_with_fill_value_vlen[b\xe1r]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_roundtrip_character_array
test_backends.py::TestH5NetCDFData::test_roundtrip_character_array
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_default_to_char_arrays
test_backends.py::TestH5NetCDFData::test_default_to_char_arrays
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_dump_encodings
test_backends.py::TestH5NetCDFData::test_dump_encodings
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_compression_encoding_legacy
test_backends.py::TestH5NetCDFData::test_compression_encoding_legacy
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_compression
test_backends.py::TestH5NetCDFData::test_encoding_kwarg_compression
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_keep_chunksizes_if_no_original_shape
test_backends.py::TestH5NetCDFData::test_keep_chunksizes_if_no_original_shape
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_computation.py::test_cross[a5-b5-ae5-be5-cartesian--1-True]
test_computation.py::test_cross[a5-b5-ae5-be5-cartesian--1-True]
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_preferred_chunks_is_present
test_backends.py::TestH5NetCDFData::test_preferred_chunks_is_present
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_computation.py::test_cross[a6-b6-ae6-be6-cartesian--1-True]
test_computation.py::test_cross[a6-b6-ae6-be6-cartesian--1-True]
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_auto_chunking_is_based_on_disk_chunk_sizes
test_backends.py::TestH5NetCDFData::test_auto_chunking_is_based_on_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_base_chunking_uses_disk_chunk_sizes
test_backends.py::TestH5NetCDFData::test_base_chunking_uses_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_preferred_chunks_are_disk_chunk_sizes
test_backends.py::TestH5NetCDFData::test_preferred_chunks_are_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::test_map_blocks_da_ds_with_template[obj0]
test_dask.py::test_map_blocks_da_ds_with_template[obj0]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_chunksizes_unlimited
test_backends.py::TestH5NetCDFData::test_encoding_chunksizes_unlimited
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::test_map_blocks_da_ds_with_template[obj1]
test_dask.py::test_map_blocks_da_ds_with_template[obj1]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_raise_on_forward_slashes_in_names
test_backends.py::TestH5NetCDFData::test_raise_on_forward_slashes_in_names
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::test_map_blocks_template_convert_object
test_dask.py::test_map_blocks_template_convert_object
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::test_map_blocks_errors_bad_template[obj0]
test_dask.py::test_map_blocks_errors_bad_template[obj0]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dask.py::test_map_blocks_errors_bad_template[obj1]
test_dask.py::test_map_blocks_errors_bad_template[obj1]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_enum__no_fill_value
test_backends.py::TestH5NetCDFData::test_encoding_enum__no_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_enum__multiple_variable_with_enum
test_backends.py::TestH5NetCDFData::test_encoding_enum__multiple_variable_with_enum
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_complex
test_backends.py::TestH5NetCDFData::test_complex
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_complex_error[None]
test_backends.py::TestH5NetCDFData::test_complex_error[None]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_complex_error[False]
test_backends.py::TestH5NetCDFData::test_complex_error[False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_numpy_bool_
test_backends.py::TestH5NetCDFData::test_numpy_bool_
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_cross_engine_read_write_netcdf4
test_backends.py::TestH5NetCDFData::test_cross_engine_read_write_netcdf4
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_encoding_unlimited_dims
test_backends.py::TestH5NetCDFData::test_encoding_unlimited_dims
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_compression_encoding_h5py
test_backends.py::TestH5NetCDFData::test_compression_encoding_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_compression_check_encoding_h5py
test_backends.py::TestH5NetCDFData::test_compression_check_encoding_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFData::test_dump_encodings_h5py
test_backends.py::TestH5NetCDFData::test_dump_encodings_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFAlreadyOpen::test_deepcopy
test_backends.py::TestH5NetCDFAlreadyOpen::test_deepcopy
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestDataArray::test_astype_subok
test_dataarray.py::TestDataArray::test_astype_subok
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_zero_dimensional_variable
test_backends.py::TestH5NetCDFFileObject::test_zero_dimensional_variable
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_write_store
test_backends.py::TestH5NetCDFFileObject::test_write_store
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_test_data
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_test_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_distributed.py::test_dask_distributed_netcdf_roundtrip[h5netcdf-NETCDF4]
test_distributed.py::test_dask_distributed_netcdf_roundtrip[h5netcdf-NETCDF4]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_load
test_backends.py::TestH5NetCDFFileObject::test_load
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_dataset_compute
test_backends.py::TestH5NetCDFFileObject::test_dataset_compute
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_pickle
test_backends.py::TestH5NetCDFFileObject::test_pickle
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_pickle_dataarray
test_backends.py::TestH5NetCDFFileObject::test_pickle_dataarray
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_dataset_caching
test_backends.py::TestH5NetCDFFileObject::test_dataset_caching
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_None_variable
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_None_variable
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_object_dtype
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_object_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_data
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_encoded_characters
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_encoded_characters
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_numpy_datetime_data
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_numpy_datetime_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_cftime_datetime_data
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_cftime_datetime_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_timedelta_data
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_timedelta_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_float64_data
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_float64_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_example_1_netcdf
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_example_1_netcdf
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_coordinates
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_coordinates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_global_coordinates
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_global_coordinates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_boolean_dtype
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_boolean_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_orthogonal_indexing
test_backends.py::TestH5NetCDFFileObject::test_orthogonal_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_vectorized_indexing
test_backends.py::TestH5NetCDFFileObject::test_vectorized_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_vectorized_indexing_negative_step
test_backends.py::TestH5NetCDFFileObject::test_vectorized_indexing_negative_step
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_outer_indexing_reversed
test_backends.py::TestH5NetCDFFileObject::test_outer_indexing_reversed
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_isel_dataarray
test_backends.py::TestH5NetCDFFileObject::test_isel_dataarray
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_array_type_after_indexing
test_backends.py::TestH5NetCDFFileObject::test_array_type_after_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestDataArray::test_to_dask_dataframe
test_dataarray.py::TestDataArray::test_to_dask_dataframe
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_dropna
test_backends.py::TestH5NetCDFFileObject::test_dropna
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_ondisk_after_print
test_backends.py::TestH5NetCDFFileObject::test_ondisk_after_print
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestDataArray::test_copy_coords[True-expected_orig0]
test_dataarray.py::TestDataArray::test_copy_coords[True-expected_orig0]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestDataArray::test_copy_coords[False-expected_orig1]
test_dataarray.py::TestDataArray::test_copy_coords[False-expected_orig1]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_bytes_with_fill_value
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_bytes_with_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_empty_vlen_string_array
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_empty_vlen_string_array
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype0-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestReduce1D::test_idxmin[True-datetime]
test_dataarray.py::TestReduce1D::test_idxmin[True-datetime]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_interp.py::test_datetime[2000-01-01T12:00-0.5]
test_interp.py::test_datetime[2000-01-01T12:00-0.5]
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_dataarray.py::TestReduce1D::test_idxmax[True-datetime]
test_dataarray.py::TestReduce1D::test_idxmax[True-datetime]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_mask_and_scale[dtype1-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[fillvalue0]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[fillvalue0]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[fillvalue1]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[fillvalue1]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[-1]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[-1]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[255]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_unsigned[255]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_missing.py::test_interpolate_na_2d[None]
test_missing.py::test_interpolate_na_2d[None]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_coordinate_variables_after_dataset_roundtrip
test_backends.py::TestH5NetCDFFileObject::test_coordinate_variables_after_dataset_roundtrip
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_namedarray.py::TestNamedArray::test_init[expected1]
test_namedarray.py::TestNamedArray::test_init[expected1]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip
test_backends.py::TestH5NetCDFFileObject::test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_coordinates_encoding
test_backends.py::TestH5NetCDFFileObject::test_coordinates_encoding
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_endian
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_endian
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_dates
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_dates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_fixed_width_string
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_fixed_width_string
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_default_fill_value
test_backends.py::TestH5NetCDFFileObject::test_default_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_distributed.py::test_dask_distributed_read_netcdf_integration_test[h5netcdf-NETCDF4]
test_distributed.py::test_dask_distributed_read_netcdf_integration_test[h5netcdf-NETCDF4]
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value_via_encoding_kwarg
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value_via_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value_in_coord
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value_in_coord
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg
test_backends.py::TestH5NetCDFFileObject::test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_same_dtype
test_backends.py::TestH5NetCDFFileObject::test_encoding_same_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_append_write
test_backends.py::TestH5NetCDFFileObject::test_append_write
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_append_overwrite_values
test_backends.py::TestH5NetCDFFileObject::test_append_overwrite_values
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_append_with_invalid_dim_raises
test_backends.py::TestH5NetCDFFileObject::test_append_with_invalid_dim_raises
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_multiindex_not_implemented
test_backends.py::TestH5NetCDFFileObject::test_multiindex_not_implemented
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_refresh_from_disk
test_backends.py::TestH5NetCDFFileObject::test_refresh_from_disk
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_open_group
test_backends.py::TestH5NetCDFFileObject::test_open_group
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_open_subgroup
test_backends.py::TestH5NetCDFFileObject::test_open_subgroup
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_write_groups
test_backends.py::TestH5NetCDFFileObject::test_write_groups
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_vlen_string[input_strings0-True]
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_vlen_string[input_strings0-True]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_vlen_string[input_strings1-False]
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_vlen_string[input_strings1-False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_vlen_string[input_strings2-False]
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_vlen_string[input_strings2-False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_with_fill_value_vlen[XXX]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_with_fill_value_vlen[XXX]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_with_fill_value_vlen[]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_with_fill_value_vlen[]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_with_fill_value_vlen[b\xe1r]
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_string_with_fill_value_vlen[b\xe1r]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_character_array
test_backends.py::TestH5NetCDFFileObject::test_roundtrip_character_array
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_plot.py::TestImshow::test_dates_are_concise
test_plot.py::TestImshow::test_dates_are_concise
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_default_to_char_arrays
test_backends.py::TestH5NetCDFFileObject::test_default_to_char_arrays
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_dump_encodings
test_backends.py::TestH5NetCDFFileObject::test_dump_encodings
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_compression_encoding_legacy
test_backends.py::TestH5NetCDFFileObject::test_compression_encoding_legacy
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_compression
test_backends.py::TestH5NetCDFFileObject::test_encoding_kwarg_compression
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_keep_chunksizes_if_no_original_shape
test_backends.py::TestH5NetCDFFileObject::test_keep_chunksizes_if_no_original_shape
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_preferred_chunks_is_present
test_backends.py::TestH5NetCDFFileObject::test_preferred_chunks_is_present
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_auto_chunking_is_based_on_disk_chunk_sizes
test_backends.py::TestH5NetCDFFileObject::test_auto_chunking_is_based_on_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_plot.py::TestSurface::test_dates_are_concise
test_plot.py::TestSurface::test_dates_are_concise
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_base_chunking_uses_disk_chunk_sizes
test_backends.py::TestH5NetCDFFileObject::test_base_chunking_uses_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_preferred_chunks_are_disk_chunk_sizes
test_backends.py::TestH5NetCDFFileObject::test_preferred_chunks_are_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_chunksizes_unlimited
test_backends.py::TestH5NetCDFFileObject::test_encoding_chunksizes_unlimited
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_raise_on_forward_slashes_in_names
test_backends.py::TestH5NetCDFFileObject::test_raise_on_forward_slashes_in_names
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_enum__no_fill_value
test_backends.py::TestH5NetCDFFileObject::test_encoding_enum__no_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_enum__multiple_variable_with_enum
test_backends.py::TestH5NetCDFFileObject::test_encoding_enum__multiple_variable_with_enum
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_complex
test_backends.py::TestH5NetCDFFileObject::test_complex
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_complex_error[None]
test_backends.py::TestH5NetCDFFileObject::test_complex_error[None]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_complex_error[False]
test_backends.py::TestH5NetCDFFileObject::test_complex_error[False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_numpy_bool_
test_backends.py::TestH5NetCDFFileObject::test_numpy_bool_
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_cross_engine_read_write_netcdf4
test_backends.py::TestH5NetCDFFileObject::test_cross_engine_read_write_netcdf4
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_encoding_unlimited_dims
test_backends.py::TestH5NetCDFFileObject::test_encoding_unlimited_dims
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_compression_encoding_h5py
test_backends.py::TestH5NetCDFFileObject::test_compression_encoding_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_compression_check_encoding_h5py
test_backends.py::TestH5NetCDFFileObject::test_compression_check_encoding_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_dump_encodings_h5py
test_backends.py::TestH5NetCDFFileObject::test_dump_encodings_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_open_twice
test_backends.py::TestH5NetCDFFileObject::test_open_twice
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFFileObject::test_open_fileobj
test_backends.py::TestH5NetCDFFileObject::test_open_fileobj
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_zero_dimensional_variable
test_backends.py::TestH5NetCDFViaDaskData::test_zero_dimensional_variable
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_write_store
test_backends.py::TestH5NetCDFViaDaskData::test_write_store
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_test_data
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_test_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_load
test_backends.py::TestH5NetCDFViaDaskData::test_load
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_dataset_compute
test_backends.py::TestH5NetCDFViaDaskData::test_dataset_compute
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_pickle
test_backends.py::TestH5NetCDFViaDaskData::test_pickle
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_pickle_dataarray
test_backends.py::TestH5NetCDFViaDaskData::test_pickle_dataarray
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_None_variable
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_None_variable
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_object_dtype
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_object_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_data
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_encoded_characters
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_encoded_characters
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_numpy_datetime_data
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_numpy_datetime_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_cftime_datetime_data
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_cftime_datetime_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_timedelta_data
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_timedelta_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_float64_data
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_float64_data
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_example_1_netcdf
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_example_1_netcdf
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_coordinates
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_coordinates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_global_coordinates
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_global_coordinates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_boolean_dtype
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_boolean_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_orthogonal_indexing
test_backends.py::TestH5NetCDFViaDaskData::test_orthogonal_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_vectorized_indexing
test_backends.py::TestH5NetCDFViaDaskData::test_vectorized_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_vectorized_indexing_negative_step
test_backends.py::TestH5NetCDFViaDaskData::test_vectorized_indexing_negative_step
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_outer_indexing_reversed
test_backends.py::TestH5NetCDFViaDaskData::test_outer_indexing_reversed
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_isel_dataarray
test_backends.py::TestH5NetCDFViaDaskData::test_isel_dataarray
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_array_type_after_indexing
test_backends.py::TestH5NetCDFViaDaskData::test_array_type_after_indexing
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_dropna
test_backends.py::TestH5NetCDFViaDaskData::test_dropna
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_ondisk_after_print
test_backends.py::TestH5NetCDFViaDaskData::test_ondisk_after_print
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_bytes_with_fill_value
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_bytes_with_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_empty_vlen_string_array
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_empty_vlen_string_array
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype0-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_unsigned_masked_scaled_data-create_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_bad_unsigned_masked_scaled_data-create_bad_encoded_unsigned_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_signed_masked_scaled_data-create_encoded_signed_masked_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_mask_and_scale[dtype1-create_masked_and_scaled_data-create_encoded_masked_and_scaled_data]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[fillvalue0]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[fillvalue0]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[fillvalue1]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[fillvalue1]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[-1]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[-1]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[255]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_unsigned[255]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_coordinate_variables_after_dataset_roundtrip
test_backends.py::TestH5NetCDFViaDaskData::test_coordinate_variables_after_dataset_roundtrip
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip
test_backends.py::TestH5NetCDFViaDaskData::test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_coordinates_encoding
test_backends.py::TestH5NetCDFViaDaskData::test_coordinates_encoding
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_endian
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_endian
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_dates
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_dates
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_fixed_width_string
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_fixed_width_string
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_default_fill_value
test_backends.py::TestH5NetCDFViaDaskData::test_default_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value_via_encoding_kwarg
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value_via_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value_in_coord
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value_in_coord
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg
test_backends.py::TestH5NetCDFViaDaskData::test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_same_dtype
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_same_dtype
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_append_write
test_backends.py::TestH5NetCDFViaDaskData::test_append_write
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_append_overwrite_values
test_backends.py::TestH5NetCDFViaDaskData::test_append_overwrite_values
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_append_with_invalid_dim_raises
test_backends.py::TestH5NetCDFViaDaskData::test_append_with_invalid_dim_raises
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_multiindex_not_implemented
test_backends.py::TestH5NetCDFViaDaskData::test_multiindex_not_implemented
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_refresh_from_disk
test_backends.py::TestH5NetCDFViaDaskData::test_refresh_from_disk
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_open_group
test_backends.py::TestH5NetCDFViaDaskData::test_open_group
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_open_subgroup
test_backends.py::TestH5NetCDFViaDaskData::test_open_subgroup
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_write_groups
test_backends.py::TestH5NetCDFViaDaskData::test_write_groups
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_vlen_string[input_strings0-True]
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_vlen_string[input_strings0-True]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_vlen_string[input_strings1-False]
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_vlen_string[input_strings1-False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_vlen_string[input_strings2-False]
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_vlen_string[input_strings2-False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_with_fill_value_vlen[XXX]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_with_fill_value_vlen[XXX]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_with_fill_value_vlen[]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_with_fill_value_vlen[]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_with_fill_value_vlen[b\xe1r]
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_string_with_fill_value_vlen[b\xe1r]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_character_array
test_backends.py::TestH5NetCDFViaDaskData::test_roundtrip_character_array
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_default_to_char_arrays
test_backends.py::TestH5NetCDFViaDaskData::test_default_to_char_arrays
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_dump_encodings
test_backends.py::TestH5NetCDFViaDaskData::test_dump_encodings
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_compression_encoding_legacy
test_backends.py::TestH5NetCDFViaDaskData::test_compression_encoding_legacy
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_compression
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_kwarg_compression
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_keep_chunksizes_if_no_original_shape
test_backends.py::TestH5NetCDFViaDaskData::test_keep_chunksizes_if_no_original_shape
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_preferred_chunks_is_present
test_backends.py::TestH5NetCDFViaDaskData::test_preferred_chunks_is_present
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_auto_chunking_is_based_on_disk_chunk_sizes
test_backends.py::TestH5NetCDFViaDaskData::test_auto_chunking_is_based_on_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_base_chunking_uses_disk_chunk_sizes
test_backends.py::TestH5NetCDFViaDaskData::test_base_chunking_uses_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_preferred_chunks_are_disk_chunk_sizes
test_backends.py::TestH5NetCDFViaDaskData::test_preferred_chunks_are_disk_chunk_sizes
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_chunksizes_unlimited
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_chunksizes_unlimited
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_raise_on_forward_slashes_in_names
test_backends.py::TestH5NetCDFViaDaskData::test_raise_on_forward_slashes_in_names
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_enum__no_fill_value
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_enum__no_fill_value
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariable::test_timedelta64_valid_range
test_variable.py::TestVariable::test_timedelta64_valid_range
[gw4] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_enum__multiple_variable_with_enum
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_enum__multiple_variable_with_enum
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_complex
test_backends.py::TestH5NetCDFViaDaskData::test_complex
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_complex_error[None]
test_backends.py::TestH5NetCDFViaDaskData::test_complex_error[None]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_complex_error[False]
test_backends.py::TestH5NetCDFViaDaskData::test_complex_error[False]
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_numpy_bool_
test_backends.py::TestH5NetCDFViaDaskData::test_numpy_bool_
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_cross_engine_read_write_netcdf4
test_backends.py::TestH5NetCDFViaDaskData::test_cross_engine_read_write_netcdf4
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_unlimited_dims
test_backends.py::TestH5NetCDFViaDaskData::test_encoding_unlimited_dims
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_compression_encoding_h5py
test_backends.py::TestH5NetCDFViaDaskData::test_compression_encoding_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_compression_check_encoding_h5py
test_backends.py::TestH5NetCDFViaDaskData::test_compression_check_encoding_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_dump_encodings_h5py
test_backends.py::TestH5NetCDFViaDaskData::test_dump_encodings_h5py
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_backends.py::TestH5NetCDFViaDaskData::test_write_inconsistent_chunks
test_backends.py::TestH5NetCDFViaDaskData::test_write_inconsistent_chunks
[gw2] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_timedelta64_valid_range
test_variable.py::TestVariableWithDask::test_timedelta64_valid_range
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg0-np_arg0-median]
test_variable.py::TestVariableWithDask::test_pad[xr_arg0-np_arg0-median]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg0-np_arg0-reflect]
test_variable.py::TestVariableWithDask::test_pad[xr_arg0-np_arg0-reflect]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg1-np_arg1-median]
test_variable.py::TestVariableWithDask::test_pad[xr_arg1-np_arg1-median]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_0d_object_array_with_list
test_variable.py::TestVariableWithDask::test_0d_object_array_with_list
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_array_interface
test_variable.py::TestVariableWithDask::test_array_interface
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg1-np_arg1-reflect]
test_variable.py::TestVariableWithDask::test_pad[xr_arg1-np_arg1-reflect]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_copy_index
test_variable.py::TestVariableWithDask::test_copy_index
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg2-np_arg2-median]
test_variable.py::TestVariableWithDask::test_pad[xr_arg2-np_arg2-median]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_eq_all_dtypes
test_variable.py::TestVariableWithDask::test_eq_all_dtypes
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg2-np_arg2-reflect]
test_variable.py::TestVariableWithDask::test_pad[xr_arg2-np_arg2-reflect]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_strategies.py::TestReduction::test_mean
test_strategies.py::TestReduction::test_mean
[gw1] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg3-np_arg3-median]
test_variable.py::TestVariableWithDask::test_pad[xr_arg3-np_arg3-median]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg3-np_arg3-reflect]
test_variable.py::TestVariableWithDask::test_pad[xr_arg3-np_arg3-reflect]
[gw0] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg4-np_arg4-median]
test_variable.py::TestVariableWithDask::test_pad[xr_arg4-np_arg4-median]
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestVariableWithDask::test_pad[xr_arg4-np_arg4-reflect]
test_variable.py::TestVariableWithDask::test_pad[xr_arg4-np_arg4-reflect]
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
test_variable.py::TestIndexVariable::test_timedelta64_valid_range
test_variable.py::TestIndexVariable::test_timedelta64_valid_range
[gw3] linux -- Python 3.10.12 /testbed/.venv/bin/python3
Patch diff
diff --git a/xarray/backends/api.py b/xarray/backends/api.py
index 51972ac3..ece60a2b 100644
--- a/xarray/backends/api.py
+++ b/xarray/backends/api.py
@@ -1,17 +1,38 @@
from __future__ import annotations
+
import os
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from functools import partial
from io import BytesIO
from numbers import Number
-from typing import TYPE_CHECKING, Any, Callable, Final, Literal, Union, cast, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Final,
+ Literal,
+ Union,
+ cast,
+ overload,
+)
+
import numpy as np
+
from xarray import backends, conventions
from xarray.backends import plugins
-from xarray.backends.common import AbstractDataStore, ArrayWriter, _find_absolute_paths, _normalize_path
+from xarray.backends.common import (
+ AbstractDataStore,
+ ArrayWriter,
+ _find_absolute_paths,
+ _normalize_path,
+)
from xarray.backends.locks import _get_scheduler
from xarray.core import indexing
-from xarray.core.combine import _infer_concat_order_from_positions, _nested_combine, combine_by_coords
+from xarray.core.combine import (
+ _infer_concat_order_from_positions,
+ _nested_combine,
+ combine_by_coords,
+)
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
@@ -19,30 +40,125 @@ from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
from xarray.core.utils import is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import guess_chunkmanager
+
if TYPE_CHECKING:
try:
from dask.delayed import Delayed
except ImportError:
- Delayed = None
+ Delayed = None # type: ignore
from io import BufferedIOBase
+
from xarray.backends.common import BackendEntrypoint
- from xarray.core.types import CombineAttrsOptions, CompatOptions, JoinOptions, NestedSequence, T_Chunks
- T_NetcdfEngine = Literal['netcdf4', 'scipy', 'h5netcdf']
- T_Engine = Union[T_NetcdfEngine, Literal['pydap', 'zarr'], type[
- BackendEntrypoint], str, None]
- T_NetcdfTypes = Literal['NETCDF4', 'NETCDF4_CLASSIC', 'NETCDF3_64BIT',
- 'NETCDF3_CLASSIC']
+ from xarray.core.types import (
+ CombineAttrsOptions,
+ CompatOptions,
+ JoinOptions,
+ NestedSequence,
+ T_Chunks,
+ )
+
+ T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"]
+ T_Engine = Union[
+ T_NetcdfEngine,
+ Literal["pydap", "zarr"],
+ type[BackendEntrypoint],
+ str, # no nice typing support for custom backends
+ None,
+ ]
+ T_NetcdfTypes = Literal[
+ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"
+ ]
from xarray.core.datatree import DataTree
-DATAARRAY_NAME = '__xarray_dataarray_name__'
-DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'
-ENGINES = {'netcdf4': backends.NetCDF4DataStore.open, 'scipy': backends.
- ScipyDataStore, 'pydap': backends.PydapDataStore.open, 'h5netcdf':
- backends.H5NetCDFStore.open, 'zarr': backends.ZarrStore.open_group}
+DATAARRAY_NAME = "__xarray_dataarray_name__"
+DATAARRAY_VARIABLE = "__xarray_dataarray_variable__"
+
+ENGINES = {
+ "netcdf4": backends.NetCDF4DataStore.open,
+ "scipy": backends.ScipyDataStore,
+ "pydap": backends.PydapDataStore.open,
+ "h5netcdf": backends.H5NetCDFStore.open,
+ "zarr": backends.ZarrStore.open_group,
+}
+
+
+def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]:
+ engine: Literal["netcdf4", "pydap"]
+ try:
+ import netCDF4 # noqa: F401
+
+ engine = "netcdf4"
+ except ImportError: # pragma: no cover
+ try:
+ import pydap # noqa: F401
+
+ engine = "pydap"
+ except ImportError:
+ raise ValueError(
+ "netCDF4 or pydap is required for accessing "
+ "remote datasets via OPeNDAP"
+ )
+ return engine
+
+
+def _get_default_engine_gz() -> Literal["scipy"]:
+ try:
+ import scipy # noqa: F401
+
+ engine: Final = "scipy"
+ except ImportError: # pragma: no cover
+ raise ValueError("scipy is required for accessing .gz files")
+ return engine
-def _validate_dataset_names(dataset: Dataset) ->None:
+
+def _get_default_engine_netcdf() -> Literal["netcdf4", "scipy"]:
+ engine: Literal["netcdf4", "scipy"]
+ try:
+ import netCDF4 # noqa: F401
+
+ engine = "netcdf4"
+ except ImportError: # pragma: no cover
+ try:
+ import scipy.io.netcdf # noqa: F401
+
+ engine = "scipy"
+ except ImportError:
+ raise ValueError(
+ "cannot read or write netCDF files without "
+ "netCDF4-python or scipy installed"
+ )
+ return engine
+
+
+def _get_default_engine(path: str, allow_remote: bool = False) -> T_NetcdfEngine:
+ if allow_remote and is_remote_uri(path):
+ return _get_default_engine_remote_uri() # type: ignore[return-value]
+ elif path.endswith(".gz"):
+ return _get_default_engine_gz()
+ else:
+ return _get_default_engine_netcdf()
+
+
+def _validate_dataset_names(dataset: Dataset) -> None:
"""DataArray.name and Dataset keys must be a string or None"""
- pass
+
+ def check_name(name: Hashable):
+ if isinstance(name, str):
+ if not name:
+ raise ValueError(
+ f"Invalid name {name!r} for DataArray or Dataset key: "
+ "string must be length 1 or greater for "
+ "serialization to netCDF files"
+ )
+ elif name is not None:
+ raise TypeError(
+ f"Invalid name {name!r} for DataArray or Dataset key: "
+ "must be either a string or None for serialization to netCDF "
+ "files"
+ )
+
+ for k in dataset.variables:
+ check_name(k)
def _validate_attrs(dataset, invalid_netcdf=False):
@@ -54,15 +170,89 @@ def _validate_attrs(dataset, invalid_netcdf=False):
A numpy.bool_ is only allowed when using the h5netcdf engine with
`invalid_netcdf=True`.
"""
- pass
+
+ valid_types = (str, Number, np.ndarray, np.number, list, tuple)
+ if invalid_netcdf:
+ valid_types += (np.bool_,)
+
+ def check_attr(name, value, valid_types):
+ if isinstance(name, str):
+ if not name:
+ raise ValueError(
+ f"Invalid name for attr {name!r}: string must be "
+ "length 1 or greater for serialization to "
+ "netCDF files"
+ )
+ else:
+ raise TypeError(
+ f"Invalid name for attr: {name!r} must be a string for "
+ "serialization to netCDF files"
+ )
+
+ if not isinstance(value, valid_types):
+ raise TypeError(
+ f"Invalid value for attr {name!r}: {value!r}. For serialization to "
+ "netCDF files, its value must be of one of the following types: "
+ f"{', '.join([vtype.__name__ for vtype in valid_types])}"
+ )
+
+ # Check attrs on the dataset itself
+ for k, v in dataset.attrs.items():
+ check_attr(k, v, valid_types)
+
+ # Check attrs on each variable within the dataset
+ for variable in dataset.variables.values():
+ for k, v in variable.attrs.items():
+ check_attr(k, v, valid_types)
+
+
+def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders):
+ for d in list(decoders):
+ if decode_cf is False and d in open_backend_dataset_parameters:
+ decoders[d] = False
+ if decoders[d] is None:
+ decoders.pop(d)
+ return decoders
+
+
+def _get_mtime(filename_or_obj):
+ # if passed an actual file path, augment the token with
+ # the file modification time
+ mtime = None
+
+ try:
+ path = os.fspath(filename_or_obj)
+ except TypeError:
+ path = None
+
+ if path and not is_remote_uri(path):
+ mtime = os.path.getmtime(os.path.expanduser(filename_or_obj))
+
+ return mtime
+
+
+def _protect_dataset_variables_inplace(dataset, cache):
+ for name, variable in dataset.variables.items():
+ if name not in dataset._indexes:
+ # no need to protect IndexVariable objects
+ data = indexing.CopyOnWriteArray(variable._data)
+ if cache:
+ data = indexing.MemoryCachedArray(data)
+ variable.data = data
def _finalize_store(write, store):
"""Finalize this store by explicitly syncing and closing"""
- pass
+ del write # ensure writing is done first
+ store.close()
+
+
+def _multi_file_closer(closers):
+ for closer in closers:
+ closer()
-def load_dataset(filename_or_obj, **kwargs) ->Dataset:
+def load_dataset(filename_or_obj, **kwargs) -> Dataset:
"""Open, load into memory, and close a Dataset from a file or file-like
object.
@@ -81,7 +271,11 @@ def load_dataset(filename_or_obj, **kwargs) ->Dataset:
--------
open_dataset
"""
- pass
+ if "cache" in kwargs:
+ raise TypeError("cache has no effect in this context")
+
+ with open_dataset(filename_or_obj, **kwargs) as ds:
+ return ds.load()
def load_dataarray(filename_or_obj, **kwargs):
@@ -103,21 +297,120 @@ def load_dataarray(filename_or_obj, **kwargs):
--------
open_dataarray
"""
- pass
-
-
-def open_dataset(filename_or_obj: (str | os.PathLike[Any] | BufferedIOBase |
- AbstractDataStore), *, engine: T_Engine=None, chunks: T_Chunks=None,
- cache: (bool | None)=None, decode_cf: (bool | None)=None,
- mask_and_scale: (bool | Mapping[str, bool] | None)=None, decode_times:
- (bool | Mapping[str, bool] | None)=None, decode_timedelta: (bool |
- Mapping[str, bool] | None)=None, use_cftime: (bool | Mapping[str, bool] |
- None)=None, concat_characters: (bool | Mapping[str, bool] | None)=None,
- decode_coords: (Literal['coordinates', 'all'] | bool | None)=None,
- drop_variables: (str | Iterable[str] | None)=None, inline_array: bool=
- False, chunked_array_type: (str | None)=None, from_array_kwargs: (dict[
- str, Any] | None)=None, backend_kwargs: (dict[str, Any] | None)=None,
- **kwargs) ->Dataset:
+ if "cache" in kwargs:
+ raise TypeError("cache has no effect in this context")
+
+ with open_dataarray(filename_or_obj, **kwargs) as da:
+ return da.load()
+
+
+def _chunk_ds(
+ backend_ds,
+ filename_or_obj,
+ engine,
+ chunks,
+ overwrite_encoded_chunks,
+ inline_array,
+ chunked_array_type,
+ from_array_kwargs,
+ **extra_tokens,
+):
+ chunkmanager = guess_chunkmanager(chunked_array_type)
+
+ # TODO refactor to move this dask-specific logic inside the DaskManager class
+ if isinstance(chunkmanager, DaskManager):
+ from dask.base import tokenize
+
+ mtime = _get_mtime(filename_or_obj)
+ token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
+ name_prefix = "open_dataset-"
+ else:
+ # not used
+ token = (None,)
+ name_prefix = None
+
+ variables = {}
+ for name, var in backend_ds.variables.items():
+ var_chunks = _get_chunk(var, chunks, chunkmanager)
+ variables[name] = _maybe_chunk(
+ name,
+ var,
+ var_chunks,
+ overwrite_encoded_chunks=overwrite_encoded_chunks,
+ name_prefix=name_prefix,
+ token=token,
+ inline_array=inline_array,
+ chunked_array_type=chunkmanager,
+ from_array_kwargs=from_array_kwargs.copy(),
+ )
+ return backend_ds._replace(variables)
+
+
+def _dataset_from_backend_dataset(
+ backend_ds,
+ filename_or_obj,
+ engine,
+ chunks,
+ cache,
+ overwrite_encoded_chunks,
+ inline_array,
+ chunked_array_type,
+ from_array_kwargs,
+ **extra_tokens,
+):
+ if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
+ raise ValueError(
+ f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
+ )
+
+ _protect_dataset_variables_inplace(backend_ds, cache)
+ if chunks is None:
+ ds = backend_ds
+ else:
+ ds = _chunk_ds(
+ backend_ds,
+ filename_or_obj,
+ engine,
+ chunks,
+ overwrite_encoded_chunks,
+ inline_array,
+ chunked_array_type,
+ from_array_kwargs,
+ **extra_tokens,
+ )
+
+ ds.set_close(backend_ds._close)
+
+ # Ensure source filename always stored in dataset object
+ if "source" not in ds.encoding:
+ path = getattr(filename_or_obj, "path", filename_or_obj)
+
+ if isinstance(path, (str, os.PathLike)):
+ ds.encoding["source"] = _normalize_path(path)
+
+ return ds
+
+
+def open_dataset(
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ engine: T_Engine = None,
+ chunks: T_Chunks = None,
+ cache: bool | None = None,
+ decode_cf: bool | None = None,
+ mask_and_scale: bool | Mapping[str, bool] | None = None,
+ decode_times: bool | Mapping[str, bool] | None = None,
+ decode_timedelta: bool | Mapping[str, bool] | None = None,
+ use_cftime: bool | Mapping[str, bool] | None = None,
+ concat_characters: bool | Mapping[str, bool] | None = None,
+ decode_coords: Literal["coordinates", "all"] | bool | None = None,
+ drop_variables: str | Iterable[str] | None = None,
+ inline_array: bool = False,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+ backend_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+) -> Dataset:
"""Open and decode a dataset from a file or file-like object.
Parameters
@@ -128,7 +421,9 @@ def open_dataset(filename_or_obj: (str | os.PathLike[Any] | BufferedIOBase |
ends with .gz, in which case the file is gunzipped and opened with
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
- engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None} , installed backend or subclass of xarray.backends.BackendEntrypoint, optional
+ engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\
+ , installed backend \
+ or subclass of xarray.backends.BackendEntrypoint, optional
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
"netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``)
@@ -263,20 +558,76 @@ def open_dataset(filename_or_obj: (str | os.PathLike[Any] | BufferedIOBase |
--------
open_mfdataset
"""
- pass
-
-
-def open_dataarray(filename_or_obj: (str | os.PathLike[Any] |
- BufferedIOBase | AbstractDataStore), *, engine: (T_Engine | None)=None,
- chunks: (T_Chunks | None)=None, cache: (bool | None)=None, decode_cf: (
- bool | None)=None, mask_and_scale: (bool | None)=None, decode_times: (
- bool | None)=None, decode_timedelta: (bool | None)=None, use_cftime: (
- bool | None)=None, concat_characters: (bool | None)=None, decode_coords:
- (Literal['coordinates', 'all'] | bool | None)=None, drop_variables: (
- str | Iterable[str] | None)=None, inline_array: bool=False,
- chunked_array_type: (str | None)=None, from_array_kwargs: (dict[str,
- Any] | None)=None, backend_kwargs: (dict[str, Any] | None)=None, **kwargs
- ) ->DataArray:
+
+ if cache is None:
+ cache = chunks is None
+
+ if backend_kwargs is not None:
+ kwargs.update(backend_kwargs)
+
+ if engine is None:
+ engine = plugins.guess_engine(filename_or_obj)
+
+ if from_array_kwargs is None:
+ from_array_kwargs = {}
+
+ backend = plugins.get_backend(engine)
+
+ decoders = _resolve_decoders_kwargs(
+ decode_cf,
+ open_backend_dataset_parameters=backend.open_dataset_parameters,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ decode_timedelta=decode_timedelta,
+ concat_characters=concat_characters,
+ use_cftime=use_cftime,
+ decode_coords=decode_coords,
+ )
+
+ overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
+ backend_ds = backend.open_dataset(
+ filename_or_obj,
+ drop_variables=drop_variables,
+ **decoders,
+ **kwargs,
+ )
+ ds = _dataset_from_backend_dataset(
+ backend_ds,
+ filename_or_obj,
+ engine,
+ chunks,
+ cache,
+ overwrite_encoded_chunks,
+ inline_array,
+ chunked_array_type,
+ from_array_kwargs,
+ drop_variables=drop_variables,
+ **decoders,
+ **kwargs,
+ )
+ return ds
+
+
+def open_dataarray(
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ engine: T_Engine | None = None,
+ chunks: T_Chunks | None = None,
+ cache: bool | None = None,
+ decode_cf: bool | None = None,
+ mask_and_scale: bool | None = None,
+ decode_times: bool | None = None,
+ decode_timedelta: bool | None = None,
+ use_cftime: bool | None = None,
+ concat_characters: bool | None = None,
+ decode_coords: Literal["coordinates", "all"] | bool | None = None,
+ drop_variables: str | Iterable[str] | None = None,
+ inline_array: bool = False,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+ backend_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+) -> DataArray:
"""Open an DataArray from a file or file-like object containing a single
data variable.
@@ -291,7 +642,9 @@ def open_dataarray(filename_or_obj: (str | os.PathLike[Any] |
ends with .gz, in which case the file is gunzipped and opened with
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
- engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None} , installed backend or subclass of xarray.backends.BackendEntrypoint, optional
+ engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\
+ , installed backend \
+ or subclass of xarray.backends.BackendEntrypoint, optional
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
"netcdf4".
@@ -412,11 +765,55 @@ def open_dataarray(filename_or_obj: (str | os.PathLike[Any] |
--------
open_dataset
"""
- pass
-
-def open_datatree(filename_or_obj: (str | os.PathLike[Any] | BufferedIOBase |
- AbstractDataStore), engine: T_Engine=None, **kwargs) ->DataTree:
+ dataset = open_dataset(
+ filename_or_obj,
+ decode_cf=decode_cf,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ engine=engine,
+ chunks=chunks,
+ cache=cache,
+ drop_variables=drop_variables,
+ inline_array=inline_array,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=from_array_kwargs,
+ backend_kwargs=backend_kwargs,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ **kwargs,
+ )
+
+ if len(dataset.data_vars) != 1:
+ raise ValueError(
+ "Given file dataset contains more than one data "
+ "variable. Please read with xarray.open_dataset and "
+ "then select the variable you want."
+ )
+ else:
+ (data_array,) = dataset.data_vars.values()
+
+ data_array.set_close(dataset._close)
+
+ # Reset names if they were changed during saving
+ # to ensure that we can 'roundtrip' perfectly
+ if DATAARRAY_NAME in dataset.attrs:
+ data_array.name = dataset.attrs[DATAARRAY_NAME]
+ del dataset.attrs[DATAARRAY_NAME]
+
+ if data_array.name == DATAARRAY_VARIABLE:
+ data_array.name = None
+
+ return data_array
+
+
+def open_datatree(
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ engine: T_Engine = None,
+ **kwargs,
+) -> DataTree:
"""
Open and decode a DataTree from a file or file-like object, creating one tree node for each group in the file.
@@ -432,18 +829,38 @@ def open_datatree(filename_or_obj: (str | os.PathLike[Any] | BufferedIOBase |
-------
xarray.DataTree
"""
- pass
-
-
-def open_mfdataset(paths: (str | NestedSequence[str | os.PathLike]), chunks:
- (T_Chunks | None)=None, concat_dim: (str | DataArray | Index | Sequence
- [str] | Sequence[DataArray] | Sequence[Index] | None)=None, compat:
- CompatOptions='no_conflicts', preprocess: (Callable[[Dataset], Dataset] |
- None)=None, engine: (T_Engine | None)=None, data_vars: (Literal['all',
- 'minimal', 'different'] | list[str])='all', coords='different', combine:
- Literal['by_coords', 'nested']='by_coords', parallel: bool=False, join:
- JoinOptions='outer', attrs_file: (str | os.PathLike | None)=None,
- combine_attrs: CombineAttrsOptions='override', **kwargs) ->Dataset:
+ if engine is None:
+ engine = plugins.guess_engine(filename_or_obj)
+
+ backend = plugins.get_backend(engine)
+
+ return backend.open_datatree(filename_or_obj, **kwargs)
+
+
+def open_mfdataset(
+ paths: str | NestedSequence[str | os.PathLike],
+ chunks: T_Chunks | None = None,
+ concat_dim: (
+ str
+ | DataArray
+ | Index
+ | Sequence[str]
+ | Sequence[DataArray]
+ | Sequence[Index]
+ | None
+ ) = None,
+ compat: CompatOptions = "no_conflicts",
+ preprocess: Callable[[Dataset], Dataset] | None = None,
+ engine: T_Engine | None = None,
+ data_vars: Literal["all", "minimal", "different"] | list[str] = "all",
+ coords="different",
+ combine: Literal["by_coords", "nested"] = "by_coords",
+ parallel: bool = False,
+ join: JoinOptions = "outer",
+ attrs_file: str | os.PathLike | None = None,
+ combine_attrs: CombineAttrsOptions = "override",
+ **kwargs,
+) -> Dataset:
"""Open multiple files as a single dataset.
If combine='by_coords' then the function ``combine_by_coords`` is used to combine
@@ -482,7 +899,8 @@ def open_mfdataset(paths: (str | NestedSequence[str | os.PathLike]), chunks:
combine : {"by_coords", "nested"}, optional
Whether ``xarray.combine_by_coords`` or ``xarray.combine_nested`` is used to
combine all the data. Default is to use ``xarray.combine_by_coords``.
- compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, default: "no_conflicts"
+ compat : {"identical", "equals", "broadcast_equals", \
+ "no_conflicts", "override"}, default: "no_conflicts"
String indicating how to compare variables of the same name for
potential conflicts when merging:
@@ -500,7 +918,9 @@ def open_mfdataset(paths: (str | NestedSequence[str | os.PathLike]), chunks:
If provided, call this function on each dataset prior to concatenation.
You can find the file-name from which each dataset was loaded in
``ds.encoding["source"]``.
- engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None} , installed backend or subclass of xarray.backends.BackendEntrypoint, optional
+ engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "zarr", None}\
+ , installed backend \
+ or subclass of xarray.backends.BackendEntrypoint, optional
Engine to use when reading files. If not provided, the default engine
is chosen based on available dependencies, with a preference for
"netcdf4".
@@ -549,7 +969,8 @@ def open_mfdataset(paths: (str | NestedSequence[str | os.PathLike]), chunks:
Path of the file used to read global attributes from.
By default global attributes are read from the first file provided,
with wildcard matches sorted by filename.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -615,21 +1036,244 @@ def open_mfdataset(paths: (str | NestedSequence[str | os.PathLike]), chunks:
.. [1] https://docs.xarray.dev/en/stable/dask.html
.. [2] https://docs.xarray.dev/en/stable/dask.html#chunking-and-performance
"""
- pass
-
-
-WRITEABLE_STORES: dict[T_NetcdfEngine, Callable] = {'netcdf4': backends.
- NetCDF4DataStore.open, 'scipy': backends.ScipyDataStore, 'h5netcdf':
- backends.H5NetCDFStore.open}
-
-
-def to_netcdf(dataset: Dataset, path_or_file: (str | os.PathLike | None)=
- None, mode: NetcdfWriteModes='w', format: (T_NetcdfTypes | None)=None,
- group: (str | None)=None, engine: (T_NetcdfEngine | None)=None,
- encoding: (Mapping[Hashable, Mapping[str, Any]] | None)=None,
- unlimited_dims: (Iterable[Hashable] | None)=None, compute: bool=True,
- multifile: bool=False, invalid_netcdf: bool=False) ->(tuple[ArrayWriter,
- AbstractDataStore] | bytes | Delayed | None):
+ paths = _find_absolute_paths(paths, engine=engine, **kwargs)
+
+ if not paths:
+ raise OSError("no files to open")
+
+ if combine == "nested":
+ if isinstance(concat_dim, (str, DataArray)) or concat_dim is None:
+ concat_dim = [concat_dim] # type: ignore[assignment]
+
+ # This creates a flat list which is easier to iterate over, whilst
+ # encoding the originally-supplied structure as "ids".
+ # The "ids" are not used at all if combine='by_coords`.
+ combined_ids_paths = _infer_concat_order_from_positions(paths)
+ ids, paths = (
+ list(combined_ids_paths.keys()),
+ list(combined_ids_paths.values()),
+ )
+ elif combine == "by_coords" and concat_dim is not None:
+ raise ValueError(
+ "When combine='by_coords', passing a value for `concat_dim` has no "
+ "effect. To manually combine along a specific dimension you should "
+ "instead specify combine='nested' along with a value for `concat_dim`.",
+ )
+
+ open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs)
+
+ if parallel:
+ import dask
+
+ # wrap the open_dataset, getattr, and preprocess with delayed
+ open_ = dask.delayed(open_dataset)
+ getattr_ = dask.delayed(getattr)
+ if preprocess is not None:
+ preprocess = dask.delayed(preprocess)
+ else:
+ open_ = open_dataset
+ getattr_ = getattr
+
+ datasets = [open_(p, **open_kwargs) for p in paths]
+ closers = [getattr_(ds, "_close") for ds in datasets]
+ if preprocess is not None:
+ datasets = [preprocess(ds) for ds in datasets]
+
+ if parallel:
+ # calling compute here will return the datasets/file_objs lists,
+ # the underlying datasets will still be stored as dask arrays
+ datasets, closers = dask.compute(datasets, closers)
+
+ # Combine all datasets, closing them in case of a ValueError
+ try:
+ if combine == "nested":
+ # Combined nested list by successive concat and merge operations
+ # along each dimension, using structure given by "ids"
+ combined = _nested_combine(
+ datasets,
+ concat_dims=concat_dim,
+ compat=compat,
+ data_vars=data_vars,
+ coords=coords,
+ ids=ids,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ elif combine == "by_coords":
+ # Redo ordering from coordinates, ignoring how they were ordered
+ # previously
+ combined = combine_by_coords(
+ datasets,
+ compat=compat,
+ data_vars=data_vars,
+ coords=coords,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ else:
+ raise ValueError(
+ f"{combine} is an invalid option for the keyword argument"
+ " ``combine``"
+ )
+ except ValueError:
+ for ds in datasets:
+ ds.close()
+ raise
+
+ combined.set_close(partial(_multi_file_closer, closers))
+
+ # read global attributes from the attrs_file or from the first dataset
+ if attrs_file is not None:
+ if isinstance(attrs_file, os.PathLike):
+ attrs_file = cast(str, os.fspath(attrs_file))
+ combined.attrs = datasets[paths.index(attrs_file)].attrs
+
+ return combined
+
+
+WRITEABLE_STORES: dict[T_NetcdfEngine, Callable] = {
+ "netcdf4": backends.NetCDF4DataStore.open,
+ "scipy": backends.ScipyDataStore,
+ "h5netcdf": backends.H5NetCDFStore.open,
+}
+
+
+# multifile=True returns writer and datastore
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike | None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ *,
+ multifile: Literal[True],
+ invalid_netcdf: bool = False,
+) -> tuple[ArrayWriter, AbstractDataStore]: ...
+
+
+# path=None writes to bytes
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ multifile: Literal[False] = False,
+ invalid_netcdf: bool = False,
+) -> bytes: ...
+
+
+# compute=False returns dask.Delayed
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ *,
+ compute: Literal[False],
+ multifile: Literal[False] = False,
+ invalid_netcdf: bool = False,
+) -> Delayed: ...
+
+
+# default return None
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: Literal[True] = True,
+ multifile: Literal[False] = False,
+ invalid_netcdf: bool = False,
+) -> None: ...
+
+
+# if compute cannot be evaluated at type check time
+# we may get back either Delayed or None
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = False,
+ multifile: Literal[False] = False,
+ invalid_netcdf: bool = False,
+) -> Delayed | None: ...
+
+
+# if multifile cannot be evaluated at type check time
+# we may get back either writer and datastore or Delayed or None
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = False,
+ multifile: bool = False,
+ invalid_netcdf: bool = False,
+) -> tuple[ArrayWriter, AbstractDataStore] | Delayed | None: ...
+
+
+# Any
+@overload
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike | None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = False,
+ multifile: bool = False,
+ invalid_netcdf: bool = False,
+) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: ...
+
+
+def to_netcdf(
+ dataset: Dataset,
+ path_or_file: str | os.PathLike | None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ multifile: bool = False,
+ invalid_netcdf: bool = False,
+) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
"""This function creates an appropriate datastore for writing a dataset to
disk as a netCDF file
@@ -637,17 +1281,142 @@ def to_netcdf(dataset: Dataset, path_or_file: (str | os.PathLike | None)=
The ``multifile`` argument is only for the private use of save_mfdataset.
"""
- pass
+ if isinstance(path_or_file, os.PathLike):
+ path_or_file = os.fspath(path_or_file)
+
+ if encoding is None:
+ encoding = {}
+
+ if path_or_file is None:
+ if engine is None:
+ engine = "scipy"
+ elif engine != "scipy":
+ raise ValueError(
+ "invalid engine for creating bytes with "
+ f"to_netcdf: {engine!r}. Only the default engine "
+ "or engine='scipy' is supported"
+ )
+ if not compute:
+ raise NotImplementedError(
+ "to_netcdf() with compute=False is not yet implemented when "
+ "returning bytes"
+ )
+ elif isinstance(path_or_file, str):
+ if engine is None:
+ engine = _get_default_engine(path_or_file)
+ path_or_file = _normalize_path(path_or_file)
+ else: # file-like object
+ engine = "scipy"
+
+ # validate Dataset keys, DataArray names, and attr keys/values
+ _validate_dataset_names(dataset)
+ _validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf")
+
+ try:
+ store_open = WRITEABLE_STORES[engine]
+ except KeyError:
+ raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}")
+
+ if format is not None:
+ format = format.upper() # type: ignore[assignment]
+
+ # handle scheduler specific logic
+ scheduler = _get_scheduler()
+ have_chunks = any(v.chunks is not None for v in dataset.variables.values())
+
+ autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
+ if autoclose and engine == "scipy":
+ raise NotImplementedError(
+ f"Writing netCDF files with the {engine} backend "
+ f"is not currently supported with dask's {scheduler} scheduler"
+ )
+
+ target = path_or_file if path_or_file is not None else BytesIO()
+ kwargs = dict(autoclose=True) if autoclose else {}
+ if invalid_netcdf:
+ if engine == "h5netcdf":
+ kwargs["invalid_netcdf"] = invalid_netcdf
+ else:
+ raise ValueError(
+ f"unrecognized option 'invalid_netcdf' for engine {engine}"
+ )
+ store = store_open(target, mode, format, group, **kwargs)
+
+ if unlimited_dims is None:
+ unlimited_dims = dataset.encoding.get("unlimited_dims", None)
+ if unlimited_dims is not None:
+ if isinstance(unlimited_dims, str) or not isinstance(unlimited_dims, Iterable):
+ unlimited_dims = [unlimited_dims]
+ else:
+ unlimited_dims = list(unlimited_dims)
+
+ writer = ArrayWriter()
+
+ # TODO: figure out how to refactor this logic (here and in save_mfdataset)
+ # to avoid this mess of conditionals
+ try:
+ # TODO: allow this work (setting up the file for writing array data)
+ # to be parallelized with dask
+ dump_to_store(
+ dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
+ )
+ if autoclose:
+ store.close()
+
+ if multifile:
+ return writer, store
+
+ writes = writer.sync(compute=compute)
+
+ if isinstance(target, BytesIO):
+ store.sync()
+ return target.getvalue()
+ finally:
+ if not multifile and compute:
+ store.close()
+ if not compute:
+ import dask
-def dump_to_store(dataset, store, writer=None, encoder=None, encoding=None,
- unlimited_dims=None):
+ return dask.delayed(_finalize_store)(writes, store)
+ return None
+
+
+def dump_to_store(
+ dataset, store, writer=None, encoder=None, encoding=None, unlimited_dims=None
+):
"""Store dataset contents to a backends.*DataStore object."""
- pass
+ if writer is None:
+ writer = ArrayWriter()
+
+ if encoding is None:
+ encoding = {}
+ variables, attrs = conventions.encode_dataset_coordinates(dataset)
-def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
- engine=None, compute=True, **kwargs):
+ check_encoding = set()
+ for k, enc in encoding.items():
+ # no need to shallow copy the variable again; that already happened
+ # in encode_dataset_coordinates
+ variables[k].encoding = enc
+ check_encoding.add(k)
+
+ if encoder:
+ variables, attrs = encoder(variables, attrs)
+
+ store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
+
+
+def save_mfdataset(
+ datasets,
+ paths,
+ mode="w",
+ format=None,
+ groups=None,
+ engine=None,
+ compute=True,
+ **kwargs,
+):
"""Write multiple datasets to disk as netCDF files simultaneously.
This function is intended for use with datasets consisting of dask.array
@@ -666,7 +1435,8 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
mode : {"w", "a"}, optional
Write ("w") or append ("a") mode. If mode="w", any existing file at
these locations will be overwritten.
- format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"}, optional
+ format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \
+ "NETCDF3_CLASSIC"}, optional
File format for the resulting netCDF file:
* NETCDF4: Data is stored in an HDF5 file, using netCDF4 API
@@ -720,22 +1490,243 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None,
>>> paths = [f"{y}.nc" for y in years]
>>> xr.save_mfdataset(datasets, paths)
"""
- pass
-
-
-def to_zarr(dataset: Dataset, store: (MutableMapping | str | os.PathLike[
- str] | None)=None, chunk_store: (MutableMapping | str | os.PathLike |
- None)=None, mode: (ZarrWriteModes | None)=None, synchronizer=None,
- group: (str | None)=None, encoding: (Mapping | None)=None, *, compute:
- bool=True, consolidated: (bool | None)=None, append_dim: (Hashable |
- None)=None, region: (Mapping[str, slice | Literal['auto']] | Literal[
- 'auto'] | None)=None, safe_chunks: bool=True, storage_options: (dict[
- str, str] | None)=None, zarr_version: (int | None)=None,
- write_empty_chunks: (bool | None)=None, chunkmanager_store_kwargs: (
- dict[str, Any] | None)=None) ->(backends.ZarrStore | Delayed):
+ if mode == "w" and len(set(paths)) < len(paths):
+ raise ValueError(
+ "cannot use mode='w' when writing multiple datasets to the same path"
+ )
+
+ for obj in datasets:
+ if not isinstance(obj, Dataset):
+ raise TypeError(
+ "save_mfdataset only supports writing Dataset "
+ f"objects, received type {type(obj)}"
+ )
+
+ if groups is None:
+ groups = [None] * len(datasets)
+
+ if len({len(datasets), len(paths), len(groups)}) > 1:
+ raise ValueError(
+ "must supply lists of the same length for the "
+ "datasets, paths and groups arguments to "
+ "save_mfdataset"
+ )
+
+ writers, stores = zip(
+ *[
+ to_netcdf(
+ ds,
+ path,
+ mode,
+ format,
+ group,
+ engine,
+ compute=compute,
+ multifile=True,
+ **kwargs,
+ )
+ for ds, path, group in zip(datasets, paths, groups)
+ ]
+ )
+
+ try:
+ writes = [w.sync(compute=compute) for w in writers]
+ finally:
+ if compute:
+ for store in stores:
+ store.close()
+
+ if not compute:
+ import dask
+
+ return dask.delayed(
+ [dask.delayed(_finalize_store)(w, s) for w, s in zip(writes, stores)]
+ )
+
+
+# compute=True returns ZarrStore
+@overload
+def to_zarr(
+ dataset: Dataset,
+ store: MutableMapping | str | os.PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | os.PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: Literal[True] = True,
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ write_empty_chunks: bool | None = None,
+ chunkmanager_store_kwargs: dict[str, Any] | None = None,
+) -> backends.ZarrStore: ...
+
+
+# compute=False returns dask.Delayed
+@overload
+def to_zarr(
+ dataset: Dataset,
+ store: MutableMapping | str | os.PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | os.PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: Literal[False],
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ write_empty_chunks: bool | None = None,
+ chunkmanager_store_kwargs: dict[str, Any] | None = None,
+) -> Delayed: ...
+
+
+def to_zarr(
+ dataset: Dataset,
+ store: MutableMapping | str | os.PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | os.PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: bool = True,
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ write_empty_chunks: bool | None = None,
+ chunkmanager_store_kwargs: dict[str, Any] | None = None,
+) -> backends.ZarrStore | Delayed:
"""This function creates an appropriate datastore for writing a dataset to
a zarr ztore
See `Dataset.to_zarr` for full API docs.
"""
- pass
+
+ # Load empty arrays to avoid bug saving zero length dimensions (Issue #5741)
+ for v in dataset.variables.values():
+ if v.size == 0:
+ v.load()
+
+ # expand str and path-like arguments
+ store = _normalize_path(store)
+ chunk_store = _normalize_path(chunk_store)
+
+ if storage_options is None:
+ mapper = store
+ chunk_mapper = chunk_store
+ else:
+ from fsspec import get_mapper
+
+ if not isinstance(store, str):
+ raise ValueError(
+ f"store must be a string to use storage_options. Got {type(store)}"
+ )
+ mapper = get_mapper(store, **storage_options)
+ if chunk_store is not None:
+ chunk_mapper = get_mapper(chunk_store, **storage_options)
+ else:
+ chunk_mapper = chunk_store
+
+ if encoding is None:
+ encoding = {}
+
+ if mode is None:
+ if append_dim is not None:
+ mode = "a"
+ elif region is not None:
+ mode = "r+"
+ else:
+ mode = "w-"
+
+ if mode not in ["a", "a-"] and append_dim is not None:
+ raise ValueError("cannot set append_dim unless mode='a' or mode=None")
+
+ if mode not in ["a", "a-", "r+"] and region is not None:
+ raise ValueError(
+ "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
+ )
+
+ if mode not in ["w", "w-", "a", "a-", "r+"]:
+ raise ValueError(
+ "The only supported options for mode are 'w', "
+ f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}"
+ )
+
+ # validate Dataset keys, DataArray names
+ _validate_dataset_names(dataset)
+
+ if zarr_version is None:
+ # default to 2 if store doesn't specify it's version (e.g. a path)
+ zarr_version = int(getattr(store, "_store_version", 2))
+
+ if consolidated is None and zarr_version > 2:
+ consolidated = False
+
+ if mode == "r+":
+ already_consolidated = consolidated
+ consolidate_on_close = False
+ else:
+ already_consolidated = False
+ consolidate_on_close = consolidated or consolidated is None
+ zstore = backends.ZarrStore.open_group(
+ store=mapper,
+ mode=mode,
+ synchronizer=synchronizer,
+ group=group,
+ consolidated=already_consolidated,
+ consolidate_on_close=consolidate_on_close,
+ chunk_store=chunk_mapper,
+ append_dim=append_dim,
+ write_region=region,
+ safe_chunks=safe_chunks,
+ stacklevel=4, # for Dataset.to_zarr()
+ zarr_version=zarr_version,
+ write_empty=write_empty_chunks,
+ )
+
+ if region is not None:
+ zstore._validate_and_autodetect_region(dataset)
+ # can't modify indexes with region writes
+ dataset = dataset.drop_vars(dataset.indexes)
+ if append_dim is not None and append_dim in region:
+ raise ValueError(
+ f"cannot list the same dimension in both ``append_dim`` and "
+ f"``region`` with to_zarr(), got {append_dim} in both"
+ )
+
+ if encoding and mode in ["a", "a-", "r+"]:
+ existing_var_names = set(zstore.zarr_group.array_keys())
+ for var_name in existing_var_names:
+ if var_name in encoding:
+ raise ValueError(
+ f"variable {var_name!r} already exists, but encoding was provided"
+ )
+
+ writer = ArrayWriter()
+ # TODO: figure out how to properly handle unlimited_dims
+ dump_to_store(dataset, zstore, writer, encoding=encoding)
+ writes = writer.sync(
+ compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
+ )
+
+ if compute:
+ _finalize_store(writes, zstore)
+ else:
+ import dask
+
+ return dask.delayed(_finalize_store)(writes, zstore)
+
+ return zstore
diff --git a/xarray/backends/common.py b/xarray/backends/common.py
index e51d2a10..e9bfdd9d 100644
--- a/xarray/backends/common.py
+++ b/xarray/backends/common.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import logging
import os
import time
@@ -6,19 +7,27 @@ import traceback
from collections.abc import Iterable
from glob import glob
from typing import TYPE_CHECKING, Any, ClassVar
+
import numpy as np
+
from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
+
if TYPE_CHECKING:
from io import BufferedIOBase
+
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.types import NestedSequence
+
+# Create a logger object, but don't add any handlers. Leave that to user code.
logger = logging.getLogger(__name__)
-NONE_VAR_NAME = '__values__'
+
+
+NONE_VAR_NAME = "__values__"
def _normalize_path(path):
@@ -40,11 +49,18 @@ def _normalize_path(path):
>>> print([type(p) for p in (paths_str,)])
[<class 'str'>]
"""
- pass
+ if isinstance(path, os.PathLike):
+ path = os.fspath(path)
+
+ if isinstance(path, str) and not is_remote_uri(path):
+ path = os.path.abspath(os.path.expanduser(path))
+ return path
-def _find_absolute_paths(paths: (str | os.PathLike | NestedSequence[str |
- os.PathLike]), **kwargs) ->list[str]:
+
+def _find_absolute_paths(
+ paths: str | os.PathLike | NestedSequence[str | os.PathLike], **kwargs
+) -> list[str]:
"""
Find absolute paths from the pattern.
@@ -65,16 +81,74 @@ def _find_absolute_paths(paths: (str | os.PathLike | NestedSequence[str |
>>> [Path(p).name for p in paths]
['common.py']
"""
- pass
+ if isinstance(paths, str):
+ if is_remote_uri(paths) and kwargs.get("engine", None) == "zarr":
+ try:
+ from fsspec.core import get_fs_token_paths
+ except ImportError as e:
+ raise ImportError(
+ "The use of remote URLs for opening zarr requires the package fsspec"
+ ) from e
+
+ fs, _, _ = get_fs_token_paths(
+ paths,
+ mode="rb",
+ storage_options=kwargs.get("backend_kwargs", {}).get(
+ "storage_options", {}
+ ),
+ expand=False,
+ )
+ tmp_paths = fs.glob(fs._strip_protocol(paths)) # finds directories
+ paths = [fs.get_mapper(path) for path in tmp_paths]
+ elif is_remote_uri(paths):
+ raise ValueError(
+ "cannot do wild-card matching for paths that are remote URLs "
+ f"unless engine='zarr' is specified. Got paths: {paths}. "
+ "Instead, supply paths as an explicit list of strings."
+ )
+ else:
+ paths = sorted(glob(_normalize_path(paths)))
+ elif isinstance(paths, os.PathLike):
+ paths = [os.fspath(paths)]
+ else:
+ paths = [os.fspath(p) if isinstance(p, os.PathLike) else p for p in paths]
+
+ return paths
+
+
+def _encode_variable_name(name):
+ if name is None:
+ name = NONE_VAR_NAME
+ return name
+
+
+def _decode_variable_name(name):
+ if name == NONE_VAR_NAME:
+ name = None
+ return name
+
+
+def _iter_nc_groups(root, parent="/"):
+ from xarray.core.treenode import NodePath
+
+ parent = NodePath(parent)
+ for path, group in root.groups.items():
+ gpath = parent / path
+ yield str(gpath)
+ yield from _iter_nc_groups(group, parent=gpath)
def find_root_and_group(ds):
"""Find the root and group name of a netCDF4/h5netcdf dataset."""
- pass
+ hierarchy = ()
+ while ds.parent is not None:
+ hierarchy = (ds.name.split("/")[-1],) + hierarchy
+ ds = ds.parent
+ group = "/" + "/".join(hierarchy)
+ return ds, group
-def robust_getitem(array, key, catch=Exception, max_retries=6,
- initial_delay=500):
+def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500):
"""
Robustly index an array, using retry logic with exponential backoff if any
of the errors ``catch`` are raised. The initial_delay is measured in ms.
@@ -82,16 +156,46 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
With the default settings, the maximum delay will be in the range of 32-64
seconds.
"""
- pass
+ assert max_retries >= 0
+ for n in range(max_retries + 1):
+ try:
+ return array[key]
+ except catch:
+ if n == max_retries:
+ raise
+ base_delay = initial_delay * 2**n
+ next_delay = base_delay + np.random.randint(base_delay)
+ msg = (
+ f"getitem failed, waiting {next_delay} ms before trying again "
+ f"({max_retries - n} tries remaining). Full traceback: {traceback.format_exc()}"
+ )
+ logger.debug(msg)
+ time.sleep(1e-3 * next_delay)
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
__slots__ = ()
+ def get_duck_array(self, dtype: np.typing.DTypeLike = None):
+ key = indexing.BasicIndexer((slice(None),) * self.ndim)
+ return self[key] # type: ignore [index]
+
class AbstractDataStore:
__slots__ = ()
+ def get_dimensions(self): # pragma: no cover
+ raise NotImplementedError()
+
+ def get_attrs(self): # pragma: no cover
+ raise NotImplementedError()
+
+ def get_variables(self): # pragma: no cover
+ raise NotImplementedError()
+
+ def get_encoding(self):
+ return {}
+
def load(self):
"""
This loads the variables and attributes simultaneously.
@@ -113,6 +217,13 @@ class AbstractDataStore:
This function will be called anytime variables or attributes
are requested, so care should be taken to make sure its fast.
"""
+ variables = FrozenDict(
+ (_decode_variable_name(k), v) for k, v in self.get_variables().items()
+ )
+ attributes = FrozenDict(self.get_attrs())
+ return variables, attributes
+
+ def close(self):
pass
def __enter__(self):
@@ -123,7 +234,7 @@ class AbstractDataStore:
class ArrayWriter:
- __slots__ = 'sources', 'targets', 'regions', 'lock'
+ __slots__ = ("sources", "targets", "regions", "lock")
def __init__(self, lock=None):
self.sources = []
@@ -131,6 +242,42 @@ class ArrayWriter:
self.regions = []
self.lock = lock
+ def add(self, source, target, region=None):
+ if is_chunked_array(source):
+ self.sources.append(source)
+ self.targets.append(target)
+ self.regions.append(region)
+ else:
+ if region:
+ target[region] = source
+ else:
+ target[...] = source
+
+ def sync(self, compute=True, chunkmanager_store_kwargs=None):
+ if self.sources:
+ chunkmanager = get_chunked_array_type(*self.sources)
+
+ # TODO: consider wrapping targets with dask.delayed, if this makes
+ # for any discernible difference in performance, e.g.,
+ # targets = [dask.delayed(t) for t in self.targets]
+
+ if chunkmanager_store_kwargs is None:
+ chunkmanager_store_kwargs = {}
+
+ delayed_store = chunkmanager.store(
+ self.sources,
+ self.targets,
+ lock=self.lock,
+ compute=compute,
+ flush=True,
+ regions=self.regions,
+ **chunkmanager_store_kwargs,
+ )
+ self.sources = []
+ self.targets = []
+ self.regions = []
+ return delayed_store
+
class AbstractWritableDataStore(AbstractDataStore):
__slots__ = ()
@@ -152,15 +299,26 @@ class AbstractWritableDataStore(AbstractDataStore):
attributes : dict-like
"""
- pass
+ variables = {k: self.encode_variable(v) for k, v in variables.items()}
+ attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
+ return variables, attributes
def encode_variable(self, v):
"""encode one variable"""
- pass
+ return v
def encode_attribute(self, a):
"""encode one attribute"""
- pass
+ return a
+
+ def set_dimension(self, dim, length): # pragma: no cover
+ raise NotImplementedError()
+
+ def set_attribute(self, k, v): # pragma: no cover
+ raise NotImplementedError()
+
+ def set_variable(self, k, v): # pragma: no cover
+ raise NotImplementedError()
def store_dataset(self, dataset):
"""
@@ -169,10 +327,16 @@ class AbstractWritableDataStore(AbstractDataStore):
so here we pass the whole dataset in instead of doing
dataset.variables
"""
- pass
-
- def store(self, variables, attributes, check_encoding_set=frozenset(),
- writer=None, unlimited_dims=None):
+ self.store(dataset, dataset.attrs)
+
+ def store(
+ self,
+ variables,
+ attributes,
+ check_encoding_set=frozenset(),
+ writer=None,
+ unlimited_dims=None,
+ ):
"""
Top level method for putting data on this store, this method:
- encodes variables/attributes
@@ -193,7 +357,16 @@ class AbstractWritableDataStore(AbstractDataStore):
List of dimension names that should be treated as unlimited
dimensions.
"""
- pass
+ if writer is None:
+ writer = ArrayWriter()
+
+ variables, attributes = self.encode(variables, attributes)
+
+ self.set_attributes(attributes)
+ self.set_dimensions(variables, unlimited_dims=unlimited_dims)
+ self.set_variables(
+ variables, check_encoding_set, writer, unlimited_dims=unlimited_dims
+ )
def set_attributes(self, attributes):
"""
@@ -205,10 +378,10 @@ class AbstractWritableDataStore(AbstractDataStore):
attributes : dict-like
Dictionary of key/value (attribute name / attribute) pairs
"""
- pass
+ for k, v in attributes.items():
+ self.set_attribute(k, v)
- def set_variables(self, variables, check_encoding_set, writer,
- unlimited_dims=None):
+ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):
"""
This provides a centralized method to set the variables on the data
store.
@@ -225,7 +398,15 @@ class AbstractWritableDataStore(AbstractDataStore):
List of dimension names that should be treated as unlimited
dimensions.
"""
- pass
+
+ for vn, v in variables.items():
+ name = _encode_variable_name(vn)
+ check = vn in check_encoding_set
+ target, source = self.prepare_variable(
+ name, v, check, unlimited_dims=unlimited_dims
+ )
+
+ writer.add(source, target)
def set_dimensions(self, variables, unlimited_dims=None):
"""
@@ -240,12 +421,39 @@ class AbstractWritableDataStore(AbstractDataStore):
List of dimension names that should be treated as unlimited
dimensions.
"""
- pass
+ if unlimited_dims is None:
+ unlimited_dims = set()
+
+ existing_dims = self.get_dimensions()
+
+ dims = {}
+ for v in unlimited_dims: # put unlimited_dims first
+ dims[v] = None
+ for v in variables.values():
+ dims.update(dict(zip(v.dims, v.shape)))
+
+ for dim, length in dims.items():
+ if dim in existing_dims and length != existing_dims[dim]:
+ raise ValueError(
+ "Unable to update size for existing dimension"
+ f"{dim!r} ({length} != {existing_dims[dim]})"
+ )
+ elif dim not in existing_dims:
+ is_unlimited = dim in unlimited_dims
+ self.set_dimension(dim, length, is_unlimited)
class WritableCFDataStore(AbstractWritableDataStore):
__slots__ = ()
+ def encode(self, variables, attributes):
+ # All NetCDF files get CF encoded by default, without this attempting
+ # to write times, for example, would fail.
+ variables, attributes = cf_encoder(variables, attributes)
+ variables = {k: self.encode_variable(v) for k, v in variables.items()}
+ attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
+ return variables, attributes
+
class BackendEntrypoint:
"""
@@ -280,39 +488,53 @@ class BackendEntrypoint:
A string with the URL to the backend's documentation.
The setting of this attribute is not mandatory.
"""
+
open_dataset_parameters: ClassVar[tuple | None] = None
- description: ClassVar[str] = ''
- url: ClassVar[str] = ''
+ description: ClassVar[str] = ""
+ url: ClassVar[str] = ""
- def __repr__(self) ->str:
- txt = f'<{type(self).__name__}>'
+ def __repr__(self) -> str:
+ txt = f"<{type(self).__name__}>"
if self.description:
- txt += f'\n {self.description}'
+ txt += f"\n {self.description}"
if self.url:
- txt += f'\n Learn more at {self.url}'
+ txt += f"\n Learn more at {self.url}"
return txt
- def open_dataset(self, filename_or_obj: (str | os.PathLike[Any] |
- BufferedIOBase | AbstractDataStore), *, drop_variables: (str |
- Iterable[str] | None)=None, **kwargs: Any) ->Dataset:
+ def open_dataset(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ drop_variables: str | Iterable[str] | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`.
"""
- pass
- def guess_can_open(self, filename_or_obj: (str | os.PathLike[Any] |
- BufferedIOBase | AbstractDataStore)) ->bool:
+ raise NotImplementedError()
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
"""
Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`.
"""
- pass
- def open_datatree(self, filename_or_obj: (str | os.PathLike[Any] |
- BufferedIOBase | AbstractDataStore), **kwargs: Any) ->DataTree:
+ return False
+
+ def open_datatree(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ **kwargs: Any,
+ ) -> DataTree:
"""
Backend open_datatree method used by Xarray in :py:func:`~xarray.open_datatree`.
"""
- pass
+
+ raise NotImplementedError()
+# mapping of engine name to (module name, BackendEntrypoint Class)
BACKEND_ENTRYPOINTS: dict[str, tuple[str | None, type[BackendEntrypoint]]] = {}
diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py
index 5058fcac..86d84f53 100644
--- a/xarray/backends/file_manager.py
+++ b/xarray/backends/file_manager.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import atexit
import contextlib
import io
@@ -7,15 +8,21 @@ import uuid
import warnings
from collections.abc import Hashable
from typing import Any
+
from xarray.backends.locks import acquire
from xarray.backends.lru_cache import LRUCache
from xarray.core import utils
from xarray.core.options import OPTIONS
-FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache(maxsize=OPTIONS[
- 'file_cache_maxsize'], on_evict=lambda k, v: v.close())
-assert FILE_CACHE.maxsize, 'file cache must be at least size one'
+
+# Global cache for storing open files.
+FILE_CACHE: LRUCache[Any, io.IOBase] = LRUCache(
+ maxsize=OPTIONS["file_cache_maxsize"], on_evict=lambda k, v: v.close()
+)
+assert FILE_CACHE.maxsize, "file cache must be at least size one"
+
REF_COUNTS: dict[Any, int] = {}
-_DEFAULT_MODE = utils.ReprObject('<unused>')
+
+_DEFAULT_MODE = utils.ReprObject("<unused>")
class FileManager:
@@ -28,7 +35,7 @@ class FileManager:
def acquire(self, needs_lock=True):
"""Acquire the file object from this manager."""
- pass
+ raise NotImplementedError()
def acquire_context(self, needs_lock=True):
"""Context manager for acquiring a file. Yields a file object.
@@ -37,11 +44,11 @@ class FileManager:
(i.e., removes it from any cache) if an exception is raised from the
context. It *does not* automatically close the file.
"""
- pass
+ raise NotImplementedError()
def close(self, needs_lock=True):
"""Close the file object associated with this manager, if needed."""
- pass
+ raise NotImplementedError()
class CachingFileManager(FileManager):
@@ -72,9 +79,17 @@ class CachingFileManager(FileManager):
"""
- def __init__(self, opener, *args, mode=_DEFAULT_MODE, kwargs=None, lock
- =None, cache=None, manager_id: (Hashable | None)=None, ref_counts=None
- ):
+ def __init__(
+ self,
+ opener,
+ *args,
+ mode=_DEFAULT_MODE,
+ kwargs=None,
+ lock=None,
+ cache=None,
+ manager_id: Hashable | None = None,
+ ref_counts=None,
+ ):
"""Initialize a CachingFileManager.
The cache, manager_id and ref_counts arguments exist solely to
@@ -118,15 +133,23 @@ class CachingFileManager(FileManager):
self._args = args
self._mode = mode
self._kwargs = {} if kwargs is None else dict(kwargs)
+
self._use_default_lock = lock is None or lock is False
self._lock = threading.Lock() if self._use_default_lock else lock
+
+ # cache[self._key] stores the file associated with this object.
if cache is None:
cache = FILE_CACHE
self._cache = cache
if manager_id is None:
+ # Each call to CachingFileManager should separately open files.
manager_id = str(uuid.uuid4())
self._manager_id = manager_id
self._key = self._make_key()
+
+ # ref_counts[self._key] stores the number of CachingFileManager objects
+ # in memory referencing this same file. We use this to know if we can
+ # close a file when the manager is deallocated.
if ref_counts is None:
ref_counts = REF_COUNTS
self._ref_counter = _RefCounter(ref_counts)
@@ -134,12 +157,23 @@ class CachingFileManager(FileManager):
def _make_key(self):
"""Make a key for caching files in the LRU cache."""
- pass
+ value = (
+ self._opener,
+ self._args,
+ "a" if self._mode == "w" else self._mode,
+ tuple(sorted(self._kwargs.items())),
+ self._manager_id,
+ )
+ return _HashedSequence(value)
@contextlib.contextmanager
def _optional_lock(self, needs_lock):
"""Context manager for optionally acquiring a lock."""
- pass
+ if needs_lock:
+ with self._lock:
+ yield
+ else:
+ yield
def acquire(self, needs_lock=True):
"""Acquire a file object from the manager.
@@ -156,53 +190,111 @@ class CachingFileManager(FileManager):
file-like
An open file object, as returned by ``opener(*args, **kwargs)``.
"""
- pass
+ file, _ = self._acquire_with_cache_info(needs_lock)
+ return file
@contextlib.contextmanager
def acquire_context(self, needs_lock=True):
"""Context manager for acquiring a file."""
- pass
+ file, cached = self._acquire_with_cache_info(needs_lock)
+ try:
+ yield file
+ except Exception:
+ if not cached:
+ self.close(needs_lock)
+ raise
def _acquire_with_cache_info(self, needs_lock=True):
"""Acquire a file, returning the file and whether it was cached."""
- pass
+ with self._optional_lock(needs_lock):
+ try:
+ file = self._cache[self._key]
+ except KeyError:
+ kwargs = self._kwargs
+ if self._mode is not _DEFAULT_MODE:
+ kwargs = kwargs.copy()
+ kwargs["mode"] = self._mode
+ file = self._opener(*self._args, **kwargs)
+ if self._mode == "w":
+ # ensure file doesn't get overridden when opened again
+ self._mode = "a"
+ self._cache[self._key] = file
+ return file, False
+ else:
+ return file, True
def close(self, needs_lock=True):
"""Explicitly close any associated file object (if necessary)."""
- pass
-
- def __del__(self) ->None:
+ # TODO: remove needs_lock if/when we have a reentrant lock in
+ # dask.distributed: https://github.com/dask/dask/issues/3832
+ with self._optional_lock(needs_lock):
+ default = None
+ file = self._cache.pop(self._key, default)
+ if file is not None:
+ file.close()
+
+ def __del__(self) -> None:
+ # If we're the only CachingFileManger referencing a unclosed file,
+ # remove it from the cache upon garbage collection.
+ #
+ # We keep track of our own reference count because we don't want to
+ # close files if another identical file manager needs it. This can
+ # happen if a CachingFileManager is pickled and unpickled without
+ # closing the original file.
ref_count = self._ref_counter.decrement(self._key)
+
if not ref_count and self._key in self._cache:
if acquire(self._lock, blocking=False):
+ # Only close files if we can do so immediately.
try:
self.close(needs_lock=False)
finally:
self._lock.release()
- if OPTIONS['warn_for_unclosed_files']:
+
+ if OPTIONS["warn_for_unclosed_files"]:
warnings.warn(
- f'deallocating {self}, but file is not already closed. This may indicate a bug.'
- , RuntimeWarning, stacklevel=2)
+ f"deallocating {self}, but file is not already closed. "
+ "This may indicate a bug.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
def __getstate__(self):
"""State for pickling."""
+ # cache is intentionally omitted: we don't want to try to serialize
+ # these global objects.
lock = None if self._use_default_lock else self._lock
- return (self._opener, self._args, self._mode, self._kwargs, lock,
- self._manager_id)
-
- def __setstate__(self, state) ->None:
+ return (
+ self._opener,
+ self._args,
+ self._mode,
+ self._kwargs,
+ lock,
+ self._manager_id,
+ )
+
+ def __setstate__(self, state) -> None:
"""Restore from a pickle."""
opener, args, mode, kwargs, lock, manager_id = state
- self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock,
- manager_id=manager_id)
+ self.__init__( # type: ignore
+ opener, *args, mode=mode, kwargs=kwargs, lock=lock, manager_id=manager_id
+ )
- def __repr__(self) ->str:
- args_string = ', '.join(map(repr, self._args))
+ def __repr__(self) -> str:
+ args_string = ", ".join(map(repr, self._args))
if self._mode is not _DEFAULT_MODE:
- args_string += f', mode={self._mode!r}'
+ args_string += f", mode={self._mode!r}"
return (
- f'{type(self).__name__}({self._opener!r}, {args_string}, kwargs={self._kwargs}, manager_id={self._manager_id!r})'
- )
+ f"{type(self).__name__}({self._opener!r}, {args_string}, "
+ f"kwargs={self._kwargs}, manager_id={self._manager_id!r})"
+ )
+
+
+@atexit.register
+def _remove_del_method():
+ # We don't need to close unclosed files at program exit, and may not be able
+ # to, because Python is cleaning up imports / globals.
+ del CachingFileManager.__del__
class _RefCounter:
@@ -212,6 +304,20 @@ class _RefCounter:
self._counts = counts
self._lock = threading.Lock()
+ def increment(self, name):
+ with self._lock:
+ count = self._counts[name] = self._counts.get(name, 0) + 1
+ return count
+
+ def decrement(self, name):
+ with self._lock:
+ count = self._counts[name] - 1
+ if count:
+ self._counts[name] = count
+ else:
+ del self._counts[name]
+ return count
+
class _HashedSequence(list):
"""Speedup repeated look-ups by caching hash values.
@@ -235,3 +341,16 @@ class DummyFileManager(FileManager):
def __init__(self, value):
self._value = value
+
+ def acquire(self, needs_lock=True):
+ del needs_lock # ignored
+ return self._value
+
+ @contextlib.contextmanager
+ def acquire_context(self, needs_lock=True):
+ del needs_lock
+ yield self._value
+
+ def close(self, needs_lock=True):
+ del needs_lock # ignored
+ self._value.close()
diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py
index bc9c82f4..3926ac05 100644
--- a/xarray/backends/h5netcdf_.py
+++ b/xarray/backends/h5netcdf_.py
@@ -1,62 +1,343 @@
from __future__ import annotations
+
import functools
import io
import os
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
-from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint, WritableCFDataStore, _normalize_path, find_root_and_group
+
+from xarray.backends.common import (
+ BACKEND_ENTRYPOINTS,
+ BackendEntrypoint,
+ WritableCFDataStore,
+ _normalize_path,
+ find_root_and_group,
+)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
from xarray.backends.locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock
-from xarray.backends.netCDF4_ import BaseNetCDF4Array, _encode_nc4_variable, _ensure_no_forward_slash_in_name, _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group
+from xarray.backends.netCDF4_ import (
+ BaseNetCDF4Array,
+ _encode_nc4_variable,
+ _ensure_no_forward_slash_in_name,
+ _extract_nc4_variable_encoding,
+ _get_datatype,
+ _nc4_require_group,
+)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
-from xarray.core.utils import FrozenDict, emit_user_level_warning, is_remote_uri, read_magic_number_from_file, try_read_magic_number_from_file_or_path
+from xarray.core.utils import (
+ FrozenDict,
+ emit_user_level_warning,
+ is_remote_uri,
+ read_magic_number_from_file,
+ try_read_magic_number_from_file_or_path,
+)
from xarray.core.variable import Variable
+
if TYPE_CHECKING:
from io import BufferedIOBase
+
from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
+ def get_array(self, needs_lock=True):
+ ds = self.datastore._acquire(needs_lock)
+ return ds.variables[self.variable_name]
def __getitem__(self, key):
- return indexing.explicit_indexing_adapter(key, self.shape, indexing
- .IndexingSupport.OUTER_1VECTOR, self._getitem)
+ return indexing.explicit_indexing_adapter(
+ key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
+ )
+
+ def _getitem(self, key):
+ with self.datastore.lock:
+ array = self.get_array(needs_lock=False)
+ return array[key]
+
+
+def _read_attributes(h5netcdf_var):
+ # GH451
+ # to ensure conventions decoding works properly on Python 3, decode all
+ # bytes attributes to strings
+ attrs = {}
+ for k, v in h5netcdf_var.attrs.items():
+ if k not in ["_FillValue", "missing_value"]:
+ if isinstance(v, bytes):
+ try:
+ v = v.decode("utf-8")
+ except UnicodeDecodeError:
+ emit_user_level_warning(
+ f"'utf-8' codec can't decode bytes for attribute "
+ f"{k!r} of h5netcdf object {h5netcdf_var.name!r}, "
+ f"returning bytes undecoded.",
+ UnicodeWarning,
+ )
+ attrs[k] = v
+ return attrs
-_extract_h5nc_encoding = functools.partial(_extract_nc4_variable_encoding,
- lsd_okay=False, h5py_okay=True, backend='h5netcdf', unlimited_dims=None)
+_extract_h5nc_encoding = functools.partial(
+ _extract_nc4_variable_encoding,
+ lsd_okay=False,
+ h5py_okay=True,
+ backend="h5netcdf",
+ unlimited_dims=None,
+)
+
+
+def _h5netcdf_create_group(dataset, name):
+ return dataset.create_group(name)
class H5NetCDFStore(WritableCFDataStore):
"""Store for reading and writing data via h5netcdf"""
- __slots__ = ('autoclose', 'format', 'is_remote', 'lock', '_filename',
- '_group', '_manager', '_mode')
- def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK,
- autoclose=False):
+ __slots__ = (
+ "autoclose",
+ "format",
+ "is_remote",
+ "lock",
+ "_filename",
+ "_group",
+ "_manager",
+ "_mode",
+ )
+
+ def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=False):
import h5netcdf
+
if isinstance(manager, (h5netcdf.File, h5netcdf.Group)):
if group is None:
root, group = find_root_and_group(manager)
else:
if type(manager) is not h5netcdf.File:
raise ValueError(
- 'must supply a h5netcdf.File if the group argument is provided'
- )
+ "must supply a h5netcdf.File if the group "
+ "argument is provided"
+ )
root = manager
manager = DummyFileManager(root)
+
self._manager = manager
self._group = group
self._mode = mode
self.format = None
+ # todo: utilizing find_root_and_group seems a bit clunky
+ # making filename available on h5netcdf.Group seems better
self._filename = find_root_and_group(self.ds)[0].filename
self.is_remote = is_remote_uri(self._filename)
self.lock = ensure_lock(lock)
self.autoclose = autoclose
+ @classmethod
+ def open(
+ cls,
+ filename,
+ mode="r",
+ format=None,
+ group=None,
+ lock=None,
+ autoclose=False,
+ invalid_netcdf=None,
+ phony_dims=None,
+ decode_vlen_strings=True,
+ driver=None,
+ driver_kwds=None,
+ ):
+ import h5netcdf
+
+ if isinstance(filename, bytes):
+ raise ValueError(
+ "can't open netCDF4/HDF5 as bytes "
+ "try passing a path or file-like object"
+ )
+ elif isinstance(filename, io.IOBase):
+ magic_number = read_magic_number_from_file(filename)
+ if not magic_number.startswith(b"\211HDF\r\n\032\n"):
+ raise ValueError(
+ f"{magic_number} is not the signature of a valid netCDF4 file"
+ )
+
+ if format not in [None, "NETCDF4"]:
+ raise ValueError("invalid format for h5netcdf backend")
+
+ kwargs = {
+ "invalid_netcdf": invalid_netcdf,
+ "decode_vlen_strings": decode_vlen_strings,
+ "driver": driver,
+ }
+ if driver_kwds is not None:
+ kwargs.update(driver_kwds)
+ if phony_dims is not None:
+ kwargs["phony_dims"] = phony_dims
+
+ if lock is None:
+ if mode == "r":
+ lock = HDF5_LOCK
+ else:
+ lock = combine_locks([HDF5_LOCK, get_write_lock(filename)])
+
+ manager = CachingFileManager(h5netcdf.File, filename, mode=mode, kwargs=kwargs)
+ return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose)
+
+ def _acquire(self, needs_lock=True):
+ with self._manager.acquire_context(needs_lock) as root:
+ ds = _nc4_require_group(
+ root, self._group, self._mode, create_group=_h5netcdf_create_group
+ )
+ return ds
+
+ @property
+ def ds(self):
+ return self._acquire()
+
+ def open_store_variable(self, name, var):
+ import h5py
+
+ dimensions = var.dimensions
+ data = indexing.LazilyIndexedArray(H5NetCDFArrayWrapper(name, self))
+ attrs = _read_attributes(var)
+
+ # netCDF4 specific encoding
+ encoding = {
+ "chunksizes": var.chunks,
+ "fletcher32": var.fletcher32,
+ "shuffle": var.shuffle,
+ }
+ if var.chunks:
+ encoding["preferred_chunks"] = dict(zip(var.dimensions, var.chunks))
+ # Convert h5py-style compression options to NetCDF4-Python
+ # style, if possible
+ if var.compression == "gzip":
+ encoding["zlib"] = True
+ encoding["complevel"] = var.compression_opts
+ elif var.compression is not None:
+ encoding["compression"] = var.compression
+ encoding["compression_opts"] = var.compression_opts
+
+ # save source so __repr__ can detect if it's local or not
+ encoding["source"] = self._filename
+ encoding["original_shape"] = data.shape
+
+ vlen_dtype = h5py.check_dtype(vlen=var.dtype)
+ if vlen_dtype is str:
+ encoding["dtype"] = str
+ elif vlen_dtype is not None: # pragma: no cover
+ # xarray doesn't support writing arbitrary vlen dtypes yet.
+ pass
+ else:
+ encoding["dtype"] = var.dtype
+
+ return Variable(dimensions, data, attrs, encoding)
+
+ def get_variables(self):
+ return FrozenDict(
+ (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
+ )
+
+ def get_attrs(self):
+ return FrozenDict(_read_attributes(self.ds))
+
+ def get_dimensions(self):
+ return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items())
+
+ def get_encoding(self):
+ return {
+ "unlimited_dims": {
+ k for k, v in self.ds.dimensions.items() if v.isunlimited()
+ }
+ }
+
+ def set_dimension(self, name, length, is_unlimited=False):
+ _ensure_no_forward_slash_in_name(name)
+ if is_unlimited:
+ self.ds.dimensions[name] = None
+ self.ds.resize_dimension(name, length)
+ else:
+ self.ds.dimensions[name] = length
+
+ def set_attribute(self, key, value):
+ self.ds.attrs[key] = value
+
+ def encode_variable(self, variable):
+ return _encode_nc4_variable(variable)
+
+ def prepare_variable(
+ self, name, variable, check_encoding=False, unlimited_dims=None
+ ):
+ import h5py
+
+ _ensure_no_forward_slash_in_name(name)
+ attrs = variable.attrs.copy()
+ dtype = _get_datatype(variable, raise_on_invalid_encoding=check_encoding)
+
+ fillvalue = attrs.pop("_FillValue", None)
+
+ if dtype is str:
+ dtype = h5py.special_dtype(vlen=str)
+
+ encoding = _extract_h5nc_encoding(variable, raise_on_invalid=check_encoding)
+ kwargs = {}
+
+ # Convert from NetCDF4-Python style compression settings to h5py style
+ # If both styles are used together, h5py takes precedence
+ # If set_encoding=True, raise ValueError in case of mismatch
+ if encoding.pop("zlib", False):
+ if check_encoding and encoding.get("compression") not in (None, "gzip"):
+ raise ValueError("'zlib' and 'compression' encodings mismatch")
+ encoding.setdefault("compression", "gzip")
+
+ if (
+ check_encoding
+ and "complevel" in encoding
+ and "compression_opts" in encoding
+ and encoding["complevel"] != encoding["compression_opts"]
+ ):
+ raise ValueError("'complevel' and 'compression_opts' encodings mismatch")
+ complevel = encoding.pop("complevel", 0)
+ if complevel != 0:
+ encoding.setdefault("compression_opts", complevel)
+
+ encoding["chunks"] = encoding.pop("chunksizes", None)
+
+ # Do not apply compression, filters or chunking to scalars.
+ if variable.shape:
+ for key in [
+ "compression",
+ "compression_opts",
+ "shuffle",
+ "chunks",
+ "fletcher32",
+ ]:
+ if key in encoding:
+ kwargs[key] = encoding[key]
+ if name not in self.ds:
+ nc4_var = self.ds.create_variable(
+ name,
+ dtype=dtype,
+ dimensions=variable.dims,
+ fillvalue=fillvalue,
+ **kwargs,
+ )
+ else:
+ nc4_var = self.ds[name]
+
+ for k, v in attrs.items():
+ nc4_var.attrs[k] = v
+
+ target = H5NetCDFArrayWrapper(name, self)
+
+ return target, variable.data
+
+ def sync(self):
+ self.ds.sync()
+
+ def close(self, **kwargs):
+ self._manager.close(**kwargs)
+
class H5netcdfBackendEntrypoint(BackendEntrypoint):
"""
@@ -79,12 +360,142 @@ class H5netcdfBackendEntrypoint(BackendEntrypoint):
backends.NetCDF4BackendEntrypoint
backends.ScipyBackendEntrypoint
"""
+
description = (
- 'Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray'
+ "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using h5netcdf in Xarray"
+ )
+ url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html"
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
+ magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
+ if magic_number is not None:
+ return magic_number.startswith(b"\211HDF\r\n\032\n")
+
+ if isinstance(filename_or_obj, (str, os.PathLike)):
+ _, ext = os.path.splitext(filename_or_obj)
+ return ext in {".nc", ".nc4", ".cdf"}
+
+ return False
+
+ def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ format=None,
+ group=None,
+ lock=None,
+ invalid_netcdf=None,
+ phony_dims=None,
+ decode_vlen_strings=True,
+ driver=None,
+ driver_kwds=None,
+ ) -> Dataset:
+ filename_or_obj = _normalize_path(filename_or_obj)
+ store = H5NetCDFStore.open(
+ filename_or_obj,
+ format=format,
+ group=group,
+ lock=lock,
+ invalid_netcdf=invalid_netcdf,
+ phony_dims=phony_dims,
+ decode_vlen_strings=decode_vlen_strings,
+ driver=driver,
+ driver_kwds=driver_kwds,
)
- url = (
- 'https://docs.xarray.dev/en/stable/generated/xarray.backends.H5netcdfBackendEntrypoint.html'
+
+ store_entrypoint = StoreBackendEntrypoint()
+
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
)
+ return ds
+
+ def open_datatree(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ format=None,
+ group: str | Iterable[str] | Callable | None = None,
+ lock=None,
+ invalid_netcdf=None,
+ phony_dims=None,
+ decode_vlen_strings=True,
+ driver=None,
+ driver_kwds=None,
+ **kwargs,
+ ) -> DataTree:
+ from xarray.backends.api import open_dataset
+ from xarray.backends.common import _iter_nc_groups
+ from xarray.core.datatree import DataTree
+ from xarray.core.treenode import NodePath
+ from xarray.core.utils import close_on_error
+
+ filename_or_obj = _normalize_path(filename_or_obj)
+ store = H5NetCDFStore.open(
+ filename_or_obj,
+ format=format,
+ group=group,
+ lock=lock,
+ invalid_netcdf=invalid_netcdf,
+ phony_dims=phony_dims,
+ decode_vlen_strings=decode_vlen_strings,
+ driver=driver,
+ driver_kwds=driver_kwds,
+ )
+ if group:
+ parent = NodePath("/") / NodePath(group)
+ else:
+ parent = NodePath("/")
+
+ manager = store._manager
+ ds = open_dataset(store, **kwargs)
+ tree_root = DataTree.from_dict({str(parent): ds})
+ for path_group in _iter_nc_groups(store.ds, parent=parent):
+ group_store = H5NetCDFStore(manager, group=path_group, **kwargs)
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(group_store):
+ ds = store_entrypoint.open_dataset(
+ group_store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
+ tree_root._set_item(
+ path_group,
+ new_node,
+ allow_overwrite=False,
+ new_nodes_along_path=True,
+ )
+ return tree_root
-BACKEND_ENTRYPOINTS['h5netcdf'] = 'h5netcdf', H5netcdfBackendEntrypoint
+BACKEND_ENTRYPOINTS["h5netcdf"] = ("h5netcdf", H5netcdfBackendEntrypoint)
diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py
index cf8cb06f..69cef309 100644
--- a/xarray/backends/locks.py
+++ b/xarray/backends/locks.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import multiprocessing
import threading
import uuid
@@ -8,6 +9,9 @@ from typing import Any, ClassVar
from weakref import WeakValueDictionary
+# SerializableLock is adapted from Dask:
+# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224
+# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
class SerializableLock:
"""A Serializable per-process Lock
@@ -35,12 +39,14 @@ class SerializableLock:
The creation of locks is itself not threadsafe.
"""
- _locks: ClassVar[WeakValueDictionary[Hashable, threading.Lock]
- ] = WeakValueDictionary()
+
+ _locks: ClassVar[WeakValueDictionary[Hashable, threading.Lock]] = (
+ WeakValueDictionary()
+ )
token: Hashable
lock: threading.Lock
- def __init__(self, token: (Hashable | None)=None):
+ def __init__(self, token: Hashable | None = None):
self.token = token or str(uuid.uuid4())
if self.token in SerializableLock._locks:
self.lock = SerializableLock._locks[self.token]
@@ -48,12 +54,21 @@ class SerializableLock:
self.lock = threading.Lock()
SerializableLock._locks[self.token] = self.lock
+ def acquire(self, *args, **kwargs):
+ return self.lock.acquire(*args, **kwargs)
+
+ def release(self, *args, **kwargs):
+ return self.lock.release(*args, **kwargs)
+
def __enter__(self):
self.lock.__enter__()
def __exit__(self, *args):
self.lock.__exit__(*args)
+ def locked(self):
+ return self.lock.locked()
+
def __getstate__(self):
return self.token
@@ -61,14 +76,33 @@ class SerializableLock:
self.__init__(token)
def __str__(self):
- return f'<{self.__class__.__name__}: {self.token}>'
+ return f"<{self.__class__.__name__}: {self.token}>"
+
__repr__ = __str__
+# Locks used by multiple backends.
+# Neither HDF5 nor the netCDF-C library are thread-safe.
HDF5_LOCK = SerializableLock()
NETCDFC_LOCK = SerializableLock()
-_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary(
- )
+
+
+_FILE_LOCKS: MutableMapping[Any, threading.Lock] = weakref.WeakValueDictionary()
+
+
+def _get_threaded_lock(key):
+ try:
+ lock = _FILE_LOCKS[key]
+ except KeyError:
+ lock = _FILE_LOCKS[key] = threading.Lock()
+ return lock
+
+
+def _get_multiprocessing_lock(key):
+ # TODO: make use of the key -- maybe use locket.py?
+ # https://github.com/mwilliamson/locket.py
+ del key # unused
+ return multiprocessing.Lock()
def _get_lock_maker(scheduler=None):
@@ -83,10 +117,26 @@ def _get_lock_maker(scheduler=None):
--------
dask.utils.get_scheduler_lock
"""
- pass
-
-def _get_scheduler(get=None, collection=None) ->(str | None):
+ if scheduler is None:
+ return _get_threaded_lock
+ elif scheduler == "threaded":
+ return _get_threaded_lock
+ elif scheduler == "multiprocessing":
+ return _get_multiprocessing_lock
+ elif scheduler == "distributed":
+ # Lazy import distributed since it is can add a significant
+ # amount of time to import
+ try:
+ from dask.distributed import Lock as DistributedLock
+ except ImportError:
+ DistributedLock = None
+ return DistributedLock
+ else:
+ raise KeyError(scheduler)
+
+
+def _get_scheduler(get=None, collection=None) -> str | None:
"""Determine the dask scheduler that is being used.
None is returned if no dask scheduler is active.
@@ -95,7 +145,33 @@ def _get_scheduler(get=None, collection=None) ->(str | None):
--------
dask.base.get_scheduler
"""
- pass
+ try:
+ # Fix for bug caused by dask installation that doesn't involve the toolz library
+ # Issue: 4164
+ import dask
+ from dask.base import get_scheduler # noqa: F401
+
+ actual_get = get_scheduler(get, collection)
+ except ImportError:
+ return None
+
+ try:
+ from dask.distributed import Client
+
+ if isinstance(actual_get.__self__, Client):
+ return "distributed"
+ except (ImportError, AttributeError):
+ pass
+
+ try:
+ # As of dask=2.6, dask.multiprocessing requires cloudpickle to be installed
+ # Dependency removed in https://github.com/dask/dask/pull/5511
+ if actual_get is dask.multiprocessing.get:
+ return "multiprocessing"
+ except AttributeError:
+ pass
+
+ return "threaded"
def get_write_lock(key):
@@ -110,7 +186,9 @@ def get_write_lock(key):
-------
Lock object that can be used like a threading.Lock object.
"""
- pass
+ scheduler = _get_scheduler()
+ lock_maker = _get_lock_maker(scheduler)
+ return lock_maker(key)
def acquire(lock, blocking=True):
@@ -119,7 +197,16 @@ def acquire(lock, blocking=True):
Includes backwards compatibility hacks for old versions of Python, dask
and dask-distributed.
"""
- pass
+ if blocking:
+ # no arguments needed
+ return lock.acquire()
+ else:
+ # "blocking" keyword argument not supported for:
+ # - threading.Lock on Python 2.
+ # - dask.SerializableLock with dask v1.0.0 or earlier.
+ # - multiprocessing.Lock calls the argument "block" instead.
+ # - dask.distributed.Lock uses the blocking argument as the first one
+ return lock.acquire(blocking)
class CombinedLock:
@@ -130,7 +217,14 @@ class CombinedLock:
"""
def __init__(self, locks):
- self.locks = tuple(set(locks))
+ self.locks = tuple(set(locks)) # remove duplicates
+
+ def acquire(self, blocking=True):
+ return all(acquire(lock, blocking=blocking) for lock in self.locks)
+
+ def release(self):
+ for lock in self.locks:
+ lock.release()
def __enter__(self):
for lock in self.locks:
@@ -140,25 +234,52 @@ class CombinedLock:
for lock in self.locks:
lock.__exit__(*args)
+ def locked(self):
+ return any(lock.locked for lock in self.locks)
+
def __repr__(self):
- return f'CombinedLock({list(self.locks)!r})'
+ return f"CombinedLock({list(self.locks)!r})"
class DummyLock:
"""DummyLock provides the lock API without any actual locking."""
+ def acquire(self, blocking=True):
+ pass
+
+ def release(self):
+ pass
+
def __enter__(self):
pass
def __exit__(self, *args):
pass
+ def locked(self):
+ return False
+
def combine_locks(locks):
"""Combine a sequence of locks into a single lock."""
- pass
+ all_locks = []
+ for lock in locks:
+ if isinstance(lock, CombinedLock):
+ all_locks.extend(lock.locks)
+ elif lock is not None:
+ all_locks.append(lock)
+
+ num_locks = len(all_locks)
+ if num_locks > 1:
+ return CombinedLock(all_locks)
+ elif num_locks == 1:
+ return all_locks[0]
+ else:
+ return DummyLock()
def ensure_lock(lock):
"""Ensure that the given object is a lock."""
- pass
+ if lock is None or lock is False:
+ return DummyLock()
+ return lock
diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py
index 81212837..c09bcb19 100644
--- a/xarray/backends/lru_cache.py
+++ b/xarray/backends/lru_cache.py
@@ -1,10 +1,12 @@
from __future__ import annotations
+
import threading
from collections import OrderedDict
from collections.abc import Iterator, MutableMapping
from typing import Any, Callable, TypeVar
-K = TypeVar('K')
-V = TypeVar('V')
+
+K = TypeVar("K")
+V = TypeVar("V")
class LRUCache(MutableMapping[K, V]):
@@ -21,14 +23,15 @@ class LRUCache(MutableMapping[K, V]):
The ``maxsize`` property can be used to view or adjust the capacity of
the cache, e.g., ``cache.maxsize = new_size``.
"""
+
_cache: OrderedDict[K, V]
_maxsize: int
_lock: threading.RLock
_on_evict: Callable[[K, V], Any] | None
- __slots__ = '_cache', '_lock', '_maxsize', '_on_evict'
- def __init__(self, maxsize: int, on_evict: (Callable[[K, V], Any] |
- None)=None):
+ __slots__ = ("_cache", "_lock", "_maxsize", "_on_evict")
+
+ def __init__(self, maxsize: int, on_evict: Callable[[K, V], Any] | None = None):
"""
Parameters
----------
@@ -39,50 +42,63 @@ class LRUCache(MutableMapping[K, V]):
evicted.
"""
if not isinstance(maxsize, int):
- raise TypeError('maxsize must be an integer')
+ raise TypeError("maxsize must be an integer")
if maxsize < 0:
- raise ValueError('maxsize must be non-negative')
+ raise ValueError("maxsize must be non-negative")
self._maxsize = maxsize
self._cache = OrderedDict()
self._lock = threading.RLock()
self._on_evict = on_evict
- def __getitem__(self, key: K) ->V:
+ def __getitem__(self, key: K) -> V:
+ # record recent use of the key by moving it to the front of the list
with self._lock:
value = self._cache[key]
self._cache.move_to_end(key)
return value
- def _enforce_size_limit(self, capacity: int) ->None:
+ def _enforce_size_limit(self, capacity: int) -> None:
"""Shrink the cache if necessary, evicting the oldest items."""
- pass
+ while len(self._cache) > capacity:
+ key, value = self._cache.popitem(last=False)
+ if self._on_evict is not None:
+ self._on_evict(key, value)
- def __setitem__(self, key: K, value: V) ->None:
+ def __setitem__(self, key: K, value: V) -> None:
with self._lock:
if key in self._cache:
+ # insert the new value at the end
del self._cache[key]
self._cache[key] = value
elif self._maxsize:
+ # make room if necessary
self._enforce_size_limit(self._maxsize - 1)
self._cache[key] = value
elif self._on_evict is not None:
+ # not saving, immediately evict
self._on_evict(key, value)
- def __delitem__(self, key: K) ->None:
+ def __delitem__(self, key: K) -> None:
del self._cache[key]
- def __iter__(self) ->Iterator[K]:
+ def __iter__(self) -> Iterator[K]:
+ # create a list, so accessing the cache during iteration cannot change
+ # the iteration order
return iter(list(self._cache))
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self._cache)
@property
- def maxsize(self) ->int:
+ def maxsize(self) -> int:
"""Maximum number of items can be held in the cache."""
- pass
+ return self._maxsize
@maxsize.setter
- def maxsize(self, size: int) ->None:
+ def maxsize(self, size: int) -> None:
"""Resize the cache, evicting the oldest items if necessary."""
- pass
+ if size < 0:
+ raise ValueError("maxsize must be non-negative")
+ with self._lock:
+ self._enforce_size_limit(size)
+ self._maxsize = size
diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py
index df716315..9df6701d 100644
--- a/xarray/backends/memory.py
+++ b/xarray/backends/memory.py
@@ -1,6 +1,9 @@
from __future__ import annotations
+
import copy
+
import numpy as np
+
from xarray.backends.common import AbstractWritableDataStore
from xarray.core.variable import Variable
@@ -16,3 +19,29 @@ class InMemoryDataStore(AbstractWritableDataStore):
def __init__(self, variables=None, attributes=None):
self._variables = {} if variables is None else variables
self._attributes = {} if attributes is None else attributes
+
+ def get_attrs(self):
+ return self._attributes
+
+ def get_variables(self):
+ return self._variables
+
+ def get_dimensions(self):
+ dims = {}
+ for v in self._variables.values():
+ for d, s in v.dims.items():
+ dims[d] = s
+ return dims
+
+ def prepare_variable(self, k, v, *args, **kwargs):
+ new_var = Variable(v.dims, np.empty_like(v), v.attrs)
+ self._variables[k] = new_var
+ return new_var, v.data
+
+ def set_attribute(self, k, v):
+ # copy to imitate writing to disk.
+ self._attributes[k] = copy.deepcopy(v)
+
+ def set_dimension(self, dim, length, unlimited_dims=None):
+ # in this model, dimensions are accounted for in the variables
+ pass
diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py
index 065e118c..302c002a 100644
--- a/xarray/backends/netCDF4_.py
+++ b/xarray/backends/netCDF4_.py
@@ -1,40 +1,75 @@
from __future__ import annotations
+
import functools
import operator
import os
from collections.abc import Callable, Iterable
from contextlib import suppress
from typing import TYPE_CHECKING, Any
+
import numpy as np
+
from xarray import coding
-from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, _normalize_path, find_root_and_group, robust_getitem
+from xarray.backends.common import (
+ BACKEND_ENTRYPOINTS,
+ BackendArray,
+ BackendEntrypoint,
+ WritableCFDataStore,
+ _normalize_path,
+ find_root_and_group,
+ robust_getitem,
+)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
-from xarray.backends.locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock
+from xarray.backends.locks import (
+ HDF5_LOCK,
+ NETCDFC_LOCK,
+ combine_locks,
+ ensure_lock,
+ get_write_lock,
+)
from xarray.backends.netcdf3 import encode_nc3_attr_value, encode_nc3_variable
from xarray.backends.store import StoreBackendEntrypoint
from xarray.coding.variables import pop_to
from xarray.core import indexing
-from xarray.core.utils import FrozenDict, close_on_error, is_remote_uri, try_read_magic_number_from_path
+from xarray.core.utils import (
+ FrozenDict,
+ close_on_error,
+ is_remote_uri,
+ try_read_magic_number_from_path,
+)
from xarray.core.variable import Variable
+
if TYPE_CHECKING:
from io import BufferedIOBase
+
from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
-_endian_lookup = {'=': 'native', '>': 'big', '<': 'little', '|': 'native'}
+
+# This lookup table maps from dtype.byteorder to a readable endian
+# string used by netCDF4.
+_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"}
+
NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])
class BaseNetCDF4Array(BackendArray):
- __slots__ = 'datastore', 'dtype', 'shape', 'variable_name'
+ __slots__ = ("datastore", "dtype", "shape", "variable_name")
def __init__(self, variable_name, datastore):
self.datastore = datastore
self.variable_name = variable_name
+
array = self.get_array()
self.shape = array.shape
+
dtype = array.dtype
if dtype is str:
+ # use object dtype (with additional vlen string metadata) because that's
+ # the only way in numpy to represent variable length strings and to
+ # check vlen string dtype in further steps
+ # it also prevents automatic string concatenation via
+ # conventions.decode_cf_variable
dtype = coding.strings.create_vlen_dtype(str)
self.dtype = dtype
@@ -45,13 +80,239 @@ class BaseNetCDF4Array(BackendArray):
if self.datastore.autoclose:
self.datastore.close(needs_lock=False)
+ def get_array(self, needs_lock=True):
+ raise NotImplementedError("Virtual Method")
+
class NetCDF4ArrayWrapper(BaseNetCDF4Array):
__slots__ = ()
+ def get_array(self, needs_lock=True):
+ ds = self.datastore._acquire(needs_lock)
+ variable = ds.variables[self.variable_name]
+ variable.set_auto_maskandscale(False)
+ # only added in netCDF4-python v1.2.8
+ with suppress(AttributeError):
+ variable.set_auto_chartostring(False)
+ return variable
+
def __getitem__(self, key):
- return indexing.explicit_indexing_adapter(key, self.shape, indexing
- .IndexingSupport.OUTER, self._getitem)
+ return indexing.explicit_indexing_adapter(
+ key, self.shape, indexing.IndexingSupport.OUTER, self._getitem
+ )
+
+ def _getitem(self, key):
+ if self.datastore.is_remote: # pragma: no cover
+ getitem = functools.partial(robust_getitem, catch=RuntimeError)
+ else:
+ getitem = operator.getitem
+
+ try:
+ with self.datastore.lock:
+ original_array = self.get_array(needs_lock=False)
+ array = getitem(original_array, key)
+ except IndexError:
+ # Catch IndexError in netCDF4 and return a more informative
+ # error message. This is most often called when an unsorted
+ # indexer is used before the data is loaded from disk.
+ msg = (
+ "The indexing operation you are attempting to perform "
+ "is not valid on netCDF4.Variable object. Try loading "
+ "your data into memory first by calling .load()."
+ )
+ raise IndexError(msg)
+ return array
+
+
+def _encode_nc4_variable(var):
+ for coder in [
+ coding.strings.EncodedStringCoder(allows_unicode=True),
+ coding.strings.CharacterArrayCoder(),
+ ]:
+ var = coder.encode(var)
+ return var
+
+
+def _check_encoding_dtype_is_vlen_string(dtype):
+ if dtype is not str:
+ raise AssertionError( # pragma: no cover
+ f"unexpected dtype encoding {dtype!r}. This shouldn't happen: please "
+ "file a bug report at github.com/pydata/xarray"
+ )
+
+
+def _get_datatype(
+ var, nc_format="NETCDF4", raise_on_invalid_encoding=False
+) -> np.dtype:
+ if nc_format == "NETCDF4":
+ return _nc4_dtype(var)
+ if "dtype" in var.encoding:
+ encoded_dtype = var.encoding["dtype"]
+ _check_encoding_dtype_is_vlen_string(encoded_dtype)
+ if raise_on_invalid_encoding:
+ raise ValueError(
+ "encoding dtype=str for vlen strings is only supported "
+ "with format='NETCDF4'."
+ )
+ return var.dtype
+
+
+def _nc4_dtype(var):
+ if "dtype" in var.encoding:
+ dtype = var.encoding.pop("dtype")
+ _check_encoding_dtype_is_vlen_string(dtype)
+ elif coding.strings.is_unicode_dtype(var.dtype):
+ dtype = str
+ elif var.dtype.kind in ["i", "u", "f", "c", "S"]:
+ dtype = var.dtype
+ else:
+ raise ValueError(f"unsupported dtype for netCDF4 variable: {var.dtype}")
+ return dtype
+
+
+def _netcdf4_create_group(dataset, name):
+ return dataset.createGroup(name)
+
+
+def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group):
+ if group in {None, "", "/"}:
+ # use the root group
+ return ds
+ else:
+ # make sure it's a string
+ if not isinstance(group, str):
+ raise ValueError("group must be a string or None")
+ # support path-like syntax
+ path = group.strip("/").split("/")
+ for key in path:
+ try:
+ ds = ds.groups[key]
+ except KeyError as e:
+ if mode != "r":
+ ds = create_group(ds, key)
+ else:
+ # wrap error to provide slightly more helpful message
+ raise OSError(f"group not found: {key}", e)
+ return ds
+
+
+def _ensure_no_forward_slash_in_name(name):
+ if "/" in name:
+ raise ValueError(
+ f"Forward slashes '/' are not allowed in variable and dimension names (got {name!r}). "
+ "Forward slashes are used as hierarchy-separators for "
+ "HDF5-based files ('netcdf4'/'h5netcdf')."
+ )
+
+
+def _ensure_fill_value_valid(data, attributes):
+ # work around for netCDF4/scipy issue where _FillValue has the wrong type:
+ # https://github.com/Unidata/netcdf4-python/issues/271
+ if data.dtype.kind == "S" and "_FillValue" in attributes:
+ attributes["_FillValue"] = np.bytes_(attributes["_FillValue"])
+
+
+def _force_native_endianness(var):
+ # possible values for byteorder are:
+ # = native
+ # < little-endian
+ # > big-endian
+ # | not applicable
+ # Below we check if the data type is not native or NA
+ if var.dtype.byteorder not in ["=", "|"]:
+ # if endianness is specified explicitly, convert to the native type
+ data = var.data.astype(var.dtype.newbyteorder("="))
+ var = Variable(var.dims, data, var.attrs, var.encoding)
+ # if endian exists, remove it from the encoding.
+ var.encoding.pop("endian", None)
+ # check to see if encoding has a value for endian its 'native'
+ if var.encoding.get("endian", "native") != "native":
+ raise NotImplementedError(
+ "Attempt to write non-native endian type, "
+ "this is not supported by the netCDF4 "
+ "python library."
+ )
+ return var
+
+
+def _extract_nc4_variable_encoding(
+ variable: Variable,
+ raise_on_invalid=False,
+ lsd_okay=True,
+ h5py_okay=False,
+ backend="netCDF4",
+ unlimited_dims=None,
+) -> dict[str, Any]:
+ if unlimited_dims is None:
+ unlimited_dims = ()
+
+ encoding = variable.encoding.copy()
+
+ safe_to_drop = {"source", "original_shape"}
+ valid_encodings = {
+ "zlib",
+ "complevel",
+ "fletcher32",
+ "contiguous",
+ "chunksizes",
+ "shuffle",
+ "_FillValue",
+ "dtype",
+ "compression",
+ "significant_digits",
+ "quantize_mode",
+ "blosc_shuffle",
+ "szip_coding",
+ "szip_pixels_per_block",
+ "endian",
+ }
+ if lsd_okay:
+ valid_encodings.add("least_significant_digit")
+ if h5py_okay:
+ valid_encodings.add("compression_opts")
+
+ if not raise_on_invalid and encoding.get("chunksizes") is not None:
+ # It's possible to get encoded chunksizes larger than a dimension size
+ # if the original file had an unlimited dimension. This is problematic
+ # if the new file no longer has an unlimited dimension.
+ chunksizes = encoding["chunksizes"]
+ chunks_too_big = any(
+ c > d and dim not in unlimited_dims
+ for c, d, dim in zip(chunksizes, variable.shape, variable.dims)
+ )
+ has_original_shape = "original_shape" in encoding
+ changed_shape = (
+ has_original_shape and encoding.get("original_shape") != variable.shape
+ )
+ if chunks_too_big or changed_shape:
+ del encoding["chunksizes"]
+
+ var_has_unlim_dim = any(dim in unlimited_dims for dim in variable.dims)
+ if not raise_on_invalid and var_has_unlim_dim and "contiguous" in encoding.keys():
+ del encoding["contiguous"]
+
+ for k in safe_to_drop:
+ if k in encoding:
+ del encoding[k]
+
+ if raise_on_invalid:
+ invalid = [k for k in encoding if k not in valid_encodings]
+ if invalid:
+ raise ValueError(
+ f"unexpected encoding parameters for {backend!r} backend: {invalid!r}. Valid "
+ f"encodings are: {valid_encodings!r}"
+ )
+ else:
+ for k in list(encoding):
+ if k not in valid_encodings:
+ del encoding[k]
+
+ return encoding
+
+
+def _is_list_of_strings(value) -> bool:
+ arr = np.asarray(value)
+ return arr.dtype.kind in ["U", "S"] and arr.size > 1
class NetCDF4DataStore(WritableCFDataStore):
@@ -59,22 +320,35 @@ class NetCDF4DataStore(WritableCFDataStore):
This store supports NetCDF3, NetCDF4 and OpenDAP datasets.
"""
- __slots__ = ('autoclose', 'format', 'is_remote', 'lock', '_filename',
- '_group', '_manager', '_mode')
- def __init__(self, manager, group=None, mode=None, lock=
- NETCDF4_PYTHON_LOCK, autoclose=False):
+ __slots__ = (
+ "autoclose",
+ "format",
+ "is_remote",
+ "lock",
+ "_filename",
+ "_group",
+ "_manager",
+ "_mode",
+ )
+
+ def __init__(
+ self, manager, group=None, mode=None, lock=NETCDF4_PYTHON_LOCK, autoclose=False
+ ):
import netCDF4
+
if isinstance(manager, netCDF4.Dataset):
if group is None:
root, group = find_root_and_group(manager)
else:
if type(manager) is not netCDF4.Dataset:
raise ValueError(
- 'must supply a root netCDF4.Dataset if the group argument is provided'
- )
+ "must supply a root netCDF4.Dataset if the group "
+ "argument is provided"
+ )
root = manager
manager = DummyFileManager(root)
+
self._manager = manager
self._group = group
self._mode = mode
@@ -84,14 +358,223 @@ class NetCDF4DataStore(WritableCFDataStore):
self.lock = ensure_lock(lock)
self.autoclose = autoclose
- def _build_and_get_enum(self, var_name: str, dtype: np.dtype, enum_name:
- str, enum_dict: dict[str, int]) ->Any:
+ @classmethod
+ def open(
+ cls,
+ filename,
+ mode="r",
+ format="NETCDF4",
+ group=None,
+ clobber=True,
+ diskless=False,
+ persist=False,
+ lock=None,
+ lock_maker=None,
+ autoclose=False,
+ ):
+ import netCDF4
+
+ if isinstance(filename, os.PathLike):
+ filename = os.fspath(filename)
+
+ if not isinstance(filename, str):
+ raise ValueError(
+ "can only read bytes or file-like objects "
+ "with engine='scipy' or 'h5netcdf'"
+ )
+
+ if format is None:
+ format = "NETCDF4"
+
+ if lock is None:
+ if mode == "r":
+ if is_remote_uri(filename):
+ lock = NETCDFC_LOCK
+ else:
+ lock = NETCDF4_PYTHON_LOCK
+ else:
+ if format is None or format.startswith("NETCDF4"):
+ base_lock = NETCDF4_PYTHON_LOCK
+ else:
+ base_lock = NETCDFC_LOCK
+ lock = combine_locks([base_lock, get_write_lock(filename)])
+
+ kwargs = dict(
+ clobber=clobber, diskless=diskless, persist=persist, format=format
+ )
+ manager = CachingFileManager(
+ netCDF4.Dataset, filename, mode=mode, kwargs=kwargs
+ )
+ return cls(manager, group=group, mode=mode, lock=lock, autoclose=autoclose)
+
+ def _acquire(self, needs_lock=True):
+ with self._manager.acquire_context(needs_lock) as root:
+ ds = _nc4_require_group(root, self._group, self._mode)
+ return ds
+
+ @property
+ def ds(self):
+ return self._acquire()
+
+ def open_store_variable(self, name: str, var):
+ import netCDF4
+
+ dimensions = var.dimensions
+ attributes = {k: var.getncattr(k) for k in var.ncattrs()}
+ data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
+ encoding: dict[str, Any] = {}
+ if isinstance(var.datatype, netCDF4.EnumType):
+ encoding["dtype"] = np.dtype(
+ data.dtype,
+ metadata={
+ "enum": var.datatype.enum_dict,
+ "enum_name": var.datatype.name,
+ },
+ )
+ else:
+ encoding["dtype"] = var.dtype
+ _ensure_fill_value_valid(data, attributes)
+ # netCDF4 specific encoding; save _FillValue for later
+ filters = var.filters()
+ if filters is not None:
+ encoding.update(filters)
+ chunking = var.chunking()
+ if chunking is not None:
+ if chunking == "contiguous":
+ encoding["contiguous"] = True
+ encoding["chunksizes"] = None
+ else:
+ encoding["contiguous"] = False
+ encoding["chunksizes"] = tuple(chunking)
+ encoding["preferred_chunks"] = dict(zip(var.dimensions, chunking))
+ # TODO: figure out how to round-trip "endian-ness" without raising
+ # warnings from netCDF4
+ # encoding['endian'] = var.endian()
+ pop_to(attributes, encoding, "least_significant_digit")
+ # save source so __repr__ can detect if it's local or not
+ encoding["source"] = self._filename
+ encoding["original_shape"] = data.shape
+
+ return Variable(dimensions, data, attributes, encoding)
+
+ def get_variables(self):
+ return FrozenDict(
+ (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
+ )
+
+ def get_attrs(self):
+ return FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs())
+
+ def get_dimensions(self):
+ return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items())
+
+ def get_encoding(self):
+ return {
+ "unlimited_dims": {
+ k for k, v in self.ds.dimensions.items() if v.isunlimited()
+ }
+ }
+
+ def set_dimension(self, name, length, is_unlimited=False):
+ _ensure_no_forward_slash_in_name(name)
+ dim_length = length if not is_unlimited else None
+ self.ds.createDimension(name, size=dim_length)
+
+ def set_attribute(self, key, value):
+ if self.format != "NETCDF4":
+ value = encode_nc3_attr_value(value)
+ if _is_list_of_strings(value):
+ # encode as NC_STRING if attr is list of strings
+ self.ds.setncattr_string(key, value)
+ else:
+ self.ds.setncattr(key, value)
+
+ def encode_variable(self, variable):
+ variable = _force_native_endianness(variable)
+ if self.format == "NETCDF4":
+ variable = _encode_nc4_variable(variable)
+ else:
+ variable = encode_nc3_variable(variable)
+ return variable
+
+ def prepare_variable(
+ self, name, variable: Variable, check_encoding=False, unlimited_dims=None
+ ):
+ _ensure_no_forward_slash_in_name(name)
+ attrs = variable.attrs.copy()
+ fill_value = attrs.pop("_FillValue", None)
+ datatype = _get_datatype(
+ variable, self.format, raise_on_invalid_encoding=check_encoding
+ )
+ # check enum metadata and use netCDF4.EnumType
+ if (
+ (meta := np.dtype(datatype).metadata)
+ and (e_name := meta.get("enum_name"))
+ and (e_dict := meta.get("enum"))
+ ):
+ datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
+ encoding = _extract_nc4_variable_encoding(
+ variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
+ )
+ if name in self.ds.variables:
+ nc4_var = self.ds.variables[name]
+ else:
+ default_args = dict(
+ varname=name,
+ datatype=datatype,
+ dimensions=variable.dims,
+ zlib=False,
+ complevel=4,
+ shuffle=True,
+ fletcher32=False,
+ contiguous=False,
+ chunksizes=None,
+ endian="native",
+ least_significant_digit=None,
+ fill_value=fill_value,
+ )
+ default_args.update(encoding)
+ default_args.pop("_FillValue", None)
+ nc4_var = self.ds.createVariable(**default_args)
+
+ nc4_var.setncatts(attrs)
+
+ target = NetCDF4ArrayWrapper(name, self)
+
+ return target, variable.data
+
+ def _build_and_get_enum(
+ self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
+ ) -> Any:
"""
Add or get the netCDF4 Enum based on the dtype in encoding.
The return type should be ``netCDF4.EnumType``,
but we avoid importing netCDF4 globally for performances.
"""
- pass
+ if enum_name not in self.ds.enumtypes:
+ return self.ds.createEnumType(
+ dtype,
+ enum_name,
+ enum_dict,
+ )
+ datatype = self.ds.enumtypes[enum_name]
+ if datatype.enum_dict != enum_dict:
+ error_msg = (
+ f"Cannot save variable `{var_name}` because an enum"
+ f" `{enum_name}` already exists in the Dataset but have"
+ " a different definition. To fix this error, make sure"
+ " each variable have a uniquely named enum in their"
+ " `encoding['dtype'].metadata` or, if they should share"
+ " the same enum type, make sure the enums are identical."
+ )
+ raise ValueError(error_msg)
+ return datatype
+
+ def sync(self):
+ self.ds.sync()
+
+ def close(self, **kwargs):
+ self._manager.close(**kwargs)
class NetCDF4BackendEntrypoint(BackendEntrypoint):
@@ -115,12 +598,142 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint):
backends.H5netcdfBackendEntrypoint
backends.ScipyBackendEntrypoint
"""
+
description = (
- 'Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray'
+ "Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray"
+ )
+ url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html"
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
+ if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj):
+ return True
+ magic_number = try_read_magic_number_from_path(filename_or_obj)
+ if magic_number is not None:
+ # netcdf 3 or HDF5
+ return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n"))
+
+ if isinstance(filename_or_obj, (str, os.PathLike)):
+ _, ext = os.path.splitext(filename_or_obj)
+ return ext in {".nc", ".nc4", ".cdf"}
+
+ return False
+
+ def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ group=None,
+ mode="r",
+ format="NETCDF4",
+ clobber=True,
+ diskless=False,
+ persist=False,
+ lock=None,
+ autoclose=False,
+ ) -> Dataset:
+ filename_or_obj = _normalize_path(filename_or_obj)
+ store = NetCDF4DataStore.open(
+ filename_or_obj,
+ mode=mode,
+ format=format,
+ group=group,
+ clobber=clobber,
+ diskless=diskless,
+ persist=persist,
+ lock=lock,
+ autoclose=autoclose,
)
- url = (
- 'https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html'
+
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
+ def open_datatree(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ group: str | Iterable[str] | Callable | None = None,
+ format="NETCDF4",
+ clobber=True,
+ diskless=False,
+ persist=False,
+ lock=None,
+ autoclose=False,
+ **kwargs,
+ ) -> DataTree:
+ from xarray.backends.api import open_dataset
+ from xarray.backends.common import _iter_nc_groups
+ from xarray.core.datatree import DataTree
+ from xarray.core.treenode import NodePath
+
+ filename_or_obj = _normalize_path(filename_or_obj)
+ store = NetCDF4DataStore.open(
+ filename_or_obj,
+ group=group,
+ format=format,
+ clobber=clobber,
+ diskless=diskless,
+ persist=persist,
+ lock=lock,
+ autoclose=autoclose,
)
+ if group:
+ parent = NodePath("/") / NodePath(group)
+ else:
+ parent = NodePath("/")
+
+ manager = store._manager
+ ds = open_dataset(store, **kwargs)
+ tree_root = DataTree.from_dict({str(parent): ds})
+ for path_group in _iter_nc_groups(store.ds, parent=parent):
+ group_store = NetCDF4DataStore(manager, group=path_group, **kwargs)
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(group_store):
+ ds = store_entrypoint.open_dataset(
+ group_store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
+ tree_root._set_item(
+ path_group,
+ new_node,
+ allow_overwrite=False,
+ new_nodes_along_path=True,
+ )
+ return tree_root
-BACKEND_ENTRYPOINTS['netcdf4'] = 'netCDF4', NetCDF4BackendEntrypoint
+BACKEND_ENTRYPOINTS["netcdf4"] = ("netCDF4", NetCDF4BackendEntrypoint)
diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py
index e42acafc..70ddbdd1 100644
--- a/xarray/backends/netcdf3.py
+++ b/xarray/backends/netcdf3.py
@@ -1,17 +1,62 @@
from __future__ import annotations
+
import unicodedata
+
import numpy as np
+
from xarray import coding
from xarray.core.variable import Variable
+
+# Special characters that are permitted in netCDF names except in the
+# 0th position of the string
_specialchars = '_.@+- !"#$%&\\()*,:;<=>?[]^`{|}~'
-_reserved_names = {'byte', 'char', 'short', 'ushort', 'int', 'uint',
- 'int64', 'uint64', 'float', 'real', 'double', 'bool', 'string'}
-_nc3_dtype_coercions = {'int64': 'int32', 'uint64': 'int32', 'uint32':
- 'int32', 'uint16': 'int16', 'uint8': 'int8', 'bool': 'int8'}
-STRING_ENCODING = 'utf-8'
+
+# The following are reserved names in CDL and may not be used as names of
+# variables, dimension, attributes
+_reserved_names = {
+ "byte",
+ "char",
+ "short",
+ "ushort",
+ "int",
+ "uint",
+ "int64",
+ "uint64",
+ "float",
+ "real",
+ "double",
+ "bool",
+ "string",
+}
+
+# These data-types aren't supported by netCDF3, so they are automatically
+# coerced instead as indicated by the "coerce_nc3_dtype" function
+_nc3_dtype_coercions = {
+ "int64": "int32",
+ "uint64": "int32",
+ "uint32": "int32",
+ "uint16": "int16",
+ "uint8": "int8",
+ "bool": "int8",
+}
+
+# encode all strings as UTF-8
+STRING_ENCODING = "utf-8"
COERCION_VALUE_ERROR = (
- "could not safely cast array from {dtype} to {new_dtype}. While it is not always the case, a common reason for this is that xarray has deemed it safest to encode np.datetime64[ns] or np.timedelta64[ns] values with int64 values representing units of 'nanoseconds'. This is either due to the fact that the times are known to require nanosecond precision for an accurate round trip, or that the times are unknown prior to writing due to being contained in a chunked array. Ways to work around this are either to use a backend that supports writing int64 values, or to manually specify the encoding['units'] and encoding['dtype'] (e.g. 'seconds since 1970-01-01' and np.dtype('int32')) on the time variable(s) such that the times can be serialized in a netCDF3 file (note that depending on the situation, however, this latter option may result in an inaccurate round trip)."
- )
+ "could not safely cast array from {dtype} to {new_dtype}. While it is not "
+ "always the case, a common reason for this is that xarray has deemed it "
+ "safest to encode np.datetime64[ns] or np.timedelta64[ns] values with "
+ "int64 values representing units of 'nanoseconds'. This is either due to "
+ "the fact that the times are known to require nanosecond precision for an "
+ "accurate round trip, or that the times are unknown prior to writing due "
+ "to being contained in a chunked array. Ways to work around this are "
+ "either to use a backend that supports writing int64 values, or to "
+ "manually specify the encoding['units'] and encoding['dtype'] (e.g. "
+ "'seconds since 1970-01-01' and np.dtype('int32')) on the time "
+ "variable(s) such that the times can be serialized in a netCDF3 file "
+ "(note that depending on the situation, however, this latter option may "
+ "result in an inaccurate round trip)."
+)
def coerce_nc3_dtype(arr):
@@ -29,7 +74,61 @@ def coerce_nc3_dtype(arr):
Data is checked for equality, or equivalence (non-NaN values) using the
``(cast_array == original_array).all()``.
"""
- pass
+ dtype = str(arr.dtype)
+ if dtype in _nc3_dtype_coercions:
+ new_dtype = _nc3_dtype_coercions[dtype]
+ # TODO: raise a warning whenever casting the data-type instead?
+ cast_arr = arr.astype(new_dtype)
+ if not (cast_arr == arr).all():
+ raise ValueError(
+ COERCION_VALUE_ERROR.format(dtype=dtype, new_dtype=new_dtype)
+ )
+ arr = cast_arr
+ return arr
+
+
+def encode_nc3_attr_value(value):
+ if isinstance(value, bytes):
+ pass
+ elif isinstance(value, str):
+ value = value.encode(STRING_ENCODING)
+ else:
+ value = coerce_nc3_dtype(np.atleast_1d(value))
+ if value.ndim > 1:
+ raise ValueError("netCDF attributes must be 1-dimensional")
+ return value
+
+
+def encode_nc3_attrs(attrs):
+ return {k: encode_nc3_attr_value(v) for k, v in attrs.items()}
+
+
+def _maybe_prepare_times(var):
+ # checks for integer-based time-like and
+ # replaces np.iinfo(np.int64).min with _FillValue or np.nan
+ # this keeps backwards compatibility
+
+ data = var.data
+ if data.dtype.kind in "iu":
+ units = var.attrs.get("units", None)
+ if units is not None:
+ if coding.variables._is_time_like(units):
+ mask = data == np.iinfo(np.int64).min
+ if mask.any():
+ data = np.where(mask, var.attrs.get("_FillValue", np.nan), data)
+ return data
+
+
+def encode_nc3_variable(var):
+ for coder in [
+ coding.strings.EncodedStringCoder(allows_unicode=False),
+ coding.strings.CharacterArrayCoder(),
+ ]:
+ var = coder.encode(var)
+ data = _maybe_prepare_times(var)
+ data = coerce_nc3_dtype(data)
+ attrs = encode_nc3_attrs(var.attrs)
+ return Variable(var.dims, data, attrs, var.encoding)
def _isalnumMUTF8(c):
@@ -38,7 +137,7 @@ def _isalnumMUTF8(c):
Input is not checked!
"""
- pass
+ return c.isalnum() or (len(c.encode("utf-8")) > 1)
def is_valid_nc3_name(s):
@@ -58,4 +157,15 @@ def is_valid_nc3_name(s):
names. Names that have trailing space characters are also not
permitted.
"""
- pass
+ if not isinstance(s, str):
+ return False
+ num_bytes = len(s.encode("utf-8"))
+ return (
+ (unicodedata.normalize("NFC", s) == s)
+ and (s not in _reserved_names)
+ and (num_bytes >= 0)
+ and ("/" not in s)
+ and (s[-1] != " ")
+ and (_isalnumMUTF8(s[0]) or (s[0] == "_"))
+ and all(_isalnumMUTF8(c) or c in _specialchars for c in s)
+ )
diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py
index 4ebd98d2..a62ca6c9 100644
--- a/xarray/backends/plugins.py
+++ b/xarray/backends/plugins.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import functools
import inspect
import itertools
@@ -6,22 +7,117 @@ import sys
import warnings
from importlib.metadata import entry_points
from typing import TYPE_CHECKING, Any, Callable
+
from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendEntrypoint
from xarray.core.utils import module_available
+
if TYPE_CHECKING:
import os
from importlib.metadata import EntryPoint
+
if sys.version_info >= (3, 10):
from importlib.metadata import EntryPoints
else:
EntryPoints = list[EntryPoint]
from io import BufferedIOBase
+
from xarray.backends.common import AbstractDataStore
-STANDARD_BACKENDS_ORDER = ['netcdf4', 'h5netcdf', 'scipy']
+
+STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"]
+
+
+def remove_duplicates(entrypoints: EntryPoints) -> list[EntryPoint]:
+ # sort and group entrypoints by name
+ entrypoints_sorted = sorted(entrypoints, key=lambda ep: ep.name)
+ entrypoints_grouped = itertools.groupby(entrypoints_sorted, key=lambda ep: ep.name)
+ # check if there are multiple entrypoints for the same name
+ unique_entrypoints = []
+ for name, _matches in entrypoints_grouped:
+ # remove equal entrypoints
+ matches = list(set(_matches))
+ unique_entrypoints.append(matches[0])
+ matches_len = len(matches)
+ if matches_len > 1:
+ all_module_names = [e.value.split(":")[0] for e in matches]
+ selected_module_name = all_module_names[0]
+ warnings.warn(
+ f"Found {matches_len} entrypoints for the engine name {name}:"
+ f"\n {all_module_names}.\n "
+ f"The entrypoint {selected_module_name} will be used.",
+ RuntimeWarning,
+ )
+ return unique_entrypoints
+
+
+def detect_parameters(open_dataset: Callable) -> tuple[str, ...]:
+ signature = inspect.signature(open_dataset)
+ parameters = signature.parameters
+ parameters_list = []
+ for name, param in parameters.items():
+ if param.kind in (
+ inspect.Parameter.VAR_KEYWORD,
+ inspect.Parameter.VAR_POSITIONAL,
+ ):
+ raise TypeError(
+ f"All the parameters in {open_dataset!r} signature should be explicit. "
+ "*args and **kwargs is not supported"
+ )
+ if name != "self":
+ parameters_list.append(name)
+ return tuple(parameters_list)
+
+
+def backends_dict_from_pkg(
+ entrypoints: list[EntryPoint],
+) -> dict[str, type[BackendEntrypoint]]:
+ backend_entrypoints = {}
+ for entrypoint in entrypoints:
+ name = entrypoint.name
+ try:
+ backend = entrypoint.load()
+ backend_entrypoints[name] = backend
+ except Exception as ex:
+ warnings.warn(f"Engine {name!r} loading failed:\n{ex}", RuntimeWarning)
+ return backend_entrypoints
+
+
+def set_missing_parameters(
+ backend_entrypoints: dict[str, type[BackendEntrypoint]]
+) -> None:
+ for _, backend in backend_entrypoints.items():
+ if backend.open_dataset_parameters is None:
+ open_dataset = backend.open_dataset
+ backend.open_dataset_parameters = detect_parameters(open_dataset)
+
+
+def sort_backends(
+ backend_entrypoints: dict[str, type[BackendEntrypoint]]
+) -> dict[str, type[BackendEntrypoint]]:
+ ordered_backends_entrypoints = {}
+ for be_name in STANDARD_BACKENDS_ORDER:
+ if backend_entrypoints.get(be_name, None) is not None:
+ ordered_backends_entrypoints[be_name] = backend_entrypoints.pop(be_name)
+ ordered_backends_entrypoints.update(
+ {name: backend_entrypoints[name] for name in sorted(backend_entrypoints)}
+ )
+ return ordered_backends_entrypoints
+
+
+def build_engines(entrypoints: EntryPoints) -> dict[str, BackendEntrypoint]:
+ backend_entrypoints: dict[str, type[BackendEntrypoint]] = {}
+ for backend_name, (module_name, backend) in BACKEND_ENTRYPOINTS.items():
+ if module_name is None or module_available(module_name):
+ backend_entrypoints[backend_name] = backend
+ entrypoints_unique = remove_duplicates(entrypoints)
+ external_backend_entrypoints = backends_dict_from_pkg(entrypoints_unique)
+ backend_entrypoints.update(external_backend_entrypoints)
+ backend_entrypoints = sort_backends(backend_entrypoints)
+ set_missing_parameters(backend_entrypoints)
+ return {name: backend() for name, backend in backend_entrypoints.items()}
@functools.lru_cache(maxsize=1)
-def list_engines() ->dict[str, BackendEntrypoint]:
+def list_engines() -> dict[str, BackendEntrypoint]:
"""
Return a dictionary of available engines and their BackendEntrypoint objects.
@@ -36,14 +132,86 @@ def list_engines() ->dict[str, BackendEntrypoint]:
# New selection mechanism introduced with Python 3.10. See GH6514.
"""
- pass
+ if sys.version_info >= (3, 10):
+ entrypoints = entry_points(group="xarray.backends")
+ else:
+ entrypoints = entry_points().get("xarray.backends", [])
+ return build_engines(entrypoints)
-def refresh_engines() ->None:
+def refresh_engines() -> None:
"""Refreshes the backend engines based on installed packages."""
- pass
+ list_engines.cache_clear()
+
+
+def guess_engine(
+ store_spec: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+) -> str | type[BackendEntrypoint]:
+ engines = list_engines()
+ for engine, backend in engines.items():
+ try:
+ if backend.guess_can_open(store_spec):
+ return engine
+ except PermissionError:
+ raise
+ except Exception:
+ warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning)
-def get_backend(engine: (str | type[BackendEntrypoint])) ->BackendEntrypoint:
+ compatible_engines = []
+ for engine, (_, backend_cls) in BACKEND_ENTRYPOINTS.items():
+ try:
+ backend = backend_cls()
+ if backend.guess_can_open(store_spec):
+ compatible_engines.append(engine)
+ except Exception:
+ warnings.warn(f"{engine!r} fails while guessing", RuntimeWarning)
+
+ installed_engines = [k for k in engines if k != "store"]
+ if not compatible_engines:
+ if installed_engines:
+ error_msg = (
+ "did not find a match in any of xarray's currently installed IO "
+ f"backends {installed_engines}. Consider explicitly selecting one of the "
+ "installed engines via the ``engine`` parameter, or installing "
+ "additional IO dependencies, see:\n"
+ "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html\n"
+ "https://docs.xarray.dev/en/stable/user-guide/io.html"
+ )
+ else:
+ error_msg = (
+ "xarray is unable to open this file because it has no currently "
+ "installed IO backends. Xarray's read/write support requires "
+ "installing optional IO dependencies, see:\n"
+ "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html\n"
+ "https://docs.xarray.dev/en/stable/user-guide/io"
+ )
+ else:
+ error_msg = (
+ "found the following matches with the input file in xarray's IO "
+ f"backends: {compatible_engines}. But their dependencies may not be installed, see:\n"
+ "https://docs.xarray.dev/en/stable/user-guide/io.html \n"
+ "https://docs.xarray.dev/en/stable/getting-started-guide/installing.html"
+ )
+
+ raise ValueError(error_msg)
+
+
+def get_backend(engine: str | type[BackendEntrypoint]) -> BackendEntrypoint:
"""Select open_dataset method based on current engine."""
- pass
+ if isinstance(engine, str):
+ engines = list_engines()
+ if engine not in engines:
+ raise ValueError(
+ f"unrecognized engine {engine} must be one of: {list(engines)}"
+ )
+ backend = engines[engine]
+ elif isinstance(engine, type) and issubclass(engine, BackendEntrypoint):
+ backend = engine()
+ else:
+ raise TypeError(
+ "engine must be a string or a subclass of "
+ f"xarray.backends.BackendEntrypoint: {engine}"
+ )
+
+ return backend
diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py
index 075d17e6..5a475a7c 100644
--- a/xarray/backends/pydap_.py
+++ b/xarray/backends/pydap_.py
@@ -1,27 +1,84 @@
from __future__ import annotations
+
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
+
import numpy as np
-from xarray.backends.common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendArray, BackendEntrypoint, robust_getitem
+
+from xarray.backends.common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractDataStore,
+ BackendArray,
+ BackendEntrypoint,
+ robust_getitem,
+)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
-from xarray.core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri
+from xarray.core.utils import (
+ Frozen,
+ FrozenDict,
+ close_on_error,
+ is_dict_like,
+ is_remote_uri,
+)
from xarray.core.variable import Variable
from xarray.namedarray.pycompat import integer_types
+
if TYPE_CHECKING:
import os
from io import BufferedIOBase
+
from xarray.core.dataset import Dataset
class PydapArrayWrapper(BackendArray):
-
def __init__(self, array):
self.array = array
+ @property
+ def shape(self) -> tuple[int, ...]:
+ return self.array.shape
+
+ @property
+ def dtype(self):
+ return self.array.dtype
+
def __getitem__(self, key):
- return indexing.explicit_indexing_adapter(key, self.shape, indexing
- .IndexingSupport.BASIC, self._getitem)
+ return indexing.explicit_indexing_adapter(
+ key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
+ )
+
+ def _getitem(self, key):
+ # pull the data from the array attribute if possible, to avoid
+ # downloading coordinate data twice
+ array = getattr(self.array, "array", self.array)
+ result = robust_getitem(array, key, catch=ValueError)
+ result = np.asarray(result)
+ # in some cases, pydap doesn't squeeze axes automatically like numpy
+ axis = tuple(n for n, k in enumerate(key) if isinstance(k, integer_types))
+ if result.ndim + len(axis) != array.ndim and axis:
+ result = np.squeeze(result, axis)
+
+ return result
+
+
+def _fix_attributes(attributes):
+ attributes = dict(attributes)
+ for k in list(attributes):
+ if k.lower() == "global" or k.lower().endswith("_global"):
+ # move global attributes to the top level, like the netcdf-C
+ # DAP client
+ attributes.update(attributes.pop(k))
+ elif is_dict_like(attributes[k]):
+ # Make Hierarchical attributes to a single level with a
+ # dot-separated key
+ attributes.update(
+ {
+ f"{k}.{k_child}": v_child
+ for k_child, v_child in attributes.pop(k).items()
+ }
+ )
+ return attributes
class PydapDataStore(AbstractDataStore):
@@ -39,6 +96,54 @@ class PydapDataStore(AbstractDataStore):
"""
self.ds = ds
+ @classmethod
+ def open(
+ cls,
+ url,
+ application=None,
+ session=None,
+ output_grid=None,
+ timeout=None,
+ verify=None,
+ user_charset=None,
+ ):
+ import pydap.client
+ import pydap.lib
+
+ if timeout is None:
+ from pydap.lib import DEFAULT_TIMEOUT
+
+ timeout = DEFAULT_TIMEOUT
+
+ kwargs = {
+ "url": url,
+ "application": application,
+ "session": session,
+ "output_grid": output_grid or True,
+ "timeout": timeout,
+ }
+ if verify is not None:
+ kwargs.update({"verify": verify})
+ if user_charset is not None:
+ kwargs.update({"user_charset": user_charset})
+ ds = pydap.client.open_url(**kwargs)
+ return cls(ds)
+
+ def open_store_variable(self, var):
+ data = indexing.LazilyIndexedArray(PydapArrayWrapper(var))
+ return Variable(var.dimensions, data, _fix_attributes(var.attributes))
+
+ def get_variables(self):
+ return FrozenDict(
+ (k, self.open_store_variable(self.ds[k])) for k in self.ds.keys()
+ )
+
+ def get_attrs(self):
+ return Frozen(_fix_attributes(self.ds.attributes))
+
+ def get_dimensions(self):
+ return Frozen(self.ds.dimensions)
+
class PydapBackendEntrypoint(BackendEntrypoint):
"""
@@ -55,10 +160,57 @@ class PydapBackendEntrypoint(BackendEntrypoint):
--------
backends.PydapDataStore
"""
- description = 'Open remote datasets via OPeNDAP using pydap in Xarray'
- url = (
- 'https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html'
+
+ description = "Open remote datasets via OPeNDAP using pydap in Xarray"
+ url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.PydapBackendEntrypoint.html"
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
+ return isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj)
+
+ def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ application=None,
+ session=None,
+ output_grid=None,
+ timeout=None,
+ verify=None,
+ user_charset=None,
+ ) -> Dataset:
+ store = PydapDataStore.open(
+ url=filename_or_obj,
+ application=application,
+ session=session,
+ output_grid=output_grid,
+ timeout=timeout,
+ verify=verify,
+ user_charset=user_charset,
)
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
-BACKEND_ENTRYPOINTS['pydap'] = 'pydap', PydapBackendEntrypoint
+BACKEND_ENTRYPOINTS["pydap"] = ("pydap", PydapBackendEntrypoint)
diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py
index 47f34b36..f8c486e5 100644
--- a/xarray/backends/scipy_.py
+++ b/xarray/backends/scipy_.py
@@ -1,27 +1,61 @@
from __future__ import annotations
+
import gzip
import io
import os
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
+
import numpy as np
-from xarray.backends.common import BACKEND_ENTRYPOINTS, BackendArray, BackendEntrypoint, WritableCFDataStore, _normalize_path
+
+from xarray.backends.common import (
+ BACKEND_ENTRYPOINTS,
+ BackendArray,
+ BackendEntrypoint,
+ WritableCFDataStore,
+ _normalize_path,
+)
from xarray.backends.file_manager import CachingFileManager, DummyFileManager
from xarray.backends.locks import ensure_lock, get_write_lock
-from xarray.backends.netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name
+from xarray.backends.netcdf3 import (
+ encode_nc3_attr_value,
+ encode_nc3_variable,
+ is_valid_nc3_name,
+)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
-from xarray.core.utils import Frozen, FrozenDict, close_on_error, module_available, try_read_magic_number_from_file_or_path
+from xarray.core.utils import (
+ Frozen,
+ FrozenDict,
+ close_on_error,
+ module_available,
+ try_read_magic_number_from_file_or_path,
+)
from xarray.core.variable import Variable
+
if TYPE_CHECKING:
from io import BufferedIOBase
+
from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
-HAS_NUMPY_2_0 = module_available('numpy', minversion='2.0.0.dev0')
-class ScipyArrayWrapper(BackendArray):
+HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0")
+
+
+def _decode_string(s):
+ if isinstance(s, bytes):
+ return s.decode("utf-8", "replace")
+ return s
+
+def _decode_attrs(d):
+ # don't decode _FillValue from bytes -> unicode, because we want to ensure
+ # that its type matches the data exactly
+ return {k: v if k == "_FillValue" else _decode_string(v) for (k, v) in d.items()}
+
+
+class ScipyArrayWrapper(BackendArray):
def __init__(self, variable_name, datastore):
self.datastore = datastore
self.variable_name = variable_name
@@ -29,11 +63,29 @@ class ScipyArrayWrapper(BackendArray):
self.shape = array.shape
self.dtype = np.dtype(array.dtype.kind + str(array.dtype.itemsize))
+ def get_variable(self, needs_lock=True):
+ ds = self.datastore._manager.acquire(needs_lock)
+ return ds.variables[self.variable_name]
+
+ def _getitem(self, key):
+ with self.datastore.lock:
+ data = self.get_variable(needs_lock=False).data
+ return data[key]
+
def __getitem__(self, key):
- data = indexing.explicit_indexing_adapter(key, self.shape, indexing
- .IndexingSupport.OUTER_1VECTOR, self._getitem)
+ data = indexing.explicit_indexing_adapter(
+ key, self.shape, indexing.IndexingSupport.OUTER_1VECTOR, self._getitem
+ )
+ # Copy data if the source file is mmapped. This makes things consistent
+ # with the netCDF4 library by ensuring we can safely read arrays even
+ # after closing associated files.
copy = self.datastore.ds.use_mmap
+
+ # adapt handling of copy-kwarg to numpy 2.0
+ # see https://github.com/numpy/numpy/issues/25916
+ # and https://github.com/numpy/numpy/pull/25922
copy = None if HAS_NUMPY_2_0 and copy is False else copy
+
return np.array(data, dtype=self.dtype, copy=copy)
def __setitem__(self, key, value):
@@ -43,11 +95,50 @@ class ScipyArrayWrapper(BackendArray):
data[key] = value
except TypeError:
if key is Ellipsis:
+ # workaround for GH: scipy/scipy#6880
data[:] = value
else:
raise
+def _open_scipy_netcdf(filename, mode, mmap, version):
+ import scipy.io
+
+ # if the string ends with .gz, then gunzip and open as netcdf file
+ if isinstance(filename, str) and filename.endswith(".gz"):
+ try:
+ return scipy.io.netcdf_file(
+ gzip.open(filename), mode=mode, mmap=mmap, version=version
+ )
+ except TypeError as e:
+ # TODO: gzipped loading only works with NetCDF3 files.
+ errmsg = e.args[0]
+ if "is not a valid NetCDF 3 file" in errmsg:
+ raise ValueError("gzipped file loading only supports NetCDF 3 files.")
+ else:
+ raise
+
+ if isinstance(filename, bytes) and filename.startswith(b"CDF"):
+ # it's a NetCDF3 bytestring
+ filename = io.BytesIO(filename)
+
+ try:
+ return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version)
+ except TypeError as e: # netcdf3 message is obscure in this case
+ errmsg = e.args[0]
+ if "is not a valid NetCDF 3 file" in errmsg:
+ msg = """
+ If this is a NetCDF4 file, you may need to install the
+ netcdf4 library, e.g.,
+
+ $ pip install netcdf4
+ """
+ errmsg += msg
+ raise TypeError(errmsg)
+ else:
+ raise
+
+
class ScipyDataStore(WritableCFDataStore):
"""Store for reading and writing data via scipy.io.netcdf.
@@ -57,31 +148,121 @@ class ScipyDataStore(WritableCFDataStore):
It only supports the NetCDF3 file-format.
"""
- def __init__(self, filename_or_obj, mode='r', format=None, group=None,
- mmap=None, lock=None):
+ def __init__(
+ self, filename_or_obj, mode="r", format=None, group=None, mmap=None, lock=None
+ ):
if group is not None:
- raise ValueError(
- 'cannot save to a group with the scipy.io.netcdf backend')
- if format is None or format == 'NETCDF3_64BIT':
+ raise ValueError("cannot save to a group with the scipy.io.netcdf backend")
+
+ if format is None or format == "NETCDF3_64BIT":
version = 2
- elif format == 'NETCDF3_CLASSIC':
+ elif format == "NETCDF3_CLASSIC":
version = 1
else:
- raise ValueError(
- f'invalid format for scipy.io.netcdf backend: {format!r}')
- if lock is None and mode != 'r' and isinstance(filename_or_obj, str):
+ raise ValueError(f"invalid format for scipy.io.netcdf backend: {format!r}")
+
+ if lock is None and mode != "r" and isinstance(filename_or_obj, str):
lock = get_write_lock(filename_or_obj)
+
self.lock = ensure_lock(lock)
+
if isinstance(filename_or_obj, str):
- manager = CachingFileManager(_open_scipy_netcdf,
- filename_or_obj, mode=mode, lock=lock, kwargs=dict(mmap=
- mmap, version=version))
+ manager = CachingFileManager(
+ _open_scipy_netcdf,
+ filename_or_obj,
+ mode=mode,
+ lock=lock,
+ kwargs=dict(mmap=mmap, version=version),
+ )
else:
- scipy_dataset = _open_scipy_netcdf(filename_or_obj, mode=mode,
- mmap=mmap, version=version)
+ scipy_dataset = _open_scipy_netcdf(
+ filename_or_obj, mode=mode, mmap=mmap, version=version
+ )
manager = DummyFileManager(scipy_dataset)
+
self._manager = manager
+ @property
+ def ds(self):
+ return self._manager.acquire()
+
+ def open_store_variable(self, name, var):
+ return Variable(
+ var.dimensions,
+ ScipyArrayWrapper(name, self),
+ _decode_attrs(var._attributes),
+ )
+
+ def get_variables(self):
+ return FrozenDict(
+ (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
+ )
+
+ def get_attrs(self):
+ return Frozen(_decode_attrs(self.ds._attributes))
+
+ def get_dimensions(self):
+ return Frozen(self.ds.dimensions)
+
+ def get_encoding(self):
+ return {
+ "unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None}
+ }
+
+ def set_dimension(self, name, length, is_unlimited=False):
+ if name in self.ds.dimensions:
+ raise ValueError(
+ f"{type(self).__name__} does not support modifying dimensions"
+ )
+ dim_length = length if not is_unlimited else None
+ self.ds.createDimension(name, dim_length)
+
+ def _validate_attr_key(self, key):
+ if not is_valid_nc3_name(key):
+ raise ValueError("Not a valid attribute name")
+
+ def set_attribute(self, key, value):
+ self._validate_attr_key(key)
+ value = encode_nc3_attr_value(value)
+ setattr(self.ds, key, value)
+
+ def encode_variable(self, variable):
+ variable = encode_nc3_variable(variable)
+ return variable
+
+ def prepare_variable(
+ self, name, variable, check_encoding=False, unlimited_dims=None
+ ):
+ if (
+ check_encoding
+ and variable.encoding
+ and variable.encoding != {"_FillValue": None}
+ ):
+ raise ValueError(
+ f"unexpected encoding for scipy backend: {list(variable.encoding)}"
+ )
+
+ data = variable.data
+ # nb. this still creates a numpy array in all memory, even though we
+ # don't write the data yet; scipy.io.netcdf does not not support
+ # incremental writes.
+ if name not in self.ds.variables:
+ self.ds.createVariable(name, data.dtype, variable.dims)
+ scipy_var = self.ds.variables[name]
+ for k, v in variable.attrs.items():
+ self._validate_attr_key(k)
+ setattr(scipy_var, k, v)
+
+ target = ScipyArrayWrapper(name, self)
+
+ return target, data
+
+ def sync(self):
+ self.ds.sync()
+
+ def close(self):
+ self._manager.close()
+
class ScipyBackendEntrypoint(BackendEntrypoint):
"""
@@ -103,11 +284,62 @@ class ScipyBackendEntrypoint(BackendEntrypoint):
backends.NetCDF4BackendEntrypoint
backends.H5netcdfBackendEntrypoint
"""
- description = (
- 'Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray')
- url = (
- 'https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html'
+
+ description = "Open netCDF files (.nc, .nc4, .cdf and .gz) using scipy in Xarray"
+ url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ScipyBackendEntrypoint.html"
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
+ magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
+ if magic_number is not None and magic_number.startswith(b"\x1f\x8b"):
+ with gzip.open(filename_or_obj) as f: # type: ignore[arg-type]
+ magic_number = try_read_magic_number_from_file_or_path(f)
+ if magic_number is not None:
+ return magic_number.startswith(b"CDF")
+
+ if isinstance(filename_or_obj, (str, os.PathLike)):
+ _, ext = os.path.splitext(filename_or_obj)
+ return ext in {".nc", ".nc4", ".cdf", ".gz"}
+
+ return False
+
+ def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ mode="r",
+ format=None,
+ group=None,
+ mmap=None,
+ lock=None,
+ ) -> Dataset:
+ filename_or_obj = _normalize_path(filename_or_obj)
+ store = ScipyDataStore(
+ filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock
)
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
-BACKEND_ENTRYPOINTS['scipy'] = 'scipy', ScipyBackendEntrypoint
+BACKEND_ENTRYPOINTS["scipy"] = ("scipy", ScipyBackendEntrypoint)
diff --git a/xarray/backends/store.py b/xarray/backends/store.py
index 031b247b..a507ee37 100644
--- a/xarray/backends/store.py
+++ b/xarray/backends/store.py
@@ -1,19 +1,66 @@
from __future__ import annotations
+
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
+
from xarray import conventions
-from xarray.backends.common import BACKEND_ENTRYPOINTS, AbstractDataStore, BackendEntrypoint
+from xarray.backends.common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractDataStore,
+ BackendEntrypoint,
+)
from xarray.core.dataset import Dataset
+
if TYPE_CHECKING:
import os
from io import BufferedIOBase
class StoreBackendEntrypoint(BackendEntrypoint):
- description = 'Open AbstractDataStore instances in Xarray'
- url = (
- 'https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html'
+ description = "Open AbstractDataStore instances in Xarray"
+ url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html"
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
+ return isinstance(filename_or_obj, AbstractDataStore)
+
+ def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ ) -> Dataset:
+ assert isinstance(filename_or_obj, AbstractDataStore)
+
+ vars, attrs = filename_or_obj.load()
+ encoding = filename_or_obj.get_encoding()
+
+ vars, attrs, coord_names = conventions.decode_cf_variables(
+ vars,
+ attrs,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
)
+ ds = Dataset(vars, attrs=attrs)
+ ds = ds.set_coords(coord_names.intersection(vars))
+ ds.set_close(filename_or_obj.close)
+ ds.encoding = encoding
+
+ return ds
+
-BACKEND_ENTRYPOINTS['store'] = None, StoreBackendEntrypoint
+BACKEND_ENTRYPOINTS["store"] = (None, StoreBackendEntrypoint)
diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py
index bf3f1484..8c526ddb 100644
--- a/xarray/backends/zarr.py
+++ b/xarray/backends/zarr.py
@@ -1,26 +1,44 @@
from __future__ import annotations
+
import json
import os
import warnings
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
+
import numpy as np
import pandas as pd
+
from xarray import coding, conventions
-from xarray.backends.common import BACKEND_ENTRYPOINTS, AbstractWritableDataStore, BackendArray, BackendEntrypoint, _encode_variable_name, _normalize_path
+from xarray.backends.common import (
+ BACKEND_ENTRYPOINTS,
+ AbstractWritableDataStore,
+ BackendArray,
+ BackendEntrypoint,
+ _encode_variable_name,
+ _normalize_path,
+)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
from xarray.core.types import ZarrWriteModes
-from xarray.core.utils import FrozenDict, HiddenKeyDict, close_on_error
+from xarray.core.utils import (
+ FrozenDict,
+ HiddenKeyDict,
+ close_on_error,
+)
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import guess_chunkmanager
from xarray.namedarray.pycompat import integer_types
+
if TYPE_CHECKING:
from io import BufferedIOBase
+
from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
-DIMENSION_KEY = '_ARRAY_DIMENSIONS'
+
+# need some special secret attributes to tell us the dimensions
+DIMENSION_KEY = "_ARRAY_DIMENSIONS"
def encode_zarr_attr_value(value):
@@ -34,22 +52,47 @@ def encode_zarr_attr_value(value):
scalar array -> scalar
other -> other (no change)
"""
- pass
+ if isinstance(value, np.ndarray):
+ encoded = value.tolist()
+ # this checks if it's a scalar number
+ elif isinstance(value, np.generic):
+ encoded = value.item()
+ else:
+ encoded = value
+ return encoded
class ZarrArrayWrapper(BackendArray):
- __slots__ = 'dtype', 'shape', '_array'
+ __slots__ = ("dtype", "shape", "_array")
def __init__(self, zarr_array):
+ # some callers attempt to evaluate an array if an `array` property exists on the object.
+ # we prefix with _ to avoid this inference.
self._array = zarr_array
self.shape = self._array.shape
- if self._array.filters is not None and any([(filt.codec_id ==
- 'vlen-utf8') for filt in self._array.filters]):
+
+ # preserve vlen string object dtype (GH 7328)
+ if self._array.filters is not None and any(
+ [filt.codec_id == "vlen-utf8" for filt in self._array.filters]
+ ):
dtype = coding.strings.create_vlen_dtype(str)
else:
dtype = self._array.dtype
+
self.dtype = dtype
+ def get_array(self):
+ return self._array
+
+ def _oindex(self, key):
+ return self._array.oindex[key]
+
+ def _vindex(self, key):
+ return self._array.vindex[key]
+
+ def _getitem(self, key):
+ return self._array[key]
+
def __getitem__(self, key):
array = self._array
if isinstance(key, indexing.BasicIndexer):
@@ -58,8 +101,12 @@ class ZarrArrayWrapper(BackendArray):
method = self._vindex
elif isinstance(key, indexing.OuterIndexer):
method = self._oindex
- return indexing.explicit_indexing_adapter(key, array.shape,
- indexing.IndexingSupport.VECTORIZED, method)
+ return indexing.explicit_indexing_adapter(
+ key, array.shape, indexing.IndexingSupport.VECTORIZED, method
+ )
+
+ # if self.ndim == 0:
+ # could possibly have a work-around for 0d data here
def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
@@ -67,11 +114,134 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks):
Given encoding chunks (possibly None or []) and variable chunks
(possibly None or []).
"""
- pass
+ # zarr chunk spec:
+ # chunks : int or tuple of ints, optional
+ # Chunk shape. If not provided, will be guessed from shape and dtype.
+
+ # if there are no chunks in encoding and the variable data is a numpy
+ # array, then we let zarr use its own heuristics to pick the chunks
+ if not var_chunks and not enc_chunks:
+ return None
+
+ # if there are no chunks in encoding but there are dask chunks, we try to
+ # use the same chunks in zarr
+ # However, zarr chunks needs to be uniform for each array
+ # http://zarr.readthedocs.io/en/latest/spec/v1.html#chunks
+ # while dask chunks can be variable sized
+ # http://dask.pydata.org/en/latest/array-design.html#chunks
+ if var_chunks and not enc_chunks:
+ if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks):
+ raise ValueError(
+ "Zarr requires uniform chunk sizes except for final chunk. "
+ f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. "
+ "Consider rechunking using `chunk()`."
+ )
+ if any((chunks[0] < chunks[-1]) for chunks in var_chunks):
+ raise ValueError(
+ "Final chunk of Zarr array must be the same size or smaller "
+ f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}."
+ "Consider either rechunking using `chunk()` or instead deleting "
+ "or modifying `encoding['chunks']`."
+ )
+ # return the first chunk for each dimension
+ return tuple(chunk[0] for chunk in var_chunks)
+
+ # from here on, we are dealing with user-specified chunks in encoding
+ # zarr allows chunks to be an integer, in which case it uses the same chunk
+ # size on each dimension.
+ # Here we re-implement this expansion ourselves. That makes the logic of
+ # checking chunk compatibility easier
-def extract_zarr_variable_encoding(variable, raise_on_invalid=False, name=
- None, safe_chunks=True):
+ if isinstance(enc_chunks, integer_types):
+ enc_chunks_tuple = ndim * (enc_chunks,)
+ else:
+ enc_chunks_tuple = tuple(enc_chunks)
+
+ if len(enc_chunks_tuple) != ndim:
+ # throw away encoding chunks, start over
+ return _determine_zarr_chunks(None, var_chunks, ndim, name, safe_chunks)
+
+ for x in enc_chunks_tuple:
+ if not isinstance(x, int):
+ raise TypeError(
+ "zarr chunk sizes specified in `encoding['chunks']` "
+ "must be an int or a tuple of ints. "
+ f"Instead found encoding['chunks']={enc_chunks_tuple!r} "
+ f"for variable named {name!r}."
+ )
+
+ # if there are chunks in encoding and the variable data is a numpy array,
+ # we use the specified chunks
+ if not var_chunks:
+ return enc_chunks_tuple
+
+ # the hard case
+ # DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk
+ # this avoids the need to get involved in zarr synchronization / locking
+ # From zarr docs:
+ # "If each worker in a parallel computation is writing to a
+ # separate region of the array, and if region boundaries are perfectly aligned
+ # with chunk boundaries, then no synchronization is required."
+ # TODO: incorporate synchronizer to allow writes from multiple dask
+ # threads
+ if var_chunks and enc_chunks_tuple:
+ for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks):
+ for dchunk in dchunks[:-1]:
+ if dchunk % zchunk:
+ base_error = (
+ f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for "
+ f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. "
+ f"Writing this array in parallel with dask could lead to corrupted data."
+ )
+ if safe_chunks:
+ raise ValueError(
+ base_error
+ + " Consider either rechunking using `chunk()`, deleting "
+ "or modifying `encoding['chunks']`, or specify `safe_chunks=False`."
+ )
+ return enc_chunks_tuple
+
+ raise AssertionError("We should never get here. Function logic must be wrong.")
+
+
+def _get_zarr_dims_and_attrs(zarr_obj, dimension_key, try_nczarr):
+ # Zarr arrays do not have dimensions. To get around this problem, we add
+ # an attribute that specifies the dimension. We have to hide this attribute
+ # when we send the attributes to the user.
+ # zarr_obj can be either a zarr group or zarr array
+ try:
+ # Xarray-Zarr
+ dimensions = zarr_obj.attrs[dimension_key]
+ except KeyError as e:
+ if not try_nczarr:
+ raise KeyError(
+ f"Zarr object is missing the attribute `{dimension_key}`, which is "
+ "required for xarray to determine variable dimensions."
+ ) from e
+
+ # NCZarr defines dimensions through metadata in .zarray
+ zarray_path = os.path.join(zarr_obj.path, ".zarray")
+ zarray = json.loads(zarr_obj.store[zarray_path])
+ try:
+ # NCZarr uses Fully Qualified Names
+ dimensions = [
+ os.path.basename(dim) for dim in zarray["_NCZARR_ARRAY"]["dimrefs"]
+ ]
+ except KeyError as e:
+ raise KeyError(
+ f"Zarr object is missing the attribute `{dimension_key}` and the NCZarr metadata, "
+ "which are required for xarray to determine variable dimensions."
+ ) from e
+
+ nc_attrs = [attr for attr in zarr_obj.attrs if attr.lower().startswith("_nc")]
+ attributes = HiddenKeyDict(zarr_obj.attrs, [dimension_key] + nc_attrs)
+ return dimensions, attributes
+
+
+def extract_zarr_variable_encoding(
+ variable, raise_on_invalid=False, name=None, safe_chunks=True
+):
"""
Extract zarr encoding dictionary from xarray Variable
@@ -85,9 +255,41 @@ def extract_zarr_variable_encoding(variable, raise_on_invalid=False, name=
encoding : dict
Zarr encoding for `variable`
"""
- pass
+ encoding = variable.encoding.copy()
+
+ safe_to_drop = {"source", "original_shape"}
+ valid_encodings = {
+ "chunks",
+ "compressor",
+ "filters",
+ "cache_metadata",
+ "write_empty_chunks",
+ }
+
+ for k in safe_to_drop:
+ if k in encoding:
+ del encoding[k]
+ if raise_on_invalid:
+ invalid = [k for k in encoding if k not in valid_encodings]
+ if invalid:
+ raise ValueError(
+ f"unexpected encoding parameters for zarr backend: {invalid!r}"
+ )
+ else:
+ for k in list(encoding):
+ if k not in valid_encodings:
+ del encoding[k]
+ chunks = _determine_zarr_chunks(
+ encoding.get("chunks"), variable.chunks, variable.ndim, name, safe_chunks
+ )
+ encoding["chunks"] = chunks
+ return encoding
+
+
+# Function below is copied from conventions.encode_cf_variable.
+# The only change is to raise an error for object dtypes.
def encode_zarr_variable(var, needs_copy=True, name=None):
"""
Converts an Variable into an Variable which follows some
@@ -108,30 +310,211 @@ def encode_zarr_variable(var, needs_copy=True, name=None):
out : Variable
A variable which has been encoded as described above.
"""
- pass
+
+ var = conventions.encode_cf_variable(var, name=name)
+
+ # zarr allows unicode, but not variable-length strings, so it's both
+ # simpler and more compact to always encode as UTF-8 explicitly.
+ # TODO: allow toggling this explicitly via dtype in encoding.
+ coder = coding.strings.EncodedStringCoder(allows_unicode=True)
+ var = coder.encode(var, name=name)
+ var = coding.strings.ensure_fixed_length_bytes(var)
+
+ return var
def _validate_datatypes_for_zarr_append(vname, existing_var, new_var):
"""If variable exists in the store, confirm dtype of the data to append is compatible with
existing dtype.
"""
- pass
+ if (
+ np.issubdtype(new_var.dtype, np.number)
+ or np.issubdtype(new_var.dtype, np.datetime64)
+ or np.issubdtype(new_var.dtype, np.bool_)
+ or new_var.dtype == object
+ ):
+ # We can skip dtype equality checks under two conditions: (1) if the var to append is
+ # new to the dataset, because in this case there is no existing var to compare it to;
+ # or (2) if var to append's dtype is known to be easy-to-append, because in this case
+ # we can be confident appending won't cause problems. Examples of dtypes which are not
+ # easy-to-append include length-specified strings of type `|S*` or `<U*` (where * is a
+ # positive integer character length). For these dtypes, appending dissimilar lengths
+ # can result in truncation of appended data. Therefore, variables which already exist
+ # in the dataset, and with dtypes which are not known to be easy-to-append, necessitate
+ # exact dtype equality, as checked below.
+ pass
+ elif not new_var.dtype == existing_var.dtype:
+ raise ValueError(
+ f"Mismatched dtypes for variable {vname} between Zarr store on disk "
+ f"and dataset to append. Store has dtype {existing_var.dtype} but "
+ f"dataset to append has dtype {new_var.dtype}."
+ )
+
+
+def _validate_and_transpose_existing_dims(
+ var_name, new_var, existing_var, region, append_dim
+):
+ if new_var.dims != existing_var.dims:
+ if set(existing_var.dims) == set(new_var.dims):
+ new_var = new_var.transpose(*existing_var.dims)
+ else:
+ raise ValueError(
+ f"variable {var_name!r} already exists with different "
+ f"dimension names {existing_var.dims} != "
+ f"{new_var.dims}, but changing variable "
+ f"dimensions is not supported by to_zarr()."
+ )
+
+ existing_sizes = {}
+ for dim, size in existing_var.sizes.items():
+ if region is not None and dim in region:
+ start, stop, stride = region[dim].indices(size)
+ assert stride == 1 # region was already validated
+ size = stop - start
+ if dim != append_dim:
+ existing_sizes[dim] = size
+
+ new_sizes = {dim: size for dim, size in new_var.sizes.items() if dim != append_dim}
+ if existing_sizes != new_sizes:
+ raise ValueError(
+ f"variable {var_name!r} already exists with different "
+ f"dimension sizes: {existing_sizes} != {new_sizes}. "
+ f"to_zarr() only supports changing dimension sizes when "
+ f"explicitly appending, but append_dim={append_dim!r}. "
+ f"If you are attempting to write to a subset of the "
+ f"existing store without changing dimension sizes, "
+ f"consider using the region argument in to_zarr()."
+ )
+
+ return new_var
def _put_attrs(zarr_obj, attrs):
"""Raise a more informative error message for invalid attrs."""
- pass
+ try:
+ zarr_obj.attrs.put(attrs)
+ except TypeError as e:
+ raise TypeError("Invalid attribute in Dataset.attrs.") from e
+ return zarr_obj
class ZarrStore(AbstractWritableDataStore):
"""Store for reading and writing data via zarr"""
- __slots__ = ('zarr_group', '_append_dim', '_consolidate_on_close',
- '_group', '_mode', '_read_only', '_synchronizer', '_write_region',
- '_safe_chunks', '_write_empty', '_close_store_on_close')
- def __init__(self, zarr_group, mode=None, consolidate_on_close=False,
- append_dim=None, write_region=None, safe_chunks=True, write_empty:
- (bool | None)=None, close_store_on_close: bool=False):
+ __slots__ = (
+ "zarr_group",
+ "_append_dim",
+ "_consolidate_on_close",
+ "_group",
+ "_mode",
+ "_read_only",
+ "_synchronizer",
+ "_write_region",
+ "_safe_chunks",
+ "_write_empty",
+ "_close_store_on_close",
+ )
+
+ @classmethod
+ def open_store(
+ cls,
+ store,
+ mode: ZarrWriteModes = "r",
+ synchronizer=None,
+ group=None,
+ consolidated=False,
+ consolidate_on_close=False,
+ chunk_store=None,
+ storage_options=None,
+ append_dim=None,
+ write_region=None,
+ safe_chunks=True,
+ stacklevel=2,
+ zarr_version=None,
+ write_empty: bool | None = None,
+ ):
+
+ zarr_group, consolidate_on_close, close_store_on_close = _get_open_params(
+ store=store,
+ mode=mode,
+ synchronizer=synchronizer,
+ group=group,
+ consolidated=consolidated,
+ consolidate_on_close=consolidate_on_close,
+ chunk_store=chunk_store,
+ storage_options=storage_options,
+ stacklevel=stacklevel,
+ zarr_version=zarr_version,
+ )
+ group_paths = [node for node in _iter_zarr_groups(zarr_group, parent=group)]
+ return {
+ group: cls(
+ zarr_group.get(group),
+ mode,
+ consolidate_on_close,
+ append_dim,
+ write_region,
+ safe_chunks,
+ write_empty,
+ close_store_on_close,
+ )
+ for group in group_paths
+ }
+
+ @classmethod
+ def open_group(
+ cls,
+ store,
+ mode: ZarrWriteModes = "r",
+ synchronizer=None,
+ group=None,
+ consolidated=False,
+ consolidate_on_close=False,
+ chunk_store=None,
+ storage_options=None,
+ append_dim=None,
+ write_region=None,
+ safe_chunks=True,
+ stacklevel=2,
+ zarr_version=None,
+ write_empty: bool | None = None,
+ ):
+
+ zarr_group, consolidate_on_close, close_store_on_close = _get_open_params(
+ store=store,
+ mode=mode,
+ synchronizer=synchronizer,
+ group=group,
+ consolidated=consolidated,
+ consolidate_on_close=consolidate_on_close,
+ chunk_store=chunk_store,
+ storage_options=storage_options,
+ stacklevel=stacklevel,
+ zarr_version=zarr_version,
+ )
+
+ return cls(
+ zarr_group,
+ mode,
+ consolidate_on_close,
+ append_dim,
+ write_region,
+ safe_chunks,
+ write_empty,
+ close_store_on_close,
+ )
+
+ def __init__(
+ self,
+ zarr_group,
+ mode=None,
+ consolidate_on_close=False,
+ append_dim=None,
+ write_region=None,
+ safe_chunks=True,
+ write_empty: bool | None = None,
+ close_store_on_close: bool = False,
+ ):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
self._synchronizer = self.zarr_group.synchronizer
@@ -144,8 +527,88 @@ class ZarrStore(AbstractWritableDataStore):
self._write_empty = write_empty
self._close_store_on_close = close_store_on_close
- def store(self, variables, attributes, check_encoding_set=frozenset(),
- writer=None, unlimited_dims=None):
+ @property
+ def ds(self):
+ # TODO: consider deprecating this in favor of zarr_group
+ return self.zarr_group
+
+ def open_store_variable(self, name, zarr_array=None):
+ if zarr_array is None:
+ zarr_array = self.zarr_group[name]
+ data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
+ try_nczarr = self._mode == "r"
+ dimensions, attributes = _get_zarr_dims_and_attrs(
+ zarr_array, DIMENSION_KEY, try_nczarr
+ )
+ attributes = dict(attributes)
+
+ # TODO: this should not be needed once
+ # https://github.com/zarr-developers/zarr-python/issues/1269 is resolved.
+ attributes.pop("filters", None)
+
+ encoding = {
+ "chunks": zarr_array.chunks,
+ "preferred_chunks": dict(zip(dimensions, zarr_array.chunks)),
+ "compressor": zarr_array.compressor,
+ "filters": zarr_array.filters,
+ }
+ # _FillValue needs to be in attributes, not encoding, so it will get
+ # picked up by decode_cf
+ if getattr(zarr_array, "fill_value") is not None:
+ attributes["_FillValue"] = zarr_array.fill_value
+
+ return Variable(dimensions, data, attributes, encoding)
+
+ def get_variables(self):
+ return FrozenDict(
+ (k, self.open_store_variable(k, v)) for k, v in self.zarr_group.arrays()
+ )
+
+ def get_attrs(self):
+ return {
+ k: v
+ for k, v in self.zarr_group.attrs.asdict().items()
+ if not k.lower().startswith("_nc")
+ }
+
+ def get_dimensions(self):
+ try_nczarr = self._mode == "r"
+ dimensions = {}
+ for k, v in self.zarr_group.arrays():
+ dim_names, _ = _get_zarr_dims_and_attrs(v, DIMENSION_KEY, try_nczarr)
+ for d, s in zip(dim_names, v.shape):
+ if d in dimensions and dimensions[d] != s:
+ raise ValueError(
+ f"found conflicting lengths for dimension {d} "
+ f"({s} != {dimensions[d]})"
+ )
+ dimensions[d] = s
+ return dimensions
+
+ def set_dimensions(self, variables, unlimited_dims=None):
+ if unlimited_dims is not None:
+ raise NotImplementedError(
+ "Zarr backend doesn't know how to handle unlimited dimensions"
+ )
+
+ def set_attributes(self, attributes):
+ _put_attrs(self.zarr_group, attributes)
+
+ def encode_variable(self, variable):
+ variable = encode_zarr_variable(variable)
+ return variable
+
+ def encode_attribute(self, a):
+ return encode_zarr_attr_value(a)
+
+ def store(
+ self,
+ variables,
+ attributes,
+ check_encoding_set=frozenset(),
+ writer=None,
+ unlimited_dims=None,
+ ):
"""
Top level method for putting data on this store, this method:
- encodes variables/attributes
@@ -168,10 +631,102 @@ class ZarrStore(AbstractWritableDataStore):
dimension on which the zarray will be appended
only needed in append mode
"""
+ import zarr
+
+ existing_keys = tuple(self.zarr_group.array_keys())
+
+ if self._mode == "r+":
+ new_names = [k for k in variables if k not in existing_keys]
+ if new_names:
+ raise ValueError(
+ f"dataset contains non-pre-existing variables {new_names}, "
+ "which is not allowed in ``xarray.Dataset.to_zarr()`` with "
+ "``mode='r+'``. To allow writing new variables, set ``mode='a'``."
+ )
+
+ if self._append_dim is not None and self._append_dim not in existing_keys:
+ # For dimensions without coordinate values, we must parse
+ # the _ARRAY_DIMENSIONS attribute on *all* arrays to check if it
+ # is a valid existing dimension name.
+ # TODO: This `get_dimensions` method also does shape checking
+ # which isn't strictly necessary for our check.
+ existing_dims = self.get_dimensions()
+ if self._append_dim not in existing_dims:
+ raise ValueError(
+ f"append_dim={self._append_dim!r} does not match any existing "
+ f"dataset dimensions {existing_dims}"
+ )
+
+ existing_variable_names = {
+ vn for vn in variables if _encode_variable_name(vn) in existing_keys
+ }
+ new_variable_names = set(variables) - existing_variable_names
+ variables_encoded, attributes = self.encode(
+ {vn: variables[vn] for vn in new_variable_names}, attributes
+ )
+
+ if existing_variable_names:
+ # We make sure that values to be appended are encoded *exactly*
+ # as the current values in the store.
+ # To do so, we decode variables directly to access the proper encoding,
+ # without going via xarray.Dataset to avoid needing to load
+ # index variables into memory.
+ existing_vars, _, _ = conventions.decode_cf_variables(
+ variables={
+ k: self.open_store_variable(name=k) for k in existing_variable_names
+ },
+ # attributes = {} since we don't care about parsing the global
+ # "coordinates" attribute
+ attributes={},
+ )
+ # Modified variables must use the same encoding as the store.
+ vars_with_encoding = {}
+ for vn in existing_variable_names:
+ if self._mode in ["a", "a-", "r+"]:
+ _validate_datatypes_for_zarr_append(
+ vn, existing_vars[vn], variables[vn]
+ )
+ vars_with_encoding[vn] = variables[vn].copy(deep=False)
+ vars_with_encoding[vn].encoding = existing_vars[vn].encoding
+ vars_with_encoding, _ = self.encode(vars_with_encoding, {})
+ variables_encoded.update(vars_with_encoding)
+
+ for var_name in existing_variable_names:
+ variables_encoded[var_name] = _validate_and_transpose_existing_dims(
+ var_name,
+ variables_encoded[var_name],
+ existing_vars[var_name],
+ self._write_region,
+ self._append_dim,
+ )
+
+ if self._mode not in ["r", "r+"]:
+ self.set_attributes(attributes)
+ self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims)
+
+ # if we are appending to an append_dim, only write either
+ # - new variables not already present, OR
+ # - variables with the append_dim in their dimensions
+ # We do NOT overwrite other variables.
+ if self._mode == "a-" and self._append_dim is not None:
+ variables_to_set = {
+ k: v
+ for k, v in variables_encoded.items()
+ if (k not in existing_variable_names) or (self._append_dim in v.dims)
+ }
+ else:
+ variables_to_set = variables_encoded
+
+ self.set_variables(
+ variables_to_set, check_encoding_set, writer, unlimited_dims=unlimited_dims
+ )
+ if self._consolidate_on_close:
+ zarr.consolidate_metadata(self.zarr_group.store)
+
+ def sync(self):
pass
- def set_variables(self, variables, check_encoding_set, writer,
- unlimited_dims=None):
+ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):
"""
This provides a centralized method to set the variables on the data
store.
@@ -188,16 +743,221 @@ class ZarrStore(AbstractWritableDataStore):
List of dimension names that should be treated as unlimited
dimensions.
"""
- pass
+ import zarr
+
+ existing_keys = tuple(self.zarr_group.array_keys())
+
+ for vn, v in variables.items():
+ name = _encode_variable_name(vn)
+ attrs = v.attrs.copy()
+ dims = v.dims
+ dtype = v.dtype
+ shape = v.shape
+
+ fill_value = attrs.pop("_FillValue", None)
+ if v.encoding == {"_FillValue": None} and fill_value is None:
+ v.encoding = {}
+
+ # We need to do this for both new and existing variables to ensure we're not
+ # writing to a partial chunk, even though we don't use the `encoding` value
+ # when writing to an existing variable. See
+ # https://github.com/pydata/xarray/issues/8371 for details.
+ encoding = extract_zarr_variable_encoding(
+ v,
+ raise_on_invalid=vn in check_encoding_set,
+ name=vn,
+ safe_chunks=self._safe_chunks,
+ )
+
+ if name in existing_keys:
+ # existing variable
+ # TODO: if mode="a", consider overriding the existing variable
+ # metadata. This would need some case work properly with region
+ # and append_dim.
+ if self._write_empty is not None:
+ # Write to zarr_group.chunk_store instead of zarr_group.store
+ # See https://github.com/pydata/xarray/pull/8326#discussion_r1365311316 for a longer explanation
+ # The open_consolidated() enforces a mode of r or r+
+ # (and to_zarr with region provided enforces a read mode of r+),
+ # and this function makes sure the resulting Group has a store of type ConsolidatedMetadataStore
+ # and a 'normal Store subtype for chunk_store.
+ # The exact type depends on if a local path was used, or a URL of some sort,
+ # but the point is that it's not a read-only ConsolidatedMetadataStore.
+ # It is safe to write chunk data to the chunk_store because no metadata would be changed by
+ # to_zarr with the region parameter:
+ # - Because the write mode is enforced to be r+, no new variables can be added to the store
+ # (this is also checked and enforced in xarray.backends.api.py::to_zarr()).
+ # - Existing variables already have their attrs included in the consolidated metadata file.
+ # - The size of dimensions can not be expanded, that would require a call using `append_dim`
+ # which is mutually exclusive with `region`
+ zarr_array = zarr.open(
+ store=self.zarr_group.chunk_store,
+ path=f"{self.zarr_group.name}/{name}",
+ write_empty_chunks=self._write_empty,
+ )
+ else:
+ zarr_array = self.zarr_group[name]
+ else:
+ # new variable
+ encoded_attrs = {}
+ # the magic for storing the hidden dimension data
+ encoded_attrs[DIMENSION_KEY] = dims
+ for k2, v2 in attrs.items():
+ encoded_attrs[k2] = self.encode_attribute(v2)
+
+ if coding.strings.check_vlen_dtype(dtype) is str:
+ dtype = str
+
+ if self._write_empty is not None:
+ if (
+ "write_empty_chunks" in encoding
+ and encoding["write_empty_chunks"] != self._write_empty
+ ):
+ raise ValueError(
+ 'Differing "write_empty_chunks" values in encoding and parameters'
+ f'Got {encoding["write_empty_chunks"] = } and {self._write_empty = }'
+ )
+ else:
+ encoding["write_empty_chunks"] = self._write_empty
+
+ zarr_array = self.zarr_group.create(
+ name,
+ shape=shape,
+ dtype=dtype,
+ fill_value=fill_value,
+ **encoding,
+ )
+ zarr_array = _put_attrs(zarr_array, encoded_attrs)
+
+ write_region = self._write_region if self._write_region is not None else {}
+ write_region = {dim: write_region.get(dim, slice(None)) for dim in dims}
+
+ if self._append_dim is not None and self._append_dim in dims:
+ # resize existing variable
+ append_axis = dims.index(self._append_dim)
+ assert write_region[self._append_dim] == slice(None)
+ write_region[self._append_dim] = slice(
+ zarr_array.shape[append_axis], None
+ )
+
+ new_shape = list(zarr_array.shape)
+ new_shape[append_axis] += v.shape[append_axis]
+ zarr_array.resize(new_shape)
+
+ region = tuple(write_region[dim] for dim in dims)
+ writer.add(v.data, zarr_array, region)
+
+ def close(self) -> None:
+ if self._close_store_on_close:
+ self.zarr_group.store.close()
+
+ def _auto_detect_regions(self, ds, region):
+ for dim, val in region.items():
+ if val != "auto":
+ continue
+
+ if dim not in ds._variables:
+ # unindexed dimension
+ region[dim] = slice(0, ds.sizes[dim])
+ continue
+
+ variable = conventions.decode_cf_variable(
+ dim, self.open_store_variable(dim).compute()
+ )
+ assert variable.dims == (dim,)
+ index = pd.Index(variable.data)
+ idxs = index.get_indexer(ds[dim].data)
+ if (idxs == -1).any():
+ raise KeyError(
+ f"Not all values of coordinate '{dim}' in the new array were"
+ " found in the original store. Writing to a zarr region slice"
+ " requires that no dimensions or metadata are changed by the write."
+ )
+
+ if (np.diff(idxs) != 1).any():
+ raise ValueError(
+ f"The auto-detected region of coordinate '{dim}' for writing new data"
+ " to the original store had non-contiguous indices. Writing to a zarr"
+ " region slice requires that the new data constitute a contiguous subset"
+ " of the original store."
+ )
+ region[dim] = slice(idxs[0], idxs[-1] + 1)
+ return region
+
+ def _validate_and_autodetect_region(self, ds) -> None:
+ region = self._write_region
+
+ if region == "auto":
+ region = {dim: "auto" for dim in ds.dims}
-def open_zarr(store, group=None, synchronizer=None, chunks='auto',
- decode_cf=True, mask_and_scale=True, decode_times=True,
- concat_characters=True, decode_coords=True, drop_variables=None,
- consolidated=None, overwrite_encoded_chunks=False, chunk_store=None,
- storage_options=None, decode_timedelta=None, use_cftime=None,
- zarr_version=None, chunked_array_type: (str | None)=None,
- from_array_kwargs: (dict[str, Any] | None)=None, **kwargs):
+ if not isinstance(region, dict):
+ raise TypeError(f"``region`` must be a dict, got {type(region)}")
+ if any(v == "auto" for v in region.values()):
+ if self._mode != "r+":
+ raise ValueError(
+ f"``mode`` must be 'r+' when using ``region='auto'``, got {self._mode!r}"
+ )
+ region = self._auto_detect_regions(ds, region)
+
+ # validate before attempting to auto-detect since the auto-detection
+ # should always return a valid slice.
+ for k, v in region.items():
+ if k not in ds.dims:
+ raise ValueError(
+ f"all keys in ``region`` are not in Dataset dimensions, got "
+ f"{list(region)} and {list(ds.dims)}"
+ )
+ if not isinstance(v, slice):
+ raise TypeError(
+ "all values in ``region`` must be slice objects, got "
+ f"region={region}"
+ )
+ if v.step not in {1, None}:
+ raise ValueError(
+ "step on all slices in ``region`` must be 1 or None, got "
+ f"region={region}"
+ )
+
+ non_matching_vars = [
+ k for k, v in ds.variables.items() if not set(region).intersection(v.dims)
+ ]
+ if non_matching_vars:
+ raise ValueError(
+ f"when setting `region` explicitly in to_zarr(), all "
+ f"variables in the dataset to write must have at least "
+ f"one dimension in common with the region's dimensions "
+ f"{list(region.keys())}, but that is not "
+ f"the case for some variables here. To drop these variables "
+ f"from this dataset before exporting to zarr, write: "
+ f".drop_vars({non_matching_vars!r})"
+ )
+
+ self._write_region = region
+
+
+def open_zarr(
+ store,
+ group=None,
+ synchronizer=None,
+ chunks="auto",
+ decode_cf=True,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables=None,
+ consolidated=None,
+ overwrite_encoded_chunks=False,
+ chunk_store=None,
+ storage_options=None,
+ decode_timedelta=None,
+ use_cftime=None,
+ zarr_version=None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+):
"""Load and decode a dataset from a Zarr store.
The `store` object should be a valid store for a Zarr group. `store`
@@ -310,7 +1070,55 @@ def open_zarr(store, group=None, synchronizer=None, chunks='auto',
----------
http://zarr.readthedocs.io/
"""
- pass
+ from xarray.backends.api import open_dataset
+
+ if from_array_kwargs is None:
+ from_array_kwargs = {}
+
+ if chunks == "auto":
+ try:
+ guess_chunkmanager(
+ chunked_array_type
+ ) # attempt to import that parallel backend
+
+ chunks = {}
+ except ValueError:
+ chunks = None
+
+ if kwargs:
+ raise TypeError(
+ "open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys())
+ )
+
+ backend_kwargs = {
+ "synchronizer": synchronizer,
+ "consolidated": consolidated,
+ "overwrite_encoded_chunks": overwrite_encoded_chunks,
+ "chunk_store": chunk_store,
+ "storage_options": storage_options,
+ "stacklevel": 4,
+ "zarr_version": zarr_version,
+ }
+
+ ds = open_dataset(
+ filename_or_obj=store,
+ group=group,
+ decode_cf=decode_cf,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ engine="zarr",
+ chunks=chunks,
+ drop_variables=drop_variables,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=from_array_kwargs,
+ backend_kwargs=backend_kwargs,
+ decode_timedelta=decode_timedelta,
+ use_cftime=use_cftime,
+ zarr_version=zarr_version,
+ )
+ return ds
class ZarrBackendEntrypoint(BackendEntrypoint):
@@ -324,10 +1132,212 @@ class ZarrBackendEntrypoint(BackendEntrypoint):
--------
backends.ZarrStore
"""
- description = 'Open zarr files (.zarr) using zarr in Xarray'
- url = (
- 'https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html'
- )
+
+ description = "Open zarr files (.zarr) using zarr in Xarray"
+ url = "https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html"
+
+ def guess_can_open(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ ) -> bool:
+ if isinstance(filename_or_obj, (str, os.PathLike)):
+ _, ext = os.path.splitext(filename_or_obj)
+ return ext in {".zarr"}
+
+ return False
+
+ def open_dataset( # type: ignore[override] # allow LSP violation, not supporting **kwargs
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ group=None,
+ mode="r",
+ synchronizer=None,
+ consolidated=None,
+ chunk_store=None,
+ storage_options=None,
+ stacklevel=3,
+ zarr_version=None,
+ store=None,
+ engine=None,
+ ) -> Dataset:
+ filename_or_obj = _normalize_path(filename_or_obj)
+ if not store:
+ store = ZarrStore.open_group(
+ filename_or_obj,
+ group=group,
+ mode=mode,
+ synchronizer=synchronizer,
+ consolidated=consolidated,
+ consolidate_on_close=False,
+ chunk_store=chunk_store,
+ storage_options=storage_options,
+ stacklevel=stacklevel + 1,
+ zarr_version=zarr_version,
+ )
+
+ store_entrypoint = StoreBackendEntrypoint()
+ with close_on_error(store):
+ ds = store_entrypoint.open_dataset(
+ store,
+ mask_and_scale=mask_and_scale,
+ decode_times=decode_times,
+ concat_characters=concat_characters,
+ decode_coords=decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ return ds
+
+ def open_datatree(
+ self,
+ filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
+ *,
+ mask_and_scale=True,
+ decode_times=True,
+ concat_characters=True,
+ decode_coords=True,
+ drop_variables: str | Iterable[str] | None = None,
+ use_cftime=None,
+ decode_timedelta=None,
+ group: str | Iterable[str] | Callable | None = None,
+ mode="r",
+ synchronizer=None,
+ consolidated=None,
+ chunk_store=None,
+ storage_options=None,
+ stacklevel=3,
+ zarr_version=None,
+ **kwargs,
+ ) -> DataTree:
+ from xarray.backends.api import open_dataset
+ from xarray.core.datatree import DataTree
+ from xarray.core.treenode import NodePath
+
+ filename_or_obj = _normalize_path(filename_or_obj)
+ if group:
+ parent = NodePath("/") / NodePath(group)
+ stores = ZarrStore.open_store(filename_or_obj, group=parent)
+ if not stores:
+ ds = open_dataset(
+ filename_or_obj, group=parent, engine="zarr", **kwargs
+ )
+ return DataTree.from_dict({str(parent): ds})
+ else:
+ parent = NodePath("/")
+ stores = ZarrStore.open_store(filename_or_obj, group=parent)
+ ds = open_dataset(filename_or_obj, group=parent, engine="zarr", **kwargs)
+ tree_root = DataTree.from_dict({str(parent): ds})
+ for path_group, store in stores.items():
+ ds = open_dataset(
+ filename_or_obj, store=store, group=path_group, engine="zarr", **kwargs
+ )
+ new_node: DataTree = DataTree(name=NodePath(path_group).name, data=ds)
+ tree_root._set_item(
+ path_group,
+ new_node,
+ allow_overwrite=False,
+ new_nodes_along_path=True,
+ )
+ return tree_root
+
+
+def _iter_zarr_groups(root, parent="/"):
+ from xarray.core.treenode import NodePath
+
+ parent = NodePath(parent)
+ for path, group in root.groups():
+ gpath = parent / path
+ yield str(gpath)
+ yield from _iter_zarr_groups(group, parent=gpath)
+
+
+def _get_open_params(
+ store,
+ mode,
+ synchronizer,
+ group,
+ consolidated,
+ consolidate_on_close,
+ chunk_store,
+ storage_options,
+ stacklevel,
+ zarr_version,
+):
+ import zarr
+
+ # zarr doesn't support pathlib.Path objects yet. zarr-python#601
+ if isinstance(store, os.PathLike):
+ store = os.fspath(store)
+
+ if zarr_version is None:
+ # default to 2 if store doesn't specify it's version (e.g. a path)
+ zarr_version = getattr(store, "_store_version", 2)
+
+ open_kwargs = dict(
+ # mode='a-' is a handcrafted xarray specialty
+ mode="a" if mode == "a-" else mode,
+ synchronizer=synchronizer,
+ path=group,
+ )
+ open_kwargs["storage_options"] = storage_options
+ if zarr_version > 2:
+ open_kwargs["zarr_version"] = zarr_version
+
+ if consolidated or consolidate_on_close:
+ raise ValueError(
+ "consolidated metadata has not been implemented for zarr "
+ f"version {zarr_version} yet. Set consolidated=False for "
+ f"zarr version {zarr_version}. See also "
+ "https://github.com/zarr-developers/zarr-specs/issues/136"
+ )
+
+ if consolidated is None:
+ consolidated = False
+
+ if chunk_store is not None:
+ open_kwargs["chunk_store"] = chunk_store
+ if consolidated is None:
+ consolidated = False
+
+ if consolidated is None:
+ try:
+ zarr_group = zarr.open_consolidated(store, **open_kwargs)
+ except KeyError:
+ try:
+ zarr_group = zarr.open_group(store, **open_kwargs)
+ warnings.warn(
+ "Failed to open Zarr store with consolidated metadata, "
+ "but successfully read with non-consolidated metadata. "
+ "This is typically much slower for opening a dataset. "
+ "To silence this warning, consider:\n"
+ "1. Consolidating metadata in this existing store with "
+ "zarr.consolidate_metadata().\n"
+ "2. Explicitly setting consolidated=False, to avoid trying "
+ "to read consolidate metadata, or\n"
+ "3. Explicitly setting consolidated=True, to raise an "
+ "error in this case instead of falling back to try "
+ "reading non-consolidated metadata.",
+ RuntimeWarning,
+ stacklevel=stacklevel,
+ )
+ except zarr.errors.GroupNotFoundError:
+ raise FileNotFoundError(f"No such file or directory: '{store}'")
+ elif consolidated:
+ # TODO: an option to pass the metadata_key keyword
+ zarr_group = zarr.open_consolidated(store, **open_kwargs)
+ else:
+ zarr_group = zarr.open_group(store, **open_kwargs)
+ close_store_on_close = zarr_group.store is not store
+ return zarr_group, consolidate_on_close, close_store_on_close
-BACKEND_ENTRYPOINTS['zarr'] = 'zarr', ZarrBackendEntrypoint
+BACKEND_ENTRYPOINTS["zarr"] = ("zarr", ZarrBackendEntrypoint)
diff --git a/xarray/coding/calendar_ops.py b/xarray/coding/calendar_ops.py
index 7de8fc94..6f492e78 100644
--- a/xarray/coding/calendar_ops.py
+++ b/xarray/coding/calendar_ops.py
@@ -1,25 +1,48 @@
from __future__ import annotations
+
import numpy as np
import pandas as pd
+
from xarray.coding.cftime_offsets import date_range_like, get_date_type
from xarray.coding.cftimeindex import CFTimeIndex
-from xarray.coding.times import _should_cftime_be_used, convert_times
+from xarray.coding.times import (
+ _should_cftime_be_used,
+ convert_times,
+)
from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like
+
try:
import cftime
except ImportError:
cftime = None
-_CALENDARS_WITHOUT_YEAR_ZERO = ['gregorian', 'proleptic_gregorian',
- 'julian', 'standard']
-def _days_in_year(year, calendar, use_cftime=True):
- """Return the number of days in the input year according to the input calendar."""
- pass
+_CALENDARS_WITHOUT_YEAR_ZERO = [
+ "gregorian",
+ "proleptic_gregorian",
+ "julian",
+ "standard",
+]
-def convert_calendar(obj, calendar, dim='time', align_on=None, missing=None,
- use_cftime=None):
+def _days_in_year(year, calendar, use_cftime=True):
+ """Return the number of days in the input year according to the input calendar."""
+ date_type = get_date_type(calendar, use_cftime=use_cftime)
+ if year == -1 and calendar in _CALENDARS_WITHOUT_YEAR_ZERO:
+ difference = date_type(year + 2, 1, 1) - date_type(year, 1, 1)
+ else:
+ difference = date_type(year + 1, 1, 1) - date_type(year, 1, 1)
+ return difference.days
+
+
+def convert_calendar(
+ obj,
+ calendar,
+ dim="time",
+ align_on=None,
+ missing=None,
+ use_cftime=None,
+):
"""Transform a time-indexed Dataset or DataArray to one that uses another calendar.
This function only converts the individual timestamps; it does not modify any
@@ -134,14 +157,102 @@ def convert_calendar(obj, calendar, dim='time', align_on=None, missing=None,
This option is best used on daily data.
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ time = obj[dim]
+ if not _contains_datetime_like_objects(time.variable):
+ raise ValueError(f"Coordinate {dim} must contain datetime objects.")
+
+ use_cftime = _should_cftime_be_used(time, calendar, use_cftime)
+
+ source_calendar = time.dt.calendar
+ # Do nothing if request calendar is the same as the source
+ # AND source is np XOR use_cftime
+ if source_calendar == calendar and is_np_datetime_like(time.dtype) ^ use_cftime:
+ return obj
+
+ if (time.dt.year == 0).any() and calendar in _CALENDARS_WITHOUT_YEAR_ZERO:
+ raise ValueError(
+ f"Source time coordinate contains dates with year 0, which is not supported by target calendar {calendar}."
+ )
+
+ if (source_calendar == "360_day" or calendar == "360_day") and align_on is None:
+ raise ValueError(
+ "Argument `align_on` must be specified with either 'date' or "
+ "'year' when converting to or from a '360_day' calendar."
+ )
+
+ if source_calendar != "360_day" and calendar != "360_day":
+ align_on = "date"
+
+ out = obj.copy()
+
+ if align_on in ["year", "random"]:
+ # Special case for conversion involving 360_day calendar
+ if align_on == "year":
+ # Instead of translating dates directly, this tries to keep the position within a year similar.
+ new_doy = time.groupby(f"{dim}.year").map(
+ _interpolate_day_of_year,
+ target_calendar=calendar,
+ use_cftime=use_cftime,
+ )
+ elif align_on == "random":
+ # The 5 days to remove are randomly chosen, one for each of the five 72-days periods of the year.
+ new_doy = time.groupby(f"{dim}.year").map(
+ _random_day_of_year, target_calendar=calendar, use_cftime=use_cftime
+ )
+ # Convert the source datetimes, but override the day of year with our new day of years.
+ out[dim] = DataArray(
+ [
+ _convert_to_new_calendar_with_new_day_of_year(
+ date, newdoy, calendar, use_cftime
+ )
+ for date, newdoy in zip(time.variable._data.array, new_doy)
+ ],
+ dims=(dim,),
+ name=dim,
+ )
+ # Remove duplicate timestamps, happens when reducing the number of days
+ out = out.isel({dim: np.unique(out[dim], return_index=True)[1]})
+ elif align_on == "date":
+ new_times = convert_times(
+ time.data,
+ get_date_type(calendar, use_cftime=use_cftime),
+ raise_on_invalid=False,
+ )
+ out[dim] = new_times
+
+ # Remove NaN that where put on invalid dates in target calendar
+ out = out.where(out[dim].notnull(), drop=True)
+
+ if use_cftime:
+ # Reassign times to ensure time index of output is a CFTimeIndex
+ # (previously it was an Index due to the presence of NaN values).
+ # Note this is not needed in the case that the output time index is
+ # a DatetimeIndex, since DatetimeIndexes can handle NaN values.
+ out[dim] = CFTimeIndex(out[dim].data)
+
+ if missing is not None:
+ time_target = date_range_like(time, calendar=calendar, use_cftime=use_cftime)
+ out = out.reindex({dim: time_target}, fill_value=missing)
+
+ # Copy attrs but remove `calendar` if still present.
+ out[dim].attrs.update(time.attrs)
+ out[dim].attrs.pop("calendar", None)
+ return out
def _interpolate_day_of_year(time, target_calendar, use_cftime):
"""Returns the nearest day in the target calendar of the corresponding
"decimal year" in the source calendar.
"""
- pass
+ year = int(time.dt.year[0])
+ source_calendar = time.dt.calendar
+ return np.round(
+ _days_in_year(year, target_calendar, use_cftime)
+ * time.dt.dayofyear
+ / _days_in_year(year, source_calendar, use_cftime)
+ ).astype(int)
def _random_day_of_year(time, target_calendar, use_cftime):
@@ -149,21 +260,51 @@ def _random_day_of_year(time, target_calendar, use_cftime):
Removes Feb 29th and five other days chosen randomly within five sections of 72 days.
"""
- pass
-
-
-def _convert_to_new_calendar_with_new_day_of_year(date, day_of_year,
- calendar, use_cftime):
+ year = int(time.dt.year[0])
+ source_calendar = time.dt.calendar
+ new_doy = np.arange(360) + 1
+ rm_idx = np.random.default_rng().integers(0, 72, 5) + 72 * np.arange(5)
+ if source_calendar == "360_day":
+ for idx in rm_idx:
+ new_doy[idx + 1 :] = new_doy[idx + 1 :] + 1
+ if _days_in_year(year, target_calendar, use_cftime) == 366:
+ new_doy[new_doy >= 60] = new_doy[new_doy >= 60] + 1
+ elif target_calendar == "360_day":
+ new_doy = np.insert(new_doy, rm_idx - np.arange(5), -1)
+ if _days_in_year(year, source_calendar, use_cftime) == 366:
+ new_doy = np.insert(new_doy, 60, -1)
+ return new_doy[time.dt.dayofyear - 1]
+
+
+def _convert_to_new_calendar_with_new_day_of_year(
+ date, day_of_year, calendar, use_cftime
+):
"""Convert a datetime object to another calendar with a new day of year.
Redefines the day of year (and thus ignores the month and day information
from the source datetime).
Nanosecond information is lost as cftime.datetime doesn't support it.
"""
- pass
-
-
-def _datetime_to_decimal_year(times, dim='time', calendar=None):
+ new_date = cftime.num2date(
+ day_of_year - 1,
+ f"days since {date.year}-01-01",
+ calendar=calendar if use_cftime else "standard",
+ )
+ try:
+ return get_date_type(calendar, use_cftime)(
+ date.year,
+ new_date.month,
+ new_date.day,
+ date.hour,
+ date.minute,
+ date.second,
+ date.microsecond,
+ )
+ except ValueError:
+ return np.nan
+
+
+def _datetime_to_decimal_year(times, dim="time", calendar=None):
"""Convert a datetime DataArray to decimal years according to its calendar or the given one.
The decimal year of a timestamp is its year plus its sub-year component
@@ -171,10 +312,27 @@ def _datetime_to_decimal_year(times, dim='time', calendar=None):
Ex: '2000-03-01 12:00' is 2000.1653 in a standard calendar,
2000.16301 in a "noleap" or 2000.16806 in a "360_day".
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ calendar = calendar or times.dt.calendar
+
+ if is_np_datetime_like(times.dtype):
+ times = times.copy(data=convert_times(times.values, get_date_type("standard")))
+
+ def _make_index(time):
+ year = int(time.dt.year[0])
+ doys = cftime.date2num(time, f"days since {year:04d}-01-01", calendar=calendar)
+ return DataArray(
+ year + doys / _days_in_year(year, calendar),
+ dims=(dim,),
+ coords=time.coords,
+ name=dim,
+ )
+
+ return times.groupby(f"{dim}.year").map(_make_index)
-def interp_calendar(source, target, dim='time'):
+def interp_calendar(source, target, dim="time"):
"""Interpolates a DataArray or Dataset indexed by a time coordinate to
another calendar based on decimal year measure.
@@ -202,4 +360,31 @@ def interp_calendar(source, target, dim='time'):
DataArray or Dataset
The source interpolated on the decimal years of target,
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ if isinstance(target, (pd.DatetimeIndex, CFTimeIndex)):
+ target = DataArray(target, dims=(dim,), name=dim)
+
+ if not _contains_datetime_like_objects(
+ source[dim].variable
+ ) or not _contains_datetime_like_objects(target.variable):
+ raise ValueError(
+ f"Both 'source.{dim}' and 'target' must contain datetime objects."
+ )
+
+ source_calendar = source[dim].dt.calendar
+ target_calendar = target.dt.calendar
+
+ if (
+ source[dim].time.dt.year == 0
+ ).any() and target_calendar in _CALENDARS_WITHOUT_YEAR_ZERO:
+ raise ValueError(
+ f"Source time coordinate contains dates with year 0, which is not supported by target calendar {target_calendar}."
+ )
+
+ out = source.copy()
+ out[dim] = _datetime_to_decimal_year(source[dim], dim=dim, calendar=source_calendar)
+ target_idx = _datetime_to_decimal_year(target, dim=dim, calendar=target_calendar)
+ out = out.interp(**{dim: target_idx})
+ out[dim] = target
+ return out
diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py
index a1176778..9dbc60ef 100644
--- a/xarray/coding/cftime_offsets.py
+++ b/xarray/coding/cftime_offsets.py
@@ -1,30 +1,114 @@
"""Time offset classes for use with cftime.datetime objects"""
+
+# The offset classes and mechanisms for generating time ranges defined in
+# this module were copied/adapted from those defined in pandas. See in
+# particular the objects and methods defined in pandas.tseries.offsets
+# and pandas.core.indexes.datetimes.
+
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
import re
from collections.abc import Mapping
from datetime import datetime, timedelta
from functools import partial
from typing import TYPE_CHECKING, ClassVar, Literal
+
import numpy as np
import pandas as pd
from packaging.version import Version
+
from xarray.coding.cftimeindex import CFTimeIndex, _parse_iso8601_with_reso
-from xarray.coding.times import _is_standard_calendar, _should_cftime_be_used, convert_time_or_go_back, format_cftime_datetime
+from xarray.coding.times import (
+ _is_standard_calendar,
+ _should_cftime_be_used,
+ convert_time_or_go_back,
+ format_cftime_datetime,
+)
from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like
-from xarray.core.pdcompat import NoDefault, count_not_none, nanosecond_precision_timestamp, no_default
+from xarray.core.pdcompat import (
+ NoDefault,
+ count_not_none,
+ nanosecond_precision_timestamp,
+ no_default,
+)
from xarray.core.utils import emit_user_level_warning
+
try:
import cftime
except ImportError:
cftime = None
+
+
if TYPE_CHECKING:
from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias
-DayOption: TypeAlias = Literal['start', 'end']
+
+
+DayOption: TypeAlias = Literal["start", "end"]
+
+
+def _nanosecond_precision_timestamp(*args, **kwargs):
+ # As of pandas version 3.0, pd.to_datetime(Timestamp(...)) will try to
+ # infer the appropriate datetime precision. Until xarray supports
+ # non-nanosecond precision times, we will use this constructor wrapper to
+ # explicitly create nanosecond-precision Timestamp objects.
+ return pd.Timestamp(*args, **kwargs).as_unit("ns")
def get_date_type(calendar, use_cftime=True):
"""Return the cftime date type for a given calendar name."""
- pass
+ if cftime is None:
+ raise ImportError("cftime is required for dates with non-standard calendars")
+ else:
+ if _is_standard_calendar(calendar) and not use_cftime:
+ return _nanosecond_precision_timestamp
+
+ calendars = {
+ "noleap": cftime.DatetimeNoLeap,
+ "360_day": cftime.Datetime360Day,
+ "365_day": cftime.DatetimeNoLeap,
+ "366_day": cftime.DatetimeAllLeap,
+ "gregorian": cftime.DatetimeGregorian,
+ "proleptic_gregorian": cftime.DatetimeProlepticGregorian,
+ "julian": cftime.DatetimeJulian,
+ "all_leap": cftime.DatetimeAllLeap,
+ "standard": cftime.DatetimeGregorian,
+ }
+ return calendars[calendar]
class BaseCFTimeOffset:
@@ -32,19 +116,23 @@ class BaseCFTimeOffset:
_day_option: ClassVar[DayOption | None] = None
n: int
- def __init__(self, n: int=1) ->None:
+ def __init__(self, n: int = 1) -> None:
if not isinstance(n, int):
raise TypeError(
- f"The provided multiple 'n' must be an integer. Instead a value of type {type(n)!r} was provided."
- )
+ "The provided multiple 'n' must be an integer. "
+ f"Instead a value of type {type(n)!r} was provided."
+ )
self.n = n
- def __eq__(self, other: object) ->bool:
+ def rule_code(self) -> str | None:
+ return self._freq
+
+ def __eq__(self, other: object) -> bool:
if not isinstance(other, BaseCFTimeOffset):
return NotImplemented
return self.n == other.n and self.rule_code() == other.rule_code()
- def __ne__(self, other: object) ->bool:
+ def __ne__(self, other: object) -> bool:
return not self == other
def __add__(self, other):
@@ -53,20 +141,20 @@ class BaseCFTimeOffset:
def __sub__(self, other):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
+
if isinstance(other, cftime.datetime):
- raise TypeError(
- 'Cannot subtract a cftime.datetime from a time offset.')
+ raise TypeError("Cannot subtract a cftime.datetime from a time offset.")
elif type(other) == type(self):
return type(self)(self.n - other.n)
else:
return NotImplemented
- def __mul__(self, other: int) ->Self:
+ def __mul__(self, other: int) -> Self:
if not isinstance(other, int):
return NotImplemented
return type(self)(n=other * self.n)
- def __neg__(self) ->Self:
+ def __neg__(self) -> Self:
return self * -1
def __rmul__(self, other):
@@ -77,44 +165,80 @@ class BaseCFTimeOffset:
def __rsub__(self, other):
if isinstance(other, BaseCFTimeOffset) and type(self) != type(other):
- raise TypeError('Cannot subtract cftime offsets of differing types'
- )
+ raise TypeError("Cannot subtract cftime offsets of differing types")
return -self + other
def __apply__(self, other):
return NotImplemented
- def onOffset(self, date) ->bool:
+ def onOffset(self, date) -> bool:
"""Check if the given date is in the set of possible dates created
using a length-one version of this offset class."""
- pass
+ test_date = (self + date) - self
+ return date == test_date
+
+ def rollforward(self, date):
+ if self.onOffset(date):
+ return date
+ else:
+ return date + type(self)()
+
+ def rollback(self, date):
+ if self.onOffset(date):
+ return date
+ else:
+ return date - type(self)()
def __str__(self):
- return f'<{type(self).__name__}: n={self.n}>'
+ return f"<{type(self).__name__}: n={self.n}>"
def __repr__(self):
return str(self)
+ def _get_offset_day(self, other):
+ # subclass must implement `_day_option`; calling from the base class
+ # will raise NotImplementedError.
+ return _get_day_of_month(other, self._day_option)
-class Tick(BaseCFTimeOffset):
- def __mul__(self, other: (int | float)) ->Tick:
+class Tick(BaseCFTimeOffset):
+ # analogous https://github.com/pandas-dev/pandas/blob/ccb25ab1d24c4fb9691270706a59c8d319750870/pandas/_libs/tslibs/offsets.pyx#L806
+
+ def _next_higher_resolution(self) -> Tick:
+ self_type = type(self)
+ if self_type is Day:
+ return Hour(self.n * 24)
+ if self_type is Hour:
+ return Minute(self.n * 60)
+ if self_type is Minute:
+ return Second(self.n * 60)
+ if self_type is Second:
+ return Millisecond(self.n * 1000)
+ if self_type is Millisecond:
+ return Microsecond(self.n * 1000)
+ raise ValueError("Could not convert to integer offset at any resolution")
+
+ def __mul__(self, other: int | float) -> Tick:
if not isinstance(other, (int, float)):
return NotImplemented
if isinstance(other, float):
n = other * self.n
+ # If the new `n` is an integer, we can represent it using the
+ # same BaseCFTimeOffset subclass as self, otherwise we need to move up
+ # to a higher-resolution subclass
if np.isclose(n % 1, 0):
return type(self)(int(n))
+
new_self = self._next_higher_resolution()
return new_self * other
return type(self)(n=other * self.n)
- def as_timedelta(self) ->timedelta:
+ def as_timedelta(self) -> timedelta:
"""All Tick subclasses must implement an as_timedelta method."""
- pass
+ raise NotImplementedError
-def _get_day_of_month(other, day_option: DayOption) ->int:
+def _get_day_of_month(other, day_option: DayOption) -> int:
"""Find the day in `other`'s month that satisfies a BaseCFTimeOffset's
onOffset policy, as described by the `day_option` argument.
@@ -130,34 +254,76 @@ def _get_day_of_month(other, day_option: DayOption) ->int:
day_of_month : int
"""
- pass
+
+ if day_option == "start":
+ return 1
+ if day_option == "end":
+ return _days_in_month(other)
+ if day_option is None:
+ # Note: unlike `_shift_month`, _get_day_of_month does not
+ # allow day_option = None
+ raise NotImplementedError()
+ raise ValueError(day_option)
def _days_in_month(date):
"""The number of days in the month of the given date"""
- pass
+ if date.month == 12:
+ reference = type(date)(date.year + 1, 1, 1)
+ else:
+ reference = type(date)(date.year, date.month + 1, 1)
+ return (reference - timedelta(days=1)).day
def _adjust_n_months(other_day, n, reference_day):
"""Adjust the number of times a monthly offset is applied based
on the day of a given date, and the reference day provided.
"""
- pass
+ if n > 0 and other_day < reference_day:
+ n = n - 1
+ elif n <= 0 and other_day > reference_day:
+ n = n + 1
+ return n
def _adjust_n_years(other, n, month, reference_day):
"""Adjust the number of times an annual offset is applied based on
another date, and the reference day provided"""
- pass
+ if n > 0:
+ if other.month < month or (other.month == month and other.day < reference_day):
+ n -= 1
+ else:
+ if other.month > month or (other.month == month and other.day > reference_day):
+ n += 1
+ return n
-def _shift_month(date, months, day_option: DayOption='start'):
+def _shift_month(date, months, day_option: DayOption = "start"):
"""Shift the date to a month start or end a given number of months away."""
- pass
-
-
-def roll_qtrday(other, n: int, month: int, day_option: DayOption, modby: int=3
- ) ->int:
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+
+ delta_year = (date.month + months) // 12
+ month = (date.month + months) % 12
+
+ if month == 0:
+ month = 12
+ delta_year = delta_year - 1
+ year = date.year + delta_year
+
+ if day_option == "start":
+ day = 1
+ elif day_option == "end":
+ reference = type(date)(year, month, 1)
+ day = _days_in_month(reference)
+ else:
+ raise ValueError(day_option)
+ return date.replace(year=year, month=month, day=day)
+
+
+def roll_qtrday(
+ other, n: int, month: int, day_option: DayOption, modby: int = 3
+) -> int:
"""Possibly increment or decrement the number of periods to shift
based on rollforward/rollbackward conventions.
@@ -179,66 +345,119 @@ def roll_qtrday(other, n: int, month: int, day_option: DayOption, modby: int=3
--------
_get_day_of_month : Find the day in a month provided an offset.
"""
- pass
+
+ months_since = other.month % modby - month % modby
+
+ if n > 0:
+ if months_since < 0 or (
+ months_since == 0 and other.day < _get_day_of_month(other, day_option)
+ ):
+ # pretend to roll back if on same month but
+ # before compare_day
+ n -= 1
+ else:
+ if months_since > 0 or (
+ months_since == 0 and other.day > _get_day_of_month(other, day_option)
+ ):
+ # make sure to roll forward, so negate
+ n += 1
+ return n
+
+
+def _validate_month(month: int | None, default_month: int) -> int:
+ result_month = default_month if month is None else month
+ if not isinstance(result_month, int):
+ raise TypeError(
+ "'self.month' must be an integer value between 1 "
+ "and 12. Instead, it was set to a value of "
+ f"{result_month!r}"
+ )
+ elif not (1 <= result_month <= 12):
+ raise ValueError(
+ "'self.month' must be an integer value between 1 "
+ "and 12. Instead, it was set to a value of "
+ f"{result_month!r}"
+ )
+ return result_month
class MonthBegin(BaseCFTimeOffset):
- _freq = 'MS'
+ _freq = "MS"
def __apply__(self, other):
n = _adjust_n_months(other.day, self.n, 1)
- return _shift_month(other, n, 'start')
+ return _shift_month(other, n, "start")
- def onOffset(self, date) ->bool:
+ def onOffset(self, date) -> bool:
"""Check if the given date is in the set of possible dates created
using a length-one version of this offset class."""
- pass
+ return date.day == 1
class MonthEnd(BaseCFTimeOffset):
- _freq = 'ME'
+ _freq = "ME"
def __apply__(self, other):
n = _adjust_n_months(other.day, self.n, _days_in_month(other))
- return _shift_month(other, n, 'end')
+ return _shift_month(other, n, "end")
- def onOffset(self, date) ->bool:
+ def onOffset(self, date) -> bool:
"""Check if the given date is in the set of possible dates created
using a length-one version of this offset class."""
- pass
+ return date.day == _days_in_month(date)
-_MONTH_ABBREVIATIONS = {(1): 'JAN', (2): 'FEB', (3): 'MAR', (4): 'APR', (5):
- 'MAY', (6): 'JUN', (7): 'JUL', (8): 'AUG', (9): 'SEP', (10): 'OCT', (11
- ): 'NOV', (12): 'DEC'}
+_MONTH_ABBREVIATIONS = {
+ 1: "JAN",
+ 2: "FEB",
+ 3: "MAR",
+ 4: "APR",
+ 5: "MAY",
+ 6: "JUN",
+ 7: "JUL",
+ 8: "AUG",
+ 9: "SEP",
+ 10: "OCT",
+ 11: "NOV",
+ 12: "DEC",
+}
class QuarterOffset(BaseCFTimeOffset):
"""Quarter representation copied off of pandas/tseries/offsets.py"""
+
_default_month: ClassVar[int]
month: int
- def __init__(self, n: int=1, month: (int | None)=None) ->None:
+ def __init__(self, n: int = 1, month: int | None = None) -> None:
BaseCFTimeOffset.__init__(self, n)
self.month = _validate_month(month, self._default_month)
def __apply__(self, other):
+ # months_since: find the calendar quarter containing other.month,
+ # e.g. if other.month == 8, the calendar quarter is [Jul, Aug, Sep].
+ # Then find the month in that quarter containing an onOffset date for
+ # self. `months_since` is the number of months to shift other.month
+ # to get to this on-offset month.
months_since = other.month % 3 - self.month % 3
- qtrs = roll_qtrday(other, self.n, self.month, day_option=self.
- _day_option, modby=3)
+ qtrs = roll_qtrday(
+ other, self.n, self.month, day_option=self._day_option, modby=3
+ )
months = qtrs * 3 - months_since
return _shift_month(other, months, self._day_option)
- def onOffset(self, date) ->bool:
+ def onOffset(self, date) -> bool:
"""Check if the given date is in the set of possible dates created
using a length-one version of this offset class."""
- pass
+ mod_month = (date.month - self.month) % 3
+ return mod_month == 0 and date.day == self._get_offset_day(date)
- def __sub__(self, other: Self) ->Self:
+ def __sub__(self, other: Self) -> Self:
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
+
if isinstance(other, cftime.datetime):
- raise TypeError('Cannot subtract cftime.datetime from offset.')
+ raise TypeError("Cannot subtract cftime.datetime from offset.")
if type(other) == type(self) and other.month == self.month:
return type(self)(self.n - other.n, month=self.month)
return NotImplemented
@@ -248,43 +467,68 @@ class QuarterOffset(BaseCFTimeOffset):
return NotImplemented
return type(self)(n=other * self.n, month=self.month)
+ def rule_code(self) -> str:
+ return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}"
+
def __str__(self):
- return f'<{type(self).__name__}: n={self.n}, month={self.month}>'
+ return f"<{type(self).__name__}: n={self.n}, month={self.month}>"
class QuarterBegin(QuarterOffset):
+ # When converting a string to an offset, pandas converts
+ # 'QS' to a QuarterBegin offset starting in the month of
+ # January. When creating a QuarterBegin offset directly
+ # from the constructor, however, the default month is March.
+ # We follow that behavior here.
_default_month = 3
- _freq = 'QS'
- _day_option = 'start'
+ _freq = "QS"
+ _day_option = "start"
def rollforward(self, date):
"""Roll date forward to nearest start of quarter"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date + QuarterBegin(month=self.month)
def rollback(self, date):
"""Roll date backward to nearest start of quarter"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date - QuarterBegin(month=self.month)
class QuarterEnd(QuarterOffset):
+ # When converting a string to an offset, pandas converts
+ # 'Q' to a QuarterEnd offset starting in the month of
+ # December. When creating a QuarterEnd offset directly
+ # from the constructor, however, the default month is March.
+ # We follow that behavior here.
_default_month = 3
- _freq = 'QE'
- _day_option = 'end'
+ _freq = "QE"
+ _day_option = "end"
def rollforward(self, date):
"""Roll date forward to nearest end of quarter"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date + QuarterEnd(month=self.month)
def rollback(self, date):
"""Roll date backward to nearest end of quarter"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date - QuarterEnd(month=self.month)
class YearOffset(BaseCFTimeOffset):
_default_month: ClassVar[int]
month: int
- def __init__(self, n: int=1, month: (int | None)=None) ->None:
+ def __init__(self, n: int = 1, month: int | None = None) -> None:
BaseCFTimeOffset.__init__(self, n)
self.month = _validate_month(month, self._default_month)
@@ -297,8 +541,9 @@ class YearOffset(BaseCFTimeOffset):
def __sub__(self, other):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
+
if isinstance(other, cftime.datetime):
- raise TypeError('Cannot subtract cftime.datetime from offset.')
+ raise TypeError("Cannot subtract cftime.datetime from offset.")
elif type(other) == type(self) and other.month == self.month:
return type(self)(self.n - other.n, month=self.month)
else:
@@ -309,138 +554,289 @@ class YearOffset(BaseCFTimeOffset):
return NotImplemented
return type(self)(n=other * self.n, month=self.month)
- def __str__(self) ->str:
- return f'<{type(self).__name__}: n={self.n}, month={self.month}>'
+ def rule_code(self) -> str:
+ return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}"
+
+ def __str__(self) -> str:
+ return f"<{type(self).__name__}: n={self.n}, month={self.month}>"
class YearBegin(YearOffset):
- _freq = 'YS'
- _day_option = 'start'
+ _freq = "YS"
+ _day_option = "start"
_default_month = 1
- def onOffset(self, date) ->bool:
+ def onOffset(self, date) -> bool:
"""Check if the given date is in the set of possible dates created
using a length-one version of this offset class."""
- pass
+ return date.day == 1 and date.month == self.month
def rollforward(self, date):
"""Roll date forward to nearest start of year"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date + YearBegin(month=self.month)
def rollback(self, date):
"""Roll date backward to nearest start of year"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date - YearBegin(month=self.month)
class YearEnd(YearOffset):
- _freq = 'YE'
- _day_option = 'end'
+ _freq = "YE"
+ _day_option = "end"
_default_month = 12
- def onOffset(self, date) ->bool:
+ def onOffset(self, date) -> bool:
"""Check if the given date is in the set of possible dates created
using a length-one version of this offset class."""
- pass
+ return date.day == _days_in_month(date) and date.month == self.month
def rollforward(self, date):
"""Roll date forward to nearest end of year"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date + YearEnd(month=self.month)
def rollback(self, date):
"""Roll date backward to nearest end of year"""
- pass
+ if self.onOffset(date):
+ return date
+ else:
+ return date - YearEnd(month=self.month)
class Day(Tick):
- _freq = 'D'
+ _freq = "D"
+
+ def as_timedelta(self) -> timedelta:
+ return timedelta(days=self.n)
def __apply__(self, other):
return other + self.as_timedelta()
class Hour(Tick):
- _freq = 'h'
+ _freq = "h"
+
+ def as_timedelta(self) -> timedelta:
+ return timedelta(hours=self.n)
def __apply__(self, other):
return other + self.as_timedelta()
class Minute(Tick):
- _freq = 'min'
+ _freq = "min"
+
+ def as_timedelta(self) -> timedelta:
+ return timedelta(minutes=self.n)
def __apply__(self, other):
return other + self.as_timedelta()
class Second(Tick):
- _freq = 's'
+ _freq = "s"
+
+ def as_timedelta(self) -> timedelta:
+ return timedelta(seconds=self.n)
def __apply__(self, other):
return other + self.as_timedelta()
class Millisecond(Tick):
- _freq = 'ms'
+ _freq = "ms"
+
+ def as_timedelta(self) -> timedelta:
+ return timedelta(milliseconds=self.n)
def __apply__(self, other):
return other + self.as_timedelta()
class Microsecond(Tick):
- _freq = 'us'
+ _freq = "us"
+
+ def as_timedelta(self) -> timedelta:
+ return timedelta(microseconds=self.n)
def __apply__(self, other):
return other + self.as_timedelta()
-_FREQUENCIES: Mapping[str, type[BaseCFTimeOffset]] = {'A': YearEnd, 'AS':
- YearBegin, 'Y': YearEnd, 'YE': YearEnd, 'YS': YearBegin, 'Q': partial(
- QuarterEnd, month=12), 'QE': partial(QuarterEnd, month=12), 'QS':
- partial(QuarterBegin, month=1), 'M': MonthEnd, 'ME': MonthEnd, 'MS':
- MonthBegin, 'D': Day, 'H': Hour, 'h': Hour, 'T': Minute, 'min': Minute,
- 'S': Second, 's': Second, 'L': Millisecond, 'ms': Millisecond, 'U':
- Microsecond, 'us': Microsecond, **_generate_anchored_offsets('AS',
- YearBegin), **_generate_anchored_offsets('A', YearEnd), **
- _generate_anchored_offsets('YS', YearBegin), **
- _generate_anchored_offsets('Y', YearEnd), **_generate_anchored_offsets(
- 'YE', YearEnd), **_generate_anchored_offsets('QS', QuarterBegin), **
- _generate_anchored_offsets('Q', QuarterEnd), **
- _generate_anchored_offsets('QE', QuarterEnd)}
-_FREQUENCY_CONDITION = '|'.join(_FREQUENCIES.keys())
-_PATTERN = f'^((?P<multiple>[+-]?\\d+)|())(?P<freq>({_FREQUENCY_CONDITION}))$'
-CFTIME_TICKS = Day, Hour, Minute, Second
-_DEPRECATED_FREQUENICES: dict[str, str] = {'A': 'YE', 'Y': 'YE', 'AS': 'YS',
- 'Q': 'QE', 'M': 'ME', 'H': 'h', 'T': 'min', 'S': 's', 'L': 'ms', 'U':
- 'us', **_generate_anchored_deprecated_frequencies('A', 'YE'), **
- _generate_anchored_deprecated_frequencies('Y', 'YE'), **
- _generate_anchored_deprecated_frequencies('AS', 'YS'), **
- _generate_anchored_deprecated_frequencies('Q', 'QE')}
+def _generate_anchored_offsets(
+ base_freq: str, offset: type[YearOffset | QuarterOffset]
+) -> dict[str, type[BaseCFTimeOffset]]:
+ offsets: dict[str, type[BaseCFTimeOffset]] = {}
+ for month, abbreviation in _MONTH_ABBREVIATIONS.items():
+ anchored_freq = f"{base_freq}-{abbreviation}"
+ offsets[anchored_freq] = partial(offset, month=month) # type: ignore[assignment]
+ return offsets
+
+
+_FREQUENCIES: Mapping[str, type[BaseCFTimeOffset]] = {
+ "A": YearEnd,
+ "AS": YearBegin,
+ "Y": YearEnd,
+ "YE": YearEnd,
+ "YS": YearBegin,
+ "Q": partial(QuarterEnd, month=12), # type: ignore[dict-item]
+ "QE": partial(QuarterEnd, month=12), # type: ignore[dict-item]
+ "QS": partial(QuarterBegin, month=1), # type: ignore[dict-item]
+ "M": MonthEnd,
+ "ME": MonthEnd,
+ "MS": MonthBegin,
+ "D": Day,
+ "H": Hour,
+ "h": Hour,
+ "T": Minute,
+ "min": Minute,
+ "S": Second,
+ "s": Second,
+ "L": Millisecond,
+ "ms": Millisecond,
+ "U": Microsecond,
+ "us": Microsecond,
+ **_generate_anchored_offsets("AS", YearBegin),
+ **_generate_anchored_offsets("A", YearEnd),
+ **_generate_anchored_offsets("YS", YearBegin),
+ **_generate_anchored_offsets("Y", YearEnd),
+ **_generate_anchored_offsets("YE", YearEnd),
+ **_generate_anchored_offsets("QS", QuarterBegin),
+ **_generate_anchored_offsets("Q", QuarterEnd),
+ **_generate_anchored_offsets("QE", QuarterEnd),
+}
+
+
+_FREQUENCY_CONDITION = "|".join(_FREQUENCIES.keys())
+_PATTERN = rf"^((?P<multiple>[+-]?\d+)|())(?P<freq>({_FREQUENCY_CONDITION}))$"
+
+
+# pandas defines these offsets as "Tick" objects, which for instance have
+# distinct behavior from monthly or longer frequencies in resample.
+CFTIME_TICKS = (Day, Hour, Minute, Second)
+
+
+def _generate_anchored_deprecated_frequencies(
+ deprecated: str, recommended: str
+) -> dict[str, str]:
+ pairs = {}
+ for abbreviation in _MONTH_ABBREVIATIONS.values():
+ anchored_deprecated = f"{deprecated}-{abbreviation}"
+ anchored_recommended = f"{recommended}-{abbreviation}"
+ pairs[anchored_deprecated] = anchored_recommended
+ return pairs
+
+
+_DEPRECATED_FREQUENICES: dict[str, str] = {
+ "A": "YE",
+ "Y": "YE",
+ "AS": "YS",
+ "Q": "QE",
+ "M": "ME",
+ "H": "h",
+ "T": "min",
+ "S": "s",
+ "L": "ms",
+ "U": "us",
+ **_generate_anchored_deprecated_frequencies("A", "YE"),
+ **_generate_anchored_deprecated_frequencies("Y", "YE"),
+ **_generate_anchored_deprecated_frequencies("AS", "YS"),
+ **_generate_anchored_deprecated_frequencies("Q", "QE"),
+}
+
+
_DEPRECATION_MESSAGE = (
- '{deprecated_freq!r} is deprecated and will be removed in a future version. Please use {recommended_freq!r} instead of {deprecated_freq!r}.'
+ "{deprecated_freq!r} is deprecated and will be removed in a future "
+ "version. Please use {recommended_freq!r} instead of "
+ "{deprecated_freq!r}."
+)
+
+
+def _emit_freq_deprecation_warning(deprecated_freq):
+ recommended_freq = _DEPRECATED_FREQUENICES[deprecated_freq]
+ message = _DEPRECATION_MESSAGE.format(
+ deprecated_freq=deprecated_freq, recommended_freq=recommended_freq
)
+ emit_user_level_warning(message, FutureWarning)
-def to_offset(freq: (BaseCFTimeOffset | str), warn: bool=True
- ) ->BaseCFTimeOffset:
+def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffset:
"""Convert a frequency string to the appropriate subclass of
BaseCFTimeOffset."""
- pass
+ if isinstance(freq, BaseCFTimeOffset):
+ return freq
+
+ match = re.match(_PATTERN, freq)
+ if match is None:
+ raise ValueError("Invalid frequency string provided")
+ freq_data = match.groupdict()
+
+ freq = freq_data["freq"]
+ if warn and freq in _DEPRECATED_FREQUENICES:
+ _emit_freq_deprecation_warning(freq)
+ multiples = freq_data["multiple"]
+ multiples = 1 if multiples is None else int(multiples)
+ return _FREQUENCIES[freq](n=multiples)
+
+
+def to_cftime_datetime(date_str_or_date, calendar=None):
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+
+ if isinstance(date_str_or_date, str):
+ if calendar is None:
+ raise ValueError(
+ "If converting a string to a cftime.datetime object, "
+ "a calendar type must be provided"
+ )
+ date, _ = _parse_iso8601_with_reso(get_date_type(calendar), date_str_or_date)
+ return date
+ elif isinstance(date_str_or_date, cftime.datetime):
+ return date_str_or_date
+ elif isinstance(date_str_or_date, (datetime, pd.Timestamp)):
+ return cftime.DatetimeProlepticGregorian(*date_str_or_date.timetuple())
+ else:
+ raise TypeError(
+ "date_str_or_date must be a string or a "
+ "subclass of cftime.datetime. Instead got "
+ f"{date_str_or_date!r}."
+ )
def normalize_date(date):
"""Round datetime down to midnight."""
- pass
+ return date.replace(hour=0, minute=0, second=0, microsecond=0)
def _maybe_normalize_date(date, normalize):
"""Round datetime down to midnight if normalize is True."""
- pass
+ if normalize:
+ return normalize_date(date)
+ else:
+ return date
def _generate_linear_range(start, end, periods):
"""Generate an equally-spaced sequence of cftime.datetime objects between
and including two dates (whose length equals the number of periods)."""
- pass
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+
+ total_seconds = (end - start).total_seconds()
+ values = np.linspace(0.0, total_seconds, periods, endpoint=True)
+ units = f"seconds since {format_cftime_datetime(start)}"
+ calendar = start.calendar
+ return cftime.num2date(
+ values, units=units, calendar=calendar, only_use_cftime_datetimes=True
+ )
def _generate_range(start, end, periods, offset):
@@ -465,24 +861,90 @@ def _generate_range(start, end, periods, offset):
-------
A generator object
"""
- pass
+ if start:
+ # From pandas GH 56147 / 56832 to account for negative direction and
+ # range bounds
+ if offset.n >= 0:
+ start = offset.rollforward(start)
+ else:
+ start = offset.rollback(start)
+ if periods is None and end < start and offset.n >= 0:
+ end = None
+ periods = 0
-def _translate_closed_to_inclusive(closed):
- """Follows code added in pandas #43504."""
- pass
+ if end is None:
+ end = start + (periods - 1) * offset
+ if start is None:
+ start = end - (periods - 1) * offset
-def _infer_inclusive(closed: (NoDefault | SideOptions), inclusive: (
- InclusiveOptions | None)) ->InclusiveOptions:
- """Follows code added in pandas #43504."""
- pass
+ current = start
+ if offset.n >= 0:
+ while current <= end:
+ yield current
+
+ next_date = current + offset
+ if next_date <= current:
+ raise ValueError(f"Offset {offset} did not increment date")
+ current = next_date
+ else:
+ while current >= end:
+ yield current
+ next_date = current + offset
+ if next_date >= current:
+ raise ValueError(f"Offset {offset} did not decrement date")
+ current = next_date
-def cftime_range(start=None, end=None, periods=None, freq=None, normalize=
- False, name=None, closed: (NoDefault | SideOptions)=no_default,
- inclusive: (None | InclusiveOptions)=None, calendar='standard'
- ) ->CFTimeIndex:
+
+def _translate_closed_to_inclusive(closed):
+ """Follows code added in pandas #43504."""
+ emit_user_level_warning(
+ "Following pandas, the `closed` parameter is deprecated in "
+ "favor of the `inclusive` parameter, and will be removed in "
+ "a future version of xarray.",
+ FutureWarning,
+ )
+ if closed is None:
+ inclusive = "both"
+ elif closed in ("left", "right"):
+ inclusive = closed
+ else:
+ raise ValueError(
+ f"Argument `closed` must be either 'left', 'right', or None. "
+ f"Got {closed!r}."
+ )
+ return inclusive
+
+
+def _infer_inclusive(
+ closed: NoDefault | SideOptions, inclusive: InclusiveOptions | None
+) -> InclusiveOptions:
+ """Follows code added in pandas #43504."""
+ if closed is not no_default and inclusive is not None:
+ raise ValueError(
+ "Following pandas, deprecated argument `closed` cannot be "
+ "passed if argument `inclusive` is not None."
+ )
+ if closed is not no_default:
+ return _translate_closed_to_inclusive(closed)
+ if inclusive is None:
+ return "both"
+ return inclusive
+
+
+def cftime_range(
+ start=None,
+ end=None,
+ periods=None,
+ freq=None,
+ normalize=False,
+ name=None,
+ closed: NoDefault | SideOptions = no_default,
+ inclusive: None | InclusiveOptions = None,
+ calendar="standard",
+) -> CFTimeIndex:
"""Return a fixed frequency CFTimeIndex.
Parameters
@@ -662,13 +1124,71 @@ def cftime_range(start=None, end=None, periods=None, freq=None, normalize=
--------
pandas.date_range
"""
- pass
-
-def date_range(start=None, end=None, periods=None, freq=None, tz=None,
- normalize=False, name=None, closed: (NoDefault | SideOptions)=
- no_default, inclusive: (None | InclusiveOptions)=None, calendar=
- 'standard', use_cftime=None):
+ if freq is None and any(arg is None for arg in [periods, start, end]):
+ freq = "D"
+
+ # Adapted from pandas.core.indexes.datetimes._generate_range.
+ if count_not_none(start, end, periods, freq) != 3:
+ raise ValueError(
+ "Of the arguments 'start', 'end', 'periods', and 'freq', three "
+ "must be specified at a time."
+ )
+
+ if start is not None:
+ start = to_cftime_datetime(start, calendar)
+ start = _maybe_normalize_date(start, normalize)
+ if end is not None:
+ end = to_cftime_datetime(end, calendar)
+ end = _maybe_normalize_date(end, normalize)
+
+ if freq is None:
+ dates = _generate_linear_range(start, end, periods)
+ else:
+ offset = to_offset(freq)
+ dates = np.array(list(_generate_range(start, end, periods, offset)))
+
+ inclusive = _infer_inclusive(closed, inclusive)
+
+ if inclusive == "neither":
+ left_closed = False
+ right_closed = False
+ elif inclusive == "left":
+ left_closed = True
+ right_closed = False
+ elif inclusive == "right":
+ left_closed = False
+ right_closed = True
+ elif inclusive == "both":
+ left_closed = True
+ right_closed = True
+ else:
+ raise ValueError(
+ f"Argument `inclusive` must be either 'both', 'neither', "
+ f"'left', 'right', or None. Got {inclusive}."
+ )
+
+ if not left_closed and len(dates) and start is not None and dates[0] == start:
+ dates = dates[1:]
+ if not right_closed and len(dates) and end is not None and dates[-1] == end:
+ dates = dates[:-1]
+
+ return CFTimeIndex(dates, name=name)
+
+
+def date_range(
+ start=None,
+ end=None,
+ periods=None,
+ freq=None,
+ tz=None,
+ normalize=False,
+ name=None,
+ closed: NoDefault | SideOptions = no_default,
+ inclusive: None | InclusiveOptions = None,
+ calendar="standard",
+ use_cftime=None,
+):
"""Return a fixed frequency datetime index.
The type (:py:class:`xarray.CFTimeIndex` or :py:class:`pandas.DatetimeIndex`)
@@ -724,7 +1244,136 @@ def date_range(start=None, end=None, periods=None, freq=None, tz=None,
cftime_range
date_range_like
"""
- pass
+ from xarray.coding.times import _is_standard_calendar
+
+ if tz is not None:
+ use_cftime = False
+
+ inclusive = _infer_inclusive(closed, inclusive)
+
+ if _is_standard_calendar(calendar) and use_cftime is not True:
+ try:
+ return pd.date_range(
+ start=start,
+ end=end,
+ periods=periods,
+ # TODO remove translation once requiring pandas >= 2.2
+ freq=_new_to_legacy_freq(freq),
+ tz=tz,
+ normalize=normalize,
+ name=name,
+ inclusive=inclusive,
+ )
+ except pd.errors.OutOfBoundsDatetime as err:
+ if use_cftime is False:
+ raise ValueError(
+ "Date range is invalid for pandas DatetimeIndex, try using `use_cftime=True`."
+ ) from err
+ elif use_cftime is False:
+ raise ValueError(
+ f"Invalid calendar {calendar} for pandas DatetimeIndex, try using `use_cftime=True`."
+ )
+
+ return cftime_range(
+ start=start,
+ end=end,
+ periods=periods,
+ freq=freq,
+ normalize=normalize,
+ name=name,
+ inclusive=inclusive,
+ calendar=calendar,
+ )
+
+
+def _new_to_legacy_freq(freq):
+ # xarray will now always return "ME" and "QE" for MonthEnd and QuarterEnd
+ # frequencies, but older versions of pandas do not support these as
+ # frequency strings. Until xarray's minimum pandas version is 2.2 or above,
+ # we add logic to continue using the deprecated "M" and "Q" frequency
+ # strings in these circumstances.
+
+ # NOTE: other conversions ("h" -> "H", ..., "ns" -> "N") not required
+
+ # TODO: remove once requiring pandas >= 2.2
+ if not freq or Version(pd.__version__) >= Version("2.2"):
+ return freq
+
+ try:
+ freq_as_offset = to_offset(freq)
+ except ValueError:
+ # freq may be valid in pandas but not in xarray
+ return freq
+
+ if isinstance(freq_as_offset, MonthEnd) and "ME" in freq:
+ freq = freq.replace("ME", "M")
+ elif isinstance(freq_as_offset, QuarterEnd) and "QE" in freq:
+ freq = freq.replace("QE", "Q")
+ elif isinstance(freq_as_offset, YearBegin) and "YS" in freq:
+ freq = freq.replace("YS", "AS")
+ elif isinstance(freq_as_offset, YearEnd):
+ # testing for "Y" is required as this was valid in xarray 2023.11 - 2024.01
+ if "Y-" in freq:
+ # Check for and replace "Y-" instead of just "Y" to prevent
+ # corrupting anchored offsets that contain "Y" in the month
+ # abbreviation, e.g. "Y-MAY" -> "A-MAY".
+ freq = freq.replace("Y-", "A-")
+ elif "YE-" in freq:
+ freq = freq.replace("YE-", "A-")
+ elif "A-" not in freq and freq.endswith("Y"):
+ freq = freq.replace("Y", "A")
+ elif freq.endswith("YE"):
+ freq = freq.replace("YE", "A")
+
+ return freq
+
+
+def _legacy_to_new_freq(freq):
+ # to avoid internal deprecation warnings when freq is determined using pandas < 2.2
+
+ # TODO: remove once requiring pandas >= 2.2
+
+ if not freq or Version(pd.__version__) >= Version("2.2"):
+ return freq
+
+ try:
+ freq_as_offset = to_offset(freq, warn=False)
+ except ValueError:
+ # freq may be valid in pandas but not in xarray
+ return freq
+
+ if isinstance(freq_as_offset, MonthEnd) and "ME" not in freq:
+ freq = freq.replace("M", "ME")
+ elif isinstance(freq_as_offset, QuarterEnd) and "QE" not in freq:
+ freq = freq.replace("Q", "QE")
+ elif isinstance(freq_as_offset, YearBegin) and "YS" not in freq:
+ freq = freq.replace("AS", "YS")
+ elif isinstance(freq_as_offset, YearEnd):
+ if "A-" in freq:
+ # Check for and replace "A-" instead of just "A" to prevent
+ # corrupting anchored offsets that contain "Y" in the month
+ # abbreviation, e.g. "A-MAY" -> "YE-MAY".
+ freq = freq.replace("A-", "YE-")
+ elif "Y-" in freq:
+ freq = freq.replace("Y-", "YE-")
+ elif freq.endswith("A"):
+ # the "A-MAY" case is already handled above
+ freq = freq.replace("A", "YE")
+ elif "YE" not in freq and freq.endswith("Y"):
+ # the "Y-MAY" case is already handled above
+ freq = freq.replace("Y", "YE")
+ elif isinstance(freq_as_offset, Hour):
+ freq = freq.replace("H", "h")
+ elif isinstance(freq_as_offset, Minute):
+ freq = freq.replace("T", "min")
+ elif isinstance(freq_as_offset, Second):
+ freq = freq.replace("S", "s")
+ elif isinstance(freq_as_offset, Millisecond):
+ freq = freq.replace("L", "ms")
+ elif isinstance(freq_as_offset, Microsecond):
+ freq = freq.replace("U", "us")
+
+ return freq
def date_range_like(source, calendar, use_cftime=None):
@@ -753,4 +1402,65 @@ def date_range_like(source, calendar, use_cftime=None):
last day of the month. Then the output range will also end on the last
day of the month in the new calendar.
"""
- pass
+ from xarray.coding.frequencies import infer_freq
+ from xarray.core.dataarray import DataArray
+
+ if not isinstance(source, (pd.DatetimeIndex, CFTimeIndex)) and (
+ isinstance(source, DataArray)
+ and (source.ndim != 1)
+ or not _contains_datetime_like_objects(source.variable)
+ ):
+ raise ValueError(
+ "'source' must be a 1D array of datetime objects for inferring its range."
+ )
+
+ freq = infer_freq(source)
+ if freq is None:
+ raise ValueError(
+ "`date_range_like` was unable to generate a range as the source frequency was not inferable."
+ )
+
+ # TODO remove once requiring pandas >= 2.2
+ freq = _legacy_to_new_freq(freq)
+
+ use_cftime = _should_cftime_be_used(source, calendar, use_cftime)
+
+ source_start = source.values.min()
+ source_end = source.values.max()
+
+ freq_as_offset = to_offset(freq)
+ if freq_as_offset.n < 0:
+ source_start, source_end = source_end, source_start
+
+ if is_np_datetime_like(source.dtype):
+ # We want to use datetime fields (datetime64 object don't have them)
+ source_calendar = "standard"
+ # TODO: the strict enforcement of nanosecond precision Timestamps can be
+ # relaxed when addressing GitHub issue #7493.
+ source_start = nanosecond_precision_timestamp(source_start)
+ source_end = nanosecond_precision_timestamp(source_end)
+ else:
+ if isinstance(source, CFTimeIndex):
+ source_calendar = source.calendar
+ else: # DataArray
+ source_calendar = source.dt.calendar
+
+ if calendar == source_calendar and is_np_datetime_like(source.dtype) ^ use_cftime:
+ return source
+
+ date_type = get_date_type(calendar, use_cftime)
+ start = convert_time_or_go_back(source_start, date_type)
+ end = convert_time_or_go_back(source_end, date_type)
+
+ # For the cases where the source ends on the end of the month, we expect the same in the new calendar.
+ if source_end.day == source_end.daysinmonth and isinstance(
+ freq_as_offset, (YearEnd, QuarterEnd, MonthEnd, Day)
+ ):
+ end = end.replace(day=end.daysinmonth)
+
+ return date_range(
+ start=start.isoformat(),
+ end=end.isoformat(),
+ freq=freq,
+ calendar=calendar,
+ )
diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py
index 674e36be..ef01f4cc 100644
--- a/xarray/coding/cftimeindex.py
+++ b/xarray/coding/cftimeindex.py
@@ -1,72 +1,290 @@
"""DatetimeIndex analog for cftime.datetime objects"""
+
+# The pandas.Index subclass defined here was copied and adapted for
+# use with cftime.datetime objects based on the source code defining
+# pandas.DatetimeIndex.
+
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
import math
import re
import warnings
from datetime import timedelta
from typing import TYPE_CHECKING, Any
+
import numpy as np
import pandas as pd
from packaging.version import Version
-from xarray.coding.times import _STANDARD_CALENDARS, cftime_to_nptime, infer_calendar_name
+
+from xarray.coding.times import (
+ _STANDARD_CALENDARS,
+ cftime_to_nptime,
+ infer_calendar_name,
+)
from xarray.core.common import _contains_cftime_datetimes
from xarray.core.options import OPTIONS
from xarray.core.utils import is_scalar
+
try:
import cftime
except ImportError:
cftime = None
+
if TYPE_CHECKING:
from xarray.coding.cftime_offsets import BaseCFTimeOffset
from xarray.core.types import Self
+
+
+# constants for cftimeindex.repr
CFTIME_REPR_LENGTH = 19
ITEMS_IN_REPR_MAX_ELSE_ELLIPSIS = 100
REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END = 10
+
+
OUT_OF_BOUNDS_TIMEDELTA_ERRORS: tuple[type[Exception], ...]
try:
- OUT_OF_BOUNDS_TIMEDELTA_ERRORS = (pd.errors.OutOfBoundsTimedelta,
- OverflowError)
+ OUT_OF_BOUNDS_TIMEDELTA_ERRORS = (pd.errors.OutOfBoundsTimedelta, OverflowError)
except AttributeError:
- OUT_OF_BOUNDS_TIMEDELTA_ERRORS = OverflowError,
-_BASIC_PATTERN = build_pattern(date_sep='', time_sep='')
+ OUT_OF_BOUNDS_TIMEDELTA_ERRORS = (OverflowError,)
+
+
+def named(name, pattern):
+ return "(?P<" + name + ">" + pattern + ")"
+
+
+def optional(x):
+ return "(?:" + x + ")?"
+
+
+def trailing_optional(xs):
+ if not xs:
+ return ""
+ return xs[0] + optional(trailing_optional(xs[1:]))
+
+
+def build_pattern(date_sep=r"\-", datetime_sep=r"T", time_sep=r"\:"):
+ pieces = [
+ (None, "year", r"\d{4}"),
+ (date_sep, "month", r"\d{2}"),
+ (date_sep, "day", r"\d{2}"),
+ (datetime_sep, "hour", r"\d{2}"),
+ (time_sep, "minute", r"\d{2}"),
+ (time_sep, "second", r"\d{2}"),
+ ]
+ pattern_list = []
+ for sep, name, sub_pattern in pieces:
+ pattern_list.append((sep if sep else "") + named(name, sub_pattern))
+ # TODO: allow timezone offsets?
+ return "^" + trailing_optional(pattern_list) + "$"
+
+
+_BASIC_PATTERN = build_pattern(date_sep="", time_sep="")
_EXTENDED_PATTERN = build_pattern()
-_CFTIME_PATTERN = build_pattern(datetime_sep=' ')
+_CFTIME_PATTERN = build_pattern(datetime_sep=" ")
_PATTERNS = [_BASIC_PATTERN, _EXTENDED_PATTERN, _CFTIME_PATTERN]
+def parse_iso8601_like(datetime_string):
+ for pattern in _PATTERNS:
+ match = re.match(pattern, datetime_string)
+ if match:
+ return match.groupdict()
+ raise ValueError(
+ f"no ISO-8601 or cftime-string-like match for string: {datetime_string}"
+ )
+
+
+def _parse_iso8601_with_reso(date_type, timestr):
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+
+ default = date_type(1, 1, 1)
+ result = parse_iso8601_like(timestr)
+ replace = {}
+
+ for attr in ["year", "month", "day", "hour", "minute", "second"]:
+ value = result.get(attr, None)
+ if value is not None:
+ # Note ISO8601 conventions allow for fractional seconds.
+ # TODO: Consider adding support for sub-second resolution?
+ replace[attr] = int(value)
+ resolution = attr
+ return default.replace(**replace), resolution
+
+
def _parsed_string_to_bounds(date_type, resolution, parsed):
"""Generalization of
pandas.tseries.index.DatetimeIndex._parsed_string_to_bounds
for use with non-standard calendars and cftime.datetime
objects.
"""
- pass
+ if resolution == "year":
+ return (
+ date_type(parsed.year, 1, 1),
+ date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1),
+ )
+ elif resolution == "month":
+ if parsed.month == 12:
+ end = date_type(parsed.year + 1, 1, 1) - timedelta(microseconds=1)
+ else:
+ end = date_type(parsed.year, parsed.month + 1, 1) - timedelta(
+ microseconds=1
+ )
+ return date_type(parsed.year, parsed.month, 1), end
+ elif resolution == "day":
+ start = date_type(parsed.year, parsed.month, parsed.day)
+ return start, start + timedelta(days=1, microseconds=-1)
+ elif resolution == "hour":
+ start = date_type(parsed.year, parsed.month, parsed.day, parsed.hour)
+ return start, start + timedelta(hours=1, microseconds=-1)
+ elif resolution == "minute":
+ start = date_type(
+ parsed.year, parsed.month, parsed.day, parsed.hour, parsed.minute
+ )
+ return start, start + timedelta(minutes=1, microseconds=-1)
+ elif resolution == "second":
+ start = date_type(
+ parsed.year,
+ parsed.month,
+ parsed.day,
+ parsed.hour,
+ parsed.minute,
+ parsed.second,
+ )
+ return start, start + timedelta(seconds=1, microseconds=-1)
+ else:
+ raise KeyError
def get_date_field(datetimes, field):
"""Adapted from pandas.tslib.get_date_field"""
- pass
+ return np.array([getattr(date, field) for date in datetimes], dtype=np.int64)
-def _field_accessor(name, docstring=None, min_cftime_version='0.0'):
+def _field_accessor(name, docstring=None, min_cftime_version="0.0"):
"""Adapted from pandas.tseries.index._field_accessor"""
- pass
+ def f(self, min_cftime_version=min_cftime_version):
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
-def format_row(times, indent=0, separator=', ', row_end=',\n'):
+ if Version(cftime.__version__) >= Version(min_cftime_version):
+ return get_date_field(self._data, name)
+ else:
+ raise ImportError(
+ f"The {name:!r} accessor requires a minimum "
+ f"version of cftime of {min_cftime_version}. Found an "
+ f"installed version of {cftime.__version__}."
+ )
+
+ f.__name__ = name
+ f.__doc__ = docstring
+ return property(f)
+
+
+def get_date_type(self):
+ if self._data.size:
+ return type(self._data[0])
+ else:
+ return None
+
+
+def assert_all_valid_date_type(data):
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+
+ if len(data) > 0:
+ sample = data[0]
+ date_type = type(sample)
+ if not isinstance(sample, cftime.datetime):
+ raise TypeError(
+ "CFTimeIndex requires cftime.datetime "
+ f"objects. Got object of {date_type}."
+ )
+ if not all(isinstance(value, date_type) for value in data):
+ raise TypeError(
+ "CFTimeIndex requires using datetime "
+ f"objects of all the same type. Got\n{data}."
+ )
+
+
+def format_row(times, indent=0, separator=", ", row_end=",\n"):
"""Format a single row from format_times."""
- pass
+ return indent * " " + separator.join(map(str, times)) + row_end
+
+
+def format_times(
+ index,
+ max_width,
+ offset,
+ separator=", ",
+ first_row_offset=0,
+ intermediate_row_end=",\n",
+ last_row_end="",
+):
+ """Format values of cftimeindex as pd.Index."""
+ n_per_row = max(max_width // (CFTIME_REPR_LENGTH + len(separator)), 1)
+ n_rows = math.ceil(len(index) / n_per_row)
+ representation = ""
+ for row in range(n_rows):
+ indent = first_row_offset if row == 0 else offset
+ row_end = last_row_end if row == n_rows - 1 else intermediate_row_end
+ times_for_row = index[row * n_per_row : (row + 1) * n_per_row]
+ representation += format_row(
+ times_for_row, indent=indent, separator=separator, row_end=row_end
+ )
-def format_times(index, max_width, offset, separator=', ', first_row_offset
- =0, intermediate_row_end=',\n', last_row_end=''):
- """Format values of cftimeindex as pd.Index."""
- pass
+ return representation
-def format_attrs(index, separator=', '):
+def format_attrs(index, separator=", "):
"""Format attributes of CFTimeIndex for __repr__."""
- pass
+ attrs = {
+ "dtype": f"'{index.dtype}'",
+ "length": f"{len(index)}",
+ "calendar": f"{index.calendar!r}",
+ "freq": f"{index.freq!r}",
+ }
+
+ attrs_str = [f"{k}={v}" for k, v in attrs.items()]
+ attrs_str = f"{separator}".join(attrs_str)
+ return attrs_str
class CFTimeIndex(pd.Index):
@@ -85,28 +303,30 @@ class CFTimeIndex(pd.Index):
--------
cftime_range
"""
- year = _field_accessor('year', 'The year of the datetime')
- month = _field_accessor('month', 'The month of the datetime')
- day = _field_accessor('day', 'The days of the datetime')
- hour = _field_accessor('hour', 'The hours of the datetime')
- minute = _field_accessor('minute', 'The minutes of the datetime')
- second = _field_accessor('second', 'The seconds of the datetime')
- microsecond = _field_accessor('microsecond',
- 'The microseconds of the datetime')
- dayofyear = _field_accessor('dayofyr',
- 'The ordinal day of year of the datetime', '1.0.2.1')
- dayofweek = _field_accessor('dayofwk',
- 'The day of week of the datetime', '1.0.2.1')
- days_in_month = _field_accessor('daysinmonth',
- 'The number of days in the month of the datetime', '1.1.0.0')
+
+ year = _field_accessor("year", "The year of the datetime")
+ month = _field_accessor("month", "The month of the datetime")
+ day = _field_accessor("day", "The days of the datetime")
+ hour = _field_accessor("hour", "The hours of the datetime")
+ minute = _field_accessor("minute", "The minutes of the datetime")
+ second = _field_accessor("second", "The seconds of the datetime")
+ microsecond = _field_accessor("microsecond", "The microseconds of the datetime")
+ dayofyear = _field_accessor(
+ "dayofyr", "The ordinal day of year of the datetime", "1.0.2.1"
+ )
+ dayofweek = _field_accessor("dayofwk", "The day of week of the datetime", "1.0.2.1")
+ days_in_month = _field_accessor(
+ "daysinmonth", "The number of days in the month of the datetime", "1.1.0.0"
+ )
date_type = property(get_date_type)
def __new__(cls, data, name=None, **kwargs):
assert_all_valid_date_type(data)
- if name is None and hasattr(data, 'name'):
+ if name is None and hasattr(data, "name"):
name = data.name
+
result = object.__new__(cls)
- result._data = np.array(data, dtype='O')
+ result._data = np.array(data, dtype="O")
result.name = name
result._cache = {}
return result
@@ -116,26 +336,38 @@ class CFTimeIndex(pd.Index):
Return a string representation for this object.
"""
klass_name = type(self).__name__
- display_width = OPTIONS['display_width']
+ display_width = OPTIONS["display_width"]
offset = len(klass_name) + 2
+
if len(self) <= ITEMS_IN_REPR_MAX_ELSE_ELLIPSIS:
- datastr = format_times(self.values, display_width, offset=
- offset, first_row_offset=0)
+ datastr = format_times(
+ self.values, display_width, offset=offset, first_row_offset=0
+ )
else:
- front_str = format_times(self.values[:
- REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END], display_width, offset=
- offset, first_row_offset=0, last_row_end=',')
- end_str = format_times(self.values[-
- REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END:], display_width, offset
- =offset, first_row_offset=offset)
- datastr = '\n'.join([front_str, f"{' ' * offset}...", end_str])
+ front_str = format_times(
+ self.values[:REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END],
+ display_width,
+ offset=offset,
+ first_row_offset=0,
+ last_row_end=",",
+ )
+ end_str = format_times(
+ self.values[-REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END:],
+ display_width,
+ offset=offset,
+ first_row_offset=offset,
+ )
+ datastr = "\n".join([front_str, f"{' '*offset}...", end_str])
+
attrs_str = format_attrs(self)
- full_repr_str = f'{klass_name}([{datastr}], {attrs_str})'
+ # oneliner only if smaller than display_width
+ full_repr_str = f"{klass_name}([{datastr}], {attrs_str})"
if len(full_repr_str) > display_width:
+ # if attrs_str too long, one per line
if len(attrs_str) >= display_width - offset:
- attrs_str = attrs_str.replace(',', f",\n{' ' * (offset - 2)}")
- full_repr_str = (
- f"{klass_name}([{datastr}],\n{' ' * (offset - 1)}{attrs_str})")
+ attrs_str = attrs_str.replace(",", f",\n{' '*(offset-2)}")
+ full_repr_str = f"{klass_name}([{datastr}],\n{' '*(offset-1)}{attrs_str})"
+
return full_repr_str
def _partial_date_slice(self, resolution, parsed):
@@ -181,50 +413,115 @@ class CFTimeIndex(pd.Index):
Coordinates:
* time (time) datetime64[ns] 8B 2001-01-01T01:00:00
"""
- pass
+ start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed)
+
+ times = self._data
+
+ if self.is_monotonic_increasing:
+ if len(times) and (
+ (start < times[0] and end < times[0])
+ or (start > times[-1] and end > times[-1])
+ ):
+ # we are out of range
+ raise KeyError
+
+ # a monotonic (sorted) series can be sliced
+ left = times.searchsorted(start, side="left")
+ right = times.searchsorted(end, side="right")
+ return slice(left, right)
+
+ lhs_mask = times >= start
+ rhs_mask = times <= end
+ return np.flatnonzero(lhs_mask & rhs_mask)
def _get_string_slice(self, key):
"""Adapted from pandas.tseries.index.DatetimeIndex._get_string_slice"""
- pass
+ parsed, resolution = _parse_iso8601_with_reso(self.date_type, key)
+ try:
+ loc = self._partial_date_slice(resolution, parsed)
+ except KeyError:
+ raise KeyError(key)
+ return loc
def _get_nearest_indexer(self, target, limit, tolerance):
"""Adapted from pandas.Index._get_nearest_indexer"""
- pass
+ left_indexer = self.get_indexer(target, "pad", limit=limit)
+ right_indexer = self.get_indexer(target, "backfill", limit=limit)
+ left_distances = abs(self.values[left_indexer] - target.values)
+ right_distances = abs(self.values[right_indexer] - target.values)
+
+ if self.is_monotonic_increasing:
+ condition = (left_distances < right_distances) | (right_indexer == -1)
+ else:
+ condition = (left_distances <= right_distances) | (right_indexer == -1)
+ indexer = np.where(condition, left_indexer, right_indexer)
+
+ if tolerance is not None:
+ indexer = self._filter_indexer_tolerance(target, indexer, tolerance)
+ return indexer
def _filter_indexer_tolerance(self, target, indexer, tolerance):
"""Adapted from pandas.Index._filter_indexer_tolerance"""
- pass
+ if isinstance(target, pd.Index):
+ distance = abs(self.values[indexer] - target.values)
+ else:
+ distance = abs(self.values[indexer] - target)
+ indexer = np.where(distance <= tolerance, indexer, -1)
+ return indexer
def get_loc(self, key):
"""Adapted from pandas.tseries.index.DatetimeIndex.get_loc"""
- pass
+ if isinstance(key, str):
+ return self._get_string_slice(key)
+ else:
+ return super().get_loc(key)
def _maybe_cast_slice_bound(self, label, side):
"""Adapted from
pandas.tseries.index.DatetimeIndex._maybe_cast_slice_bound
"""
- pass
+ if not isinstance(label, str):
+ return label
+
+ parsed, resolution = _parse_iso8601_with_reso(self.date_type, label)
+ start, end = _parsed_string_to_bounds(self.date_type, resolution, parsed)
+ if self.is_monotonic_decreasing and len(self) > 1:
+ return end if side == "left" else start
+ return start if side == "left" else end
+ # TODO: Add ability to use integer range outside of iloc?
+ # e.g. series[1:5].
def get_value(self, series, key):
"""Adapted from pandas.tseries.index.DatetimeIndex.get_value"""
- pass
+ if np.asarray(key).dtype == np.dtype(bool):
+ return series.iloc[key]
+ elif isinstance(key, slice):
+ return series.iloc[self.slice_indexer(key.start, key.stop, key.step)]
+ else:
+ return series.iloc[self.get_loc(key)]
- def __contains__(self, key: Any) ->bool:
+ def __contains__(self, key: Any) -> bool:
"""Adapted from
pandas.tseries.base.DatetimeIndexOpsMixin.__contains__"""
try:
result = self.get_loc(key)
- return is_scalar(result) or isinstance(result, slice
- ) or isinstance(result, np.ndarray) and result.size > 0
+ return (
+ is_scalar(result)
+ or isinstance(result, slice)
+ or (isinstance(result, np.ndarray) and result.size > 0)
+ )
except (KeyError, TypeError, ValueError):
return False
- def contains(self, key: Any) ->bool:
+ def contains(self, key: Any) -> bool:
"""Needed for .loc based partial-string indexing"""
- pass
+ return self.__contains__(key)
- def shift(self, periods: (int | float), freq: (str | timedelta |
- BaseCFTimeOffset | None)=None) ->Self:
+ def shift( # type: ignore[override] # freq is typed Any, we are more precise
+ self,
+ periods: int | float,
+ freq: str | timedelta | BaseCFTimeOffset | None = None,
+ ) -> Self:
"""Shift the CFTimeIndex a multiple of the given frequency.
See the documentation for :py:func:`~xarray.cftime_range` for a
@@ -258,14 +555,32 @@ class CFTimeIndex(pd.Index):
CFTimeIndex([2000-02-01 12:00:00],
dtype='object', length=1, calendar='standard', freq=None)
"""
- pass
+ from xarray.coding.cftime_offsets import BaseCFTimeOffset
+
+ if freq is None:
+ # None type is required to be compatible with base pd.Index class
+ raise TypeError(
+ f"`freq` argument cannot be None for {type(self).__name__}.shift"
+ )
+
+ if isinstance(freq, timedelta):
+ return self + periods * freq
+
+ if isinstance(freq, (str, BaseCFTimeOffset)):
+ from xarray.coding.cftime_offsets import to_offset
- def __add__(self, other) ->Self:
+ return self + periods * to_offset(freq)
+
+ raise TypeError(
+ f"'freq' must be of type str or datetime.timedelta, got {type(freq)}."
+ )
+
+ def __add__(self, other) -> Self:
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return type(self)(np.array(self) + other)
- def __radd__(self, other) ->Self:
+ def __radd__(self, other) -> Self:
if isinstance(other, pd.TimedeltaIndex):
other = other.to_pytimedelta()
return type(self)(other + np.array(self))
@@ -280,8 +595,9 @@ class CFTimeIndex(pd.Index):
return pd.TimedeltaIndex(np.array(self) - np.array(other))
except OUT_OF_BOUNDS_TIMEDELTA_ERRORS:
raise ValueError(
- 'The time difference exceeds the range of values that can be expressed at the nanosecond resolution.'
- )
+ "The time difference exceeds the range of values "
+ "that can be expressed at the nanosecond resolution."
+ )
return NotImplemented
def __rsub__(self, other):
@@ -289,8 +605,9 @@ class CFTimeIndex(pd.Index):
return pd.TimedeltaIndex(other - np.array(self))
except OUT_OF_BOUNDS_TIMEDELTA_ERRORS:
raise ValueError(
- 'The time difference exceeds the range of values that can be expressed at the nanosecond resolution.'
- )
+ "The time difference exceeds the range of values "
+ "that can be expressed at the nanosecond resolution."
+ )
def to_datetimeindex(self, unsafe=False):
"""If possible, convert this index to a pandas.DatetimeIndex.
@@ -331,7 +648,23 @@ class CFTimeIndex(pd.Index):
>>> times.to_datetimeindex()
DatetimeIndex(['2000-01-01', '2000-01-02'], dtype='datetime64[ns]', freq=None)
"""
- pass
+
+ if not self._data.size:
+ return pd.DatetimeIndex([])
+
+ nptimes = cftime_to_nptime(self)
+ calendar = infer_calendar_name(self)
+ if calendar not in _STANDARD_CALENDARS and not unsafe:
+ warnings.warn(
+ "Converting a CFTimeIndex with dates from a non-standard "
+ f"calendar, {calendar!r}, to a pandas.DatetimeIndex, which uses dates "
+ "from the standard calendar. This may lead to subtle errors "
+ "in operations that depend on the length of time between "
+ "dates.",
+ RuntimeWarning,
+ stacklevel=2,
+ )
+ return pd.DatetimeIndex(nptimes)
def strftime(self, date_format):
"""
@@ -361,26 +694,61 @@ class CFTimeIndex(pd.Index):
'September 01, 2000, 12:00:00 AM'],
dtype='object')
"""
- pass
+ return pd.Index([date.strftime(date_format) for date in self._data])
@property
def asi8(self):
"""Convert to integers with units of microseconds since 1970-01-01."""
- pass
+ from xarray.core.resample_cftime import exact_cftime_datetime_difference
+
+ if not self._data.size:
+ return np.array([], dtype=np.int64)
+
+ epoch = self.date_type(1970, 1, 1)
+ return np.array(
+ [
+ _total_microseconds(exact_cftime_datetime_difference(epoch, date))
+ for date in self.values
+ ],
+ dtype=np.int64,
+ )
@property
def calendar(self):
"""The calendar used by the datetimes in the index."""
- pass
+ from xarray.coding.times import infer_calendar_name
+
+ if not self._data.size:
+ return None
+
+ return infer_calendar_name(self)
@property
def freq(self):
"""The frequency used by the dates in the index."""
- pass
+ from xarray.coding.frequencies import infer_freq
+
+ # min 3 elemtents required to determine freq
+ if self._data.size < 3:
+ return None
+
+ return infer_freq(self)
def _round_via_method(self, freq, method):
"""Round dates using a specified method."""
- pass
+ from xarray.coding.cftime_offsets import CFTIME_TICKS, to_offset
+
+ if not self._data.size:
+ return CFTimeIndex(np.array(self))
+
+ offset = to_offset(freq)
+ if not isinstance(offset, CFTIME_TICKS):
+ raise ValueError(f"{offset} is a non-fixed frequency")
+
+ unit = _total_microseconds(offset.as_timedelta())
+ values = self.asi8
+ rounded = method(values, unit)
+ return _cftimeindex_from_i8(rounded, self.date_type, self.name)
def floor(self, freq):
"""Round dates down to fixed frequency.
@@ -397,7 +765,7 @@ class CFTimeIndex(pd.Index):
-------
CFTimeIndex
"""
- pass
+ return self._round_via_method(freq, _floor_int)
def ceil(self, freq):
"""Round dates up to fixed frequency.
@@ -414,7 +782,7 @@ class CFTimeIndex(pd.Index):
-------
CFTimeIndex
"""
- pass
+ return self._round_via_method(freq, _ceil_int)
def round(self, freq):
"""Round dates to a fixed frequency.
@@ -431,7 +799,12 @@ class CFTimeIndex(pd.Index):
-------
CFTimeIndex
"""
- pass
+ return self._round_via_method(freq, _round_to_nearest_half_even)
+
+
+def _parse_iso8601_without_reso(date_type, datetime_str):
+ date, _ = _parse_iso8601_with_reso(date_type, datetime_str)
+ return date
def _parse_array_of_cftime_strings(strings, date_type):
@@ -451,12 +824,15 @@ def _parse_array_of_cftime_strings(strings, date_type):
-------
np.array
"""
- pass
+ return np.array(
+ [_parse_iso8601_without_reso(date_type, s) for s in strings.ravel()]
+ ).reshape(strings.shape)
def _contains_datetime_timedeltas(array):
"""Check if an input array contains datetime.timedelta objects."""
- pass
+ array = np.atleast_1d(array)
+ return isinstance(array[0], timedelta)
def _cftimeindex_from_i8(values, date_type, name):
@@ -475,7 +851,9 @@ def _cftimeindex_from_i8(values, date_type, name):
-------
CFTimeIndex
"""
- pass
+ epoch = date_type(1970, 1, 1)
+ dates = np.array([epoch + timedelta(microseconds=int(value)) for value in values])
+ return CFTimeIndex(dates, name=name)
def _total_microseconds(delta):
@@ -490,19 +868,26 @@ def _total_microseconds(delta):
-------
int
"""
- pass
+ return delta / timedelta(microseconds=1)
def _floor_int(values, unit):
"""Copied from pandas."""
- pass
+ return values - np.remainder(values, unit)
def _ceil_int(values, unit):
"""Copied from pandas."""
- pass
+ return values + np.remainder(-values, unit)
def _round_to_nearest_half_even(values, unit):
"""Copied from pandas."""
- pass
+ if unit % 2:
+ return _ceil_int(values - unit // 2, unit)
+ quotient, remainder = np.divmod(values, unit)
+ mask = np.logical_or(
+ remainder > (unit // 2), np.logical_and(remainder == (unit // 2), quotient % 2)
+ )
+ quotient[mask] += 1
+ return quotient * unit
diff --git a/xarray/coding/frequencies.py b/xarray/coding/frequencies.py
index 1ea6ec74..b912b9a1 100644
--- a/xarray/coding/frequencies.py
+++ b/xarray/coding/frequencies.py
@@ -1,10 +1,54 @@
"""FrequencyInferer analog for cftime.datetime objects"""
+
+# The infer_freq method and the _CFTimeFrequencyInferer
+# subclass defined here were copied and adapted for
+# use with cftime.datetime objects based on the source code in
+# pandas.tseries.Frequencies._FrequencyInferer
+
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
import numpy as np
import pandas as pd
+
from xarray.coding.cftime_offsets import _MONTH_ABBREVIATIONS, _legacy_to_new_freq
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.core.common import _contains_datetime_like_objects
+
_ONE_MICRO = 1
_ONE_MILLI = _ONE_MICRO * 1000
_ONE_SECOND = _ONE_MILLI * 1000
@@ -35,18 +79,41 @@ def infer_freq(index):
ValueError
If there are fewer than three values or the index is not 1D.
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.variable import Variable
+
+ if isinstance(index, (DataArray, pd.Series)):
+ if index.ndim != 1:
+ raise ValueError("'index' must be 1D")
+ elif not _contains_datetime_like_objects(Variable("dim", index)):
+ raise ValueError("'index' must contain datetime-like objects")
+ dtype = np.asarray(index).dtype
+ if dtype == "datetime64[ns]":
+ index = pd.DatetimeIndex(index.values)
+ elif dtype == "timedelta64[ns]":
+ index = pd.TimedeltaIndex(index.values)
+ else:
+ index = CFTimeIndex(index.values)
+
+ if isinstance(index, CFTimeIndex):
+ inferer = _CFTimeFrequencyInferer(index)
+ return inferer.get_freq()
+ return _legacy_to_new_freq(pd.infer_freq(index))
-class _CFTimeFrequencyInferer:
+class _CFTimeFrequencyInferer: # (pd.tseries.frequencies._FrequencyInferer):
def __init__(self, index):
self.index = index
self.values = index.asi8
+
if len(index) < 3:
- raise ValueError('Need at least 3 dates to infer frequency')
- self.is_monotonic = (self.index.is_monotonic_decreasing or self.
- index.is_monotonic_increasing)
+ raise ValueError("Need at least 3 dates to infer frequency")
+
+ self.is_monotonic = (
+ self.index.is_monotonic_decreasing or self.index.is_monotonic_increasing
+ )
+
self._deltas = None
self._year_deltas = None
self._month_deltas = None
@@ -60,37 +127,120 @@ class _CFTimeFrequencyInferer:
-------
str or None
"""
- pass
+ if not self.is_monotonic or not self.index.is_unique:
+ return None
+
+ delta = self.deltas[0] # Smallest delta
+ if _is_multiple(delta, _ONE_DAY):
+ return self._infer_daily_rule()
+ # There is no possible intraday frequency with a non-unique delta
+ # Different from pandas: we don't need to manage DST and business offsets in cftime
+ elif not len(self.deltas) == 1:
+ return None
+
+ if _is_multiple(delta, _ONE_HOUR):
+ return _maybe_add_count("h", delta / _ONE_HOUR)
+ elif _is_multiple(delta, _ONE_MINUTE):
+ return _maybe_add_count("min", delta / _ONE_MINUTE)
+ elif _is_multiple(delta, _ONE_SECOND):
+ return _maybe_add_count("s", delta / _ONE_SECOND)
+ elif _is_multiple(delta, _ONE_MILLI):
+ return _maybe_add_count("ms", delta / _ONE_MILLI)
+ else:
+ return _maybe_add_count("us", delta / _ONE_MICRO)
+
+ def _infer_daily_rule(self):
+ annual_rule = self._get_annual_rule()
+ if annual_rule:
+ nyears = self.year_deltas[0]
+ month = _MONTH_ABBREVIATIONS[self.index[0].month]
+ alias = f"{annual_rule}-{month}"
+ return _maybe_add_count(alias, nyears)
+
+ quartely_rule = self._get_quartely_rule()
+ if quartely_rule:
+ nquarters = self.month_deltas[0] / 3
+ mod_dict = {0: 12, 2: 11, 1: 10}
+ month = _MONTH_ABBREVIATIONS[mod_dict[self.index[0].month % 3]]
+ alias = f"{quartely_rule}-{month}"
+ return _maybe_add_count(alias, nquarters)
+
+ monthly_rule = self._get_monthly_rule()
+ if monthly_rule:
+ return _maybe_add_count(monthly_rule, self.month_deltas[0])
+
+ if len(self.deltas) == 1:
+ # Daily as there is no "Weekly" offsets with CFTime
+ days = self.deltas[0] / _ONE_DAY
+ return _maybe_add_count("D", days)
+
+ # CFTime has no business freq and no "week of month" (WOM)
+ return None
+
+ def _get_annual_rule(self):
+ if len(self.year_deltas) > 1:
+ return None
+
+ if len(np.unique(self.index.month)) > 1:
+ return None
+
+ return {"cs": "YS", "ce": "YE"}.get(month_anchor_check(self.index))
+
+ def _get_quartely_rule(self):
+ if len(self.month_deltas) > 1:
+ return None
+
+ if self.month_deltas[0] % 3 != 0:
+ return None
+
+ return {"cs": "QS", "ce": "QE"}.get(month_anchor_check(self.index))
+
+ def _get_monthly_rule(self):
+ if len(self.month_deltas) > 1:
+ return None
+
+ return {"cs": "MS", "ce": "ME"}.get(month_anchor_check(self.index))
@property
def deltas(self):
"""Sorted unique timedeltas as microseconds."""
- pass
+ if self._deltas is None:
+ self._deltas = _unique_deltas(self.values)
+ return self._deltas
@property
def year_deltas(self):
"""Sorted unique year deltas."""
- pass
+ if self._year_deltas is None:
+ self._year_deltas = _unique_deltas(self.index.year)
+ return self._year_deltas
@property
def month_deltas(self):
"""Sorted unique month deltas."""
- pass
+ if self._month_deltas is None:
+ self._month_deltas = _unique_deltas(self.index.year * 12 + self.index.month)
+ return self._month_deltas
def _unique_deltas(arr):
"""Sorted unique deltas of numpy array"""
- pass
+ return np.sort(np.unique(np.diff(arr)))
def _is_multiple(us, mult: int):
"""Whether us is a multiple of mult"""
- pass
+ return us % mult == 0
def _maybe_add_count(base: str, count: float):
"""If count is greater than 1, add it to the base offset string"""
- pass
+ if count != 1:
+ assert count == int(count)
+ count = int(count)
+ return f"{count}{base}"
+ else:
+ return base
def month_anchor_check(dates):
@@ -103,4 +253,22 @@ def month_anchor_check(dates):
Replicated pandas._libs.tslibs.resolution.month_position_check
but without business offset handling.
"""
- pass
+ calendar_end = True
+ calendar_start = True
+
+ for date in dates:
+ if calendar_start:
+ calendar_start &= date.day == 1
+
+ if calendar_end:
+ cal = date.day == date.daysinmonth
+ calendar_end &= cal
+ elif not calendar_start:
+ break
+
+ if calendar_end:
+ return "ce"
+ elif calendar_start:
+ return "cs"
+ else:
+ return None
diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py
index 8a99b2a2..d16ec52d 100644
--- a/xarray/coding/strings.py
+++ b/xarray/coding/strings.py
@@ -1,14 +1,49 @@
"""Coders for strings."""
+
from __future__ import annotations
+
from functools import partial
+
import numpy as np
-from xarray.coding.variables import VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, unpack_for_decoding, unpack_for_encoding
+
+from xarray.coding.variables import (
+ VariableCoder,
+ lazy_elemwise_func,
+ pop_to,
+ safe_setitem,
+ unpack_for_decoding,
+ unpack_for_encoding,
+)
from xarray.core import indexing
from xarray.core.utils import module_available
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
-HAS_NUMPY_2_0 = module_available('numpy', minversion='2.0.0.dev0')
+
+HAS_NUMPY_2_0 = module_available("numpy", minversion="2.0.0.dev0")
+
+
+def create_vlen_dtype(element_type):
+ if element_type not in (str, bytes):
+ raise TypeError(f"unsupported type for vlen_dtype: {element_type!r}")
+ # based on h5py.special_dtype
+ return np.dtype("O", metadata={"element_type": element_type})
+
+
+def check_vlen_dtype(dtype):
+ if dtype.kind != "O" or dtype.metadata is None:
+ return None
+ else:
+ # check xarray (element_type) as well as h5py (vlen)
+ return dtype.metadata.get("element_type", dtype.metadata.get("vlen"))
+
+
+def is_unicode_dtype(dtype):
+ return dtype.kind == "U" or check_vlen_dtype(dtype) is str
+
+
+def is_bytes_dtype(dtype):
+ return dtype.kind == "S" or check_vlen_dtype(dtype) is bytes
class EncodedStringCoder(VariableCoder):
@@ -17,34 +52,168 @@ class EncodedStringCoder(VariableCoder):
def __init__(self, allows_unicode=True):
self.allows_unicode = allows_unicode
+ def encode(self, variable: Variable, name=None) -> Variable:
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+
+ contains_unicode = is_unicode_dtype(data.dtype)
+ encode_as_char = encoding.get("dtype") == "S1"
+ if encode_as_char:
+ del encoding["dtype"] # no longer relevant
+
+ if contains_unicode and (encode_as_char or not self.allows_unicode):
+ if "_FillValue" in attrs:
+ raise NotImplementedError(
+ f"variable {name!r} has a _FillValue specified, but "
+ "_FillValue is not yet supported on unicode strings: "
+ "https://github.com/pydata/xarray/issues/1647"
+ )
+
+ string_encoding = encoding.pop("_Encoding", "utf-8")
+ safe_setitem(attrs, "_Encoding", string_encoding, name=name)
+ # TODO: figure out how to handle this in a lazy way with dask
+ data = encode_string_array(data, string_encoding)
+
+ return Variable(dims, data, attrs, encoding)
+ else:
+ variable.encoding = encoding
+ return variable
+
+ def decode(self, variable: Variable, name=None) -> Variable:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+
+ if "_Encoding" in attrs:
+ string_encoding = pop_to(attrs, encoding, "_Encoding")
+ func = partial(decode_bytes_array, encoding=string_encoding)
+ data = lazy_elemwise_func(data, func, np.dtype(object))
+
+ return Variable(dims, data, attrs, encoding)
+
-def ensure_fixed_length_bytes(var: Variable) ->Variable:
+def decode_bytes_array(bytes_array, encoding="utf-8"):
+ # This is faster than using np.char.decode() or np.vectorize()
+ bytes_array = np.asarray(bytes_array)
+ decoded = [x.decode(encoding) for x in bytes_array.ravel()]
+ return np.array(decoded, dtype=object).reshape(bytes_array.shape)
+
+
+def encode_string_array(string_array, encoding="utf-8"):
+ string_array = np.asarray(string_array)
+ encoded = [x.encode(encoding) for x in string_array.ravel()]
+ return np.array(encoded, dtype=bytes).reshape(string_array.shape)
+
+
+def ensure_fixed_length_bytes(var: Variable) -> Variable:
"""Ensure that a variable with vlen bytes is converted to fixed width."""
- pass
+ if check_vlen_dtype(var.dtype) is bytes:
+ dims, data, attrs, encoding = unpack_for_encoding(var)
+ # TODO: figure out how to handle this with dask
+ data = np.asarray(data, dtype=np.bytes_)
+ return Variable(dims, data, attrs, encoding)
+ else:
+ return var
class CharacterArrayCoder(VariableCoder):
"""Transforms between arrays containing bytes and character arrays."""
+ def encode(self, variable, name=None):
+ variable = ensure_fixed_length_bytes(variable)
+
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+ if data.dtype.kind == "S" and encoding.get("dtype") is not str:
+ data = bytes_to_char(data)
+ if "char_dim_name" in encoding.keys():
+ char_dim_name = encoding.pop("char_dim_name")
+ else:
+ char_dim_name = f"string{data.shape[-1]}"
+ dims = dims + (char_dim_name,)
+ return Variable(dims, data, attrs, encoding)
+
+ def decode(self, variable, name=None):
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+
+ if data.dtype == "S1" and dims:
+ encoding["char_dim_name"] = dims[-1]
+ dims = dims[:-1]
+ data = char_to_bytes(data)
+ return Variable(dims, data, attrs, encoding)
+
def bytes_to_char(arr):
"""Convert numpy/dask arrays from fixed width bytes to characters."""
- pass
+ if arr.dtype.kind != "S":
+ raise ValueError("argument must have a fixed-width bytes dtype")
+
+ if is_chunked_array(arr):
+ chunkmanager = get_chunked_array_type(arr)
+
+ return chunkmanager.map_blocks(
+ _numpy_bytes_to_char,
+ arr,
+ dtype="S1",
+ chunks=arr.chunks + ((arr.dtype.itemsize,)),
+ new_axis=[arr.ndim],
+ )
+ return _numpy_bytes_to_char(arr)
def _numpy_bytes_to_char(arr):
"""Like netCDF4.stringtochar, but faster and more flexible."""
- pass
+ # adapt handling of copy-kwarg to numpy 2.0
+ # see https://github.com/numpy/numpy/issues/25916
+ # and https://github.com/numpy/numpy/pull/25922
+ copy = None if HAS_NUMPY_2_0 else False
+ # ensure the array is contiguous
+ arr = np.array(arr, copy=copy, order="C", dtype=np.bytes_)
+ return arr.reshape(arr.shape + (1,)).view("S1")
def char_to_bytes(arr):
"""Convert numpy/dask arrays from characters to fixed width bytes."""
- pass
+ if arr.dtype != "S1":
+ raise ValueError("argument must have dtype='S1'")
+
+ if not arr.ndim:
+ # no dimension to concatenate along
+ return arr
+
+ size = arr.shape[-1]
+
+ if not size:
+ # can't make an S0 dtype
+ return np.zeros(arr.shape[:-1], dtype=np.bytes_)
+
+ if is_chunked_array(arr):
+ chunkmanager = get_chunked_array_type(arr)
+
+ if len(arr.chunks[-1]) > 1:
+ raise ValueError(
+ "cannot stacked dask character array with "
+ f"multiple chunks in the last dimension: {arr}"
+ )
+
+ dtype = np.dtype("S" + str(arr.shape[-1]))
+ return chunkmanager.map_blocks(
+ _numpy_char_to_bytes,
+ arr,
+ dtype=dtype,
+ chunks=arr.chunks[:-1],
+ drop_axis=[arr.ndim - 1],
+ )
+ else:
+ return StackedBytesArray(arr)
def _numpy_char_to_bytes(arr):
"""Like netCDF4.chartostring, but faster and more flexible."""
- pass
+ # adapt handling of copy-kwarg to numpy 2.0
+ # see https://github.com/numpy/numpy/issues/25916
+ # and https://github.com/numpy/numpy/pull/25922
+ copy = None if HAS_NUMPY_2_0 else False
+ # based on: http://stackoverflow.com/a/10984878/809705
+ arr = np.array(arr, copy=copy, order="C")
+ dtype = "S" + str(arr.shape[-1])
+ return arr.view(dtype).reshape(arr.shape[:-1])
class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin):
@@ -63,16 +232,32 @@ class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin):
array : array-like
Original array of values to wrap.
"""
- if array.dtype != 'S1':
+ if array.dtype != "S1":
raise ValueError(
- "can only use StackedBytesArray if argument has dtype='S1'")
+ "can only use StackedBytesArray if argument has dtype='S1'"
+ )
self.array = indexing.as_indexable(array)
+ @property
+ def dtype(self):
+ return np.dtype("S" + str(self.array.shape[-1]))
+
+ @property
+ def shape(self) -> tuple[int, ...]:
+ return self.array.shape[:-1]
+
def __repr__(self):
- return f'{type(self).__name__}({self.array!r})'
+ return f"{type(self).__name__}({self.array!r})"
+
+ def _vindex_get(self, key):
+ return _numpy_char_to_bytes(self.array.vindex[key])
+
+ def _oindex_get(self, key):
+ return _numpy_char_to_bytes(self.array.oindex[key])
def __getitem__(self, key):
+ # require slicing the last dimension completely
key = type(key)(indexing.expanded_indexer(key.tuple, self.array.ndim))
if key.tuple[-1] != slice(None):
- raise IndexError('too many indices')
+ raise IndexError("too many indices")
return _numpy_char_to_bytes(self.array[key])
diff --git a/xarray/coding/times.py b/xarray/coding/times.py
index 956c93ca..badb9259 100644
--- a/xarray/coding/times.py
+++ b/xarray/coding/times.py
@@ -1,14 +1,25 @@
from __future__ import annotations
+
import re
import warnings
from collections.abc import Hashable
from datetime import datetime, timedelta
from functools import partial
from typing import Callable, Literal, Union, cast
+
import numpy as np
import pandas as pd
from pandas.errors import OutOfBoundsDatetime, OutOfBoundsTimedelta
-from xarray.coding.variables import SerializationWarning, VariableCoder, lazy_elemwise_func, pop_to, safe_setitem, unpack_for_decoding, unpack_for_encoding
+
+from xarray.coding.variables import (
+ SerializationWarning,
+ VariableCoder,
+ lazy_elemwise_func,
+ pop_to,
+ safe_setitem,
+ unpack_for_decoding,
+ unpack_for_encoding,
+)
from xarray.core import indexing
from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like
from xarray.core.duck_array_ops import asarray, ravel, reshape
@@ -19,28 +30,278 @@ from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.namedarray.utils import is_duck_dask_array
+
try:
import cftime
except ImportError:
cftime = None
+
from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray
+
T_Name = Union[Hashable, None]
-_STANDARD_CALENDARS = {'standard', 'gregorian', 'proleptic_gregorian'}
-_NS_PER_TIME_DELTA = {'ns': 1, 'us': int(1000.0), 'ms': int(1000000.0), 's':
- int(1000000000.0), 'm': int(1000000000.0) * 60, 'h': int(1000000000.0) *
- 60 * 60, 'D': int(1000000000.0) * 60 * 60 * 24}
-_US_PER_TIME_DELTA = {'microseconds': 1, 'milliseconds': 1000, 'seconds':
- 1000000, 'minutes': 60 * 1000000, 'hours': 60 * 60 * 1000000, 'days':
- 24 * 60 * 60 * 1000000}
-_NETCDF_TIME_UNITS_CFTIME = ['days', 'hours', 'minutes', 'seconds',
- 'milliseconds', 'microseconds']
-_NETCDF_TIME_UNITS_NUMPY = _NETCDF_TIME_UNITS_CFTIME + ['nanoseconds']
-TIME_UNITS = frozenset(['days', 'hours', 'minutes', 'seconds',
- 'milliseconds', 'microseconds', 'nanoseconds'])
-
-
-def decode_cf_datetime(num_dates, units: str, calendar: (str | None)=None,
- use_cftime: (bool | None)=None) ->np.ndarray:
+
+# standard calendars recognized by cftime
+_STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"}
+
+_NS_PER_TIME_DELTA = {
+ "ns": 1,
+ "us": int(1e3),
+ "ms": int(1e6),
+ "s": int(1e9),
+ "m": int(1e9) * 60,
+ "h": int(1e9) * 60 * 60,
+ "D": int(1e9) * 60 * 60 * 24,
+}
+
+_US_PER_TIME_DELTA = {
+ "microseconds": 1,
+ "milliseconds": 1_000,
+ "seconds": 1_000_000,
+ "minutes": 60 * 1_000_000,
+ "hours": 60 * 60 * 1_000_000,
+ "days": 24 * 60 * 60 * 1_000_000,
+}
+
+_NETCDF_TIME_UNITS_CFTIME = [
+ "days",
+ "hours",
+ "minutes",
+ "seconds",
+ "milliseconds",
+ "microseconds",
+]
+
+_NETCDF_TIME_UNITS_NUMPY = _NETCDF_TIME_UNITS_CFTIME + ["nanoseconds"]
+
+TIME_UNITS = frozenset(
+ [
+ "days",
+ "hours",
+ "minutes",
+ "seconds",
+ "milliseconds",
+ "microseconds",
+ "nanoseconds",
+ ]
+)
+
+
+def _is_standard_calendar(calendar: str) -> bool:
+ return calendar.lower() in _STANDARD_CALENDARS
+
+
+def _is_numpy_compatible_time_range(times):
+ if is_np_datetime_like(times.dtype):
+ return True
+ # times array contains cftime objects
+ times = np.asarray(times)
+ tmin = times.min()
+ tmax = times.max()
+ try:
+ convert_time_or_go_back(tmin, pd.Timestamp)
+ convert_time_or_go_back(tmax, pd.Timestamp)
+ except pd.errors.OutOfBoundsDatetime:
+ return False
+ except ValueError as err:
+ if err.args[0] == "year 0 is out of range":
+ return False
+ raise
+ else:
+ return True
+
+
+def _netcdf_to_numpy_timeunit(units: str) -> NPDatetimeUnitOptions:
+ units = units.lower()
+ if not units.endswith("s"):
+ units = f"{units}s"
+ return cast(
+ NPDatetimeUnitOptions,
+ {
+ "nanoseconds": "ns",
+ "microseconds": "us",
+ "milliseconds": "ms",
+ "seconds": "s",
+ "minutes": "m",
+ "hours": "h",
+ "days": "D",
+ }[units],
+ )
+
+
+def _numpy_to_netcdf_timeunit(units: NPDatetimeUnitOptions) -> str:
+ return {
+ "ns": "nanoseconds",
+ "us": "microseconds",
+ "ms": "milliseconds",
+ "s": "seconds",
+ "m": "minutes",
+ "h": "hours",
+ "D": "days",
+ }[units]
+
+
+def _ensure_padded_year(ref_date: str) -> str:
+ # Reference dates without a padded year (e.g. since 1-1-1 or since 2-3-4)
+ # are ambiguous (is it YMD or DMY?). This can lead to some very odd
+ # behaviour e.g. pandas (via dateutil) passes '1-1-1 00:00:0.0' as
+ # '2001-01-01 00:00:00' (because it assumes a) DMY and b) that year 1 is
+ # shorthand for 2001 (like 02 would be shorthand for year 2002)).
+
+ # Here we ensure that there is always a four-digit year, with the
+ # assumption being that year comes first if we get something ambiguous.
+ matches_year = re.match(r".*\d{4}.*", ref_date)
+ if matches_year:
+ # all good, return
+ return ref_date
+
+ # No four-digit strings, assume the first digits are the year and pad
+ # appropriately
+ matches_start_digits = re.match(r"(\d+)(.*)", ref_date)
+ if not matches_start_digits:
+ raise ValueError(f"invalid reference date for time units: {ref_date}")
+ ref_year, everything_else = (s for s in matches_start_digits.groups())
+ ref_date_padded = f"{int(ref_year):04d}{everything_else}"
+
+ warning_msg = (
+ f"Ambiguous reference date string: {ref_date}. The first value is "
+ "assumed to be the year hence will be padded with zeros to remove "
+ f"the ambiguity (the padded reference date string is: {ref_date_padded}). "
+ "To remove this message, remove the ambiguity by padding your reference "
+ "date strings with zeros."
+ )
+ warnings.warn(warning_msg, SerializationWarning)
+
+ return ref_date_padded
+
+
+def _unpack_netcdf_time_units(units: str) -> tuple[str, str]:
+ # CF datetime units follow the format: "UNIT since DATE"
+ # this parses out the unit and date allowing for extraneous
+ # whitespace. It also ensures that the year is padded with zeros
+ # so it will be correctly understood by pandas (via dateutil).
+ matches = re.match(r"(.+) since (.+)", units)
+ if not matches:
+ raise ValueError(f"invalid time units: {units}")
+
+ delta_units, ref_date = (s.strip() for s in matches.groups())
+ ref_date = _ensure_padded_year(ref_date)
+
+ return delta_units, ref_date
+
+
+def _unpack_time_units_and_ref_date(units: str) -> tuple[str, pd.Timestamp]:
+ # same us _unpack_netcdf_time_units but finalizes ref_date for
+ # processing in encode_cf_datetime
+ time_units, _ref_date = _unpack_netcdf_time_units(units)
+ # TODO: the strict enforcement of nanosecond precision Timestamps can be
+ # relaxed when addressing GitHub issue #7493.
+ ref_date = nanosecond_precision_timestamp(_ref_date)
+ # If the ref_date Timestamp is timezone-aware, convert to UTC and
+ # make it timezone-naive (GH 2649).
+ if ref_date.tz is not None:
+ ref_date = ref_date.tz_convert(None)
+ return time_units, ref_date
+
+
+def _decode_cf_datetime_dtype(
+ data, units: str, calendar: str, use_cftime: bool | None
+) -> np.dtype:
+ # Verify that at least the first and last date can be decoded
+ # successfully. Otherwise, tracebacks end up swallowed by
+ # Dataset.__repr__ when users try to view their lazily decoded array.
+ values = indexing.ImplicitToExplicitIndexingAdapter(indexing.as_indexable(data))
+ example_value = np.concatenate(
+ [first_n_items(values, 1) or [0], last_item(values) or [0]]
+ )
+
+ try:
+ result = decode_cf_datetime(example_value, units, calendar, use_cftime)
+ except Exception:
+ calendar_msg = (
+ "the default calendar" if calendar is None else f"calendar {calendar!r}"
+ )
+ msg = (
+ f"unable to decode time units {units!r} with {calendar_msg!r}. Try "
+ "opening your dataset with decode_times=False or installing cftime "
+ "if it is not installed."
+ )
+ raise ValueError(msg)
+ else:
+ dtype = getattr(result, "dtype", np.dtype("object"))
+
+ return dtype
+
+
+def _decode_datetime_with_cftime(
+ num_dates: np.ndarray, units: str, calendar: str
+) -> np.ndarray:
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+ if num_dates.size > 0:
+ return np.asarray(
+ cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True)
+ )
+ else:
+ return np.array([], dtype=object)
+
+
+def _decode_datetime_with_pandas(
+ flat_num_dates: np.ndarray, units: str, calendar: str
+) -> np.ndarray:
+ if not _is_standard_calendar(calendar):
+ raise OutOfBoundsDatetime(
+ f"Cannot decode times from a non-standard calendar, {calendar!r}, using "
+ "pandas."
+ )
+
+ time_units, ref_date_str = _unpack_netcdf_time_units(units)
+ time_units = _netcdf_to_numpy_timeunit(time_units)
+ try:
+ # TODO: the strict enforcement of nanosecond precision Timestamps can be
+ # relaxed when addressing GitHub issue #7493.
+ ref_date = nanosecond_precision_timestamp(ref_date_str)
+ except ValueError:
+ # ValueError is raised by pd.Timestamp for non-ISO timestamp
+ # strings, in which case we fall back to using cftime
+ raise OutOfBoundsDatetime
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "invalid value encountered", RuntimeWarning)
+ if flat_num_dates.size > 0:
+ # avoid size 0 datetimes GH1329
+ pd.to_timedelta(flat_num_dates.min(), time_units) + ref_date
+ pd.to_timedelta(flat_num_dates.max(), time_units) + ref_date
+
+ # To avoid integer overflow when converting to nanosecond units for integer
+ # dtypes smaller than np.int64 cast all integer and unsigned integer dtype
+ # arrays to np.int64 (GH 2002, GH 6589). Note this is safe even in the case
+ # of np.uint64 values, because any np.uint64 value that would lead to
+ # overflow when converting to np.int64 would not be representable with a
+ # timedelta64 value, and therefore would raise an error in the lines above.
+ if flat_num_dates.dtype.kind in "iu":
+ flat_num_dates = flat_num_dates.astype(np.int64)
+ elif flat_num_dates.dtype.kind in "f":
+ flat_num_dates = flat_num_dates.astype(np.float64)
+
+ # Cast input ordinals to integers of nanoseconds because pd.to_timedelta
+ # works much faster when dealing with integers (GH 1399).
+ # properly handle NaN/NaT to prevent casting NaN to int
+ nan = np.isnan(flat_num_dates) | (flat_num_dates == np.iinfo(np.int64).min)
+ flat_num_dates = flat_num_dates * _NS_PER_TIME_DELTA[time_units]
+ flat_num_dates_ns_int = np.zeros_like(flat_num_dates, dtype=np.int64)
+ flat_num_dates_ns_int[nan] = np.iinfo(np.int64).min
+ flat_num_dates_ns_int[~nan] = flat_num_dates[~nan].astype(np.int64)
+
+ # Use pd.to_timedelta to safely cast integer values to timedeltas,
+ # and add those to a Timestamp to safely produce a DatetimeIndex. This
+ # ensures that we do not encounter integer overflow at any point in the
+ # process without raising OutOfBoundsDatetime.
+ return (pd.to_timedelta(flat_num_dates_ns_int, "ns") + ref_date).values
+
+
+def decode_cf_datetime(
+ num_dates, units: str, calendar: str | None = None, use_cftime: bool | None = None
+) -> np.ndarray:
"""Given an array of numeric dates in netCDF format, convert it into a
numpy array of date time objects.
@@ -55,55 +316,186 @@ def decode_cf_datetime(num_dates, units: str, calendar: (str | None)=None,
--------
cftime.num2date
"""
- pass
+ num_dates = np.asarray(num_dates)
+ flat_num_dates = ravel(num_dates)
+ if calendar is None:
+ calendar = "standard"
+
+ if use_cftime is None:
+ try:
+ dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar)
+ except (KeyError, OutOfBoundsDatetime, OutOfBoundsTimedelta, OverflowError):
+ dates = _decode_datetime_with_cftime(
+ flat_num_dates.astype(float), units, calendar
+ )
+
+ if (
+ dates[np.nanargmin(num_dates)].year < 1678
+ or dates[np.nanargmax(num_dates)].year >= 2262
+ ):
+ if _is_standard_calendar(calendar):
+ warnings.warn(
+ "Unable to decode time axis into full "
+ "numpy.datetime64 objects, continuing using "
+ "cftime.datetime objects instead, reason: dates out "
+ "of range",
+ SerializationWarning,
+ stacklevel=3,
+ )
+ else:
+ if _is_standard_calendar(calendar):
+ dates = cftime_to_nptime(dates)
+ elif use_cftime:
+ dates = _decode_datetime_with_cftime(flat_num_dates, units, calendar)
+ else:
+ dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar)
+ return reshape(dates, num_dates.shape)
-def decode_cf_timedelta(num_timedeltas, units: str) ->np.ndarray:
+
+def to_timedelta_unboxed(value, **kwargs):
+ result = pd.to_timedelta(value, **kwargs).to_numpy()
+ assert result.dtype == "timedelta64[ns]"
+ return result
+
+
+def to_datetime_unboxed(value, **kwargs):
+ result = pd.to_datetime(value, **kwargs).to_numpy()
+ assert result.dtype == "datetime64[ns]"
+ return result
+
+
+def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray:
"""Given an array of numeric timedeltas in netCDF format, convert it into a
numpy timedelta64[ns] array.
"""
- pass
+ num_timedeltas = np.asarray(num_timedeltas)
+ units = _netcdf_to_numpy_timeunit(units)
+ result = to_timedelta_unboxed(ravel(num_timedeltas), unit=units)
+ return reshape(result, num_timedeltas.shape)
+
+
+def _unit_timedelta_cftime(units: str) -> timedelta:
+ return timedelta(microseconds=_US_PER_TIME_DELTA[units])
+
+def _unit_timedelta_numpy(units: str) -> np.timedelta64:
+ numpy_units = _netcdf_to_numpy_timeunit(units)
+ return np.timedelta64(_NS_PER_TIME_DELTA[numpy_units], "ns")
-def infer_calendar_name(dates) ->CFCalendar:
+
+def _infer_time_units_from_diff(unique_timedeltas) -> str:
+ unit_timedelta: Callable[[str], timedelta] | Callable[[str], np.timedelta64]
+ zero_timedelta: timedelta | np.timedelta64
+ if unique_timedeltas.dtype == np.dtype("O"):
+ time_units = _NETCDF_TIME_UNITS_CFTIME
+ unit_timedelta = _unit_timedelta_cftime
+ zero_timedelta = timedelta(microseconds=0)
+ else:
+ time_units = _NETCDF_TIME_UNITS_NUMPY
+ unit_timedelta = _unit_timedelta_numpy
+ zero_timedelta = np.timedelta64(0, "ns")
+ for time_unit in time_units:
+ if np.all(unique_timedeltas % unit_timedelta(time_unit) == zero_timedelta):
+ return time_unit
+ return "seconds"
+
+
+def _time_units_to_timedelta64(units: str) -> np.timedelta64:
+ return np.timedelta64(1, _netcdf_to_numpy_timeunit(units)).astype("timedelta64[ns]")
+
+
+def infer_calendar_name(dates) -> CFCalendar:
"""Given an array of datetimes, infer the CF calendar name"""
- pass
+ if is_np_datetime_like(dates.dtype):
+ return "proleptic_gregorian"
+ elif dates.dtype == np.dtype("O") and dates.size > 0:
+ # Logic copied from core.common.contains_cftime_datetimes.
+ if cftime is not None:
+ sample = np.asarray(dates).flat[0]
+ if is_duck_dask_array(sample):
+ sample = sample.compute()
+ if isinstance(sample, np.ndarray):
+ sample = sample.item()
+ if isinstance(sample, cftime.datetime):
+ return sample.calendar
+
+ # Error raise if dtype is neither datetime or "O", if cftime is not importable, and if element of 'O' dtype is not cftime.
+ raise ValueError("Array does not contain datetime objects.")
-def infer_datetime_units(dates) ->str:
+def infer_datetime_units(dates) -> str:
"""Given an array of datetimes, returns a CF compatible time-unit string of
the form "{time_unit} since {date[0]}", where `time_unit` is 'days',
'hours', 'minutes' or 'seconds' (the first one that can evenly divide all
unique time deltas in `dates`)
"""
- pass
+ dates = ravel(np.asarray(dates))
+ if np.asarray(dates).dtype == "datetime64[ns]":
+ dates = to_datetime_unboxed(dates)
+ dates = dates[pd.notnull(dates)]
+ reference_date = dates[0] if len(dates) > 0 else "1970-01-01"
+ # TODO: the strict enforcement of nanosecond precision Timestamps can be
+ # relaxed when addressing GitHub issue #7493.
+ reference_date = nanosecond_precision_timestamp(reference_date)
+ else:
+ reference_date = dates[0] if len(dates) > 0 else "1970-01-01"
+ reference_date = format_cftime_datetime(reference_date)
+ unique_timedeltas = np.unique(np.diff(dates))
+ units = _infer_time_units_from_diff(unique_timedeltas)
+ return f"{units} since {reference_date}"
-def format_cftime_datetime(date) ->str:
+def format_cftime_datetime(date) -> str:
"""Converts a cftime.datetime object to a string with the format:
YYYY-MM-DD HH:MM:SS.UUUUUU
"""
- pass
+ return f"{date.year:04d}-{date.month:02d}-{date.day:02d} {date.hour:02d}:{date.minute:02d}:{date.second:02d}.{date.microsecond:06d}"
-def infer_timedelta_units(deltas) ->str:
+def infer_timedelta_units(deltas) -> str:
"""Given an array of timedeltas, returns a CF compatible time-unit from
{'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly
divide all unique time deltas in `deltas`)
"""
- pass
+ deltas = to_timedelta_unboxed(ravel(np.asarray(deltas)))
+ unique_timedeltas = np.unique(deltas[pd.notnull(deltas)])
+ return _infer_time_units_from_diff(unique_timedeltas)
-def cftime_to_nptime(times, raise_on_invalid: bool=True) ->np.ndarray:
+def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray:
"""Given an array of cftime.datetime objects, return an array of
numpy.datetime64 objects of the same size
If raise_on_invalid is True (default), invalid dates trigger a ValueError.
Otherwise, the invalid element is replaced by np.NaT."""
- pass
+ times = np.asarray(times)
+ # TODO: the strict enforcement of nanosecond precision datetime values can
+ # be relaxed when addressing GitHub issue #7493.
+ new = np.empty(times.shape, dtype="M8[ns]")
+ dt: pd.Timestamp | Literal["NaT"]
+ for i, t in np.ndenumerate(times):
+ try:
+ # Use pandas.Timestamp in place of datetime.datetime, because
+ # NumPy casts it safely it np.datetime64[ns] for dates outside
+ # 1678 to 2262 (this is not currently the case for
+ # datetime.datetime).
+ dt = nanosecond_precision_timestamp(
+ t.year, t.month, t.day, t.hour, t.minute, t.second, t.microsecond
+ )
+ except ValueError as e:
+ if raise_on_invalid:
+ raise ValueError(
+ f"Cannot convert date {t} to a date in the "
+ f"standard calendar. Reason: {e}."
+ )
+ else:
+ dt = "NaT"
+ new[i] = np.datetime64(dt)
+ return new
-def convert_times(times, date_type, raise_on_invalid: bool=True) ->np.ndarray:
+def convert_times(times, date_type, raise_on_invalid: bool = True) -> np.ndarray:
"""Given an array of datetimes, return the same dates in another cftime or numpy date type.
Useful to convert between calendars in numpy and cftime or between cftime calendars.
@@ -111,7 +503,30 @@ def convert_times(times, date_type, raise_on_invalid: bool=True) ->np.ndarray:
If raise_on_valid is True (default), invalid dates trigger a ValueError.
Otherwise, the invalid element is replaced by np.nan for cftime types and np.NaT for np.datetime64.
"""
- pass
+ if date_type in (pd.Timestamp, np.datetime64) and not is_np_datetime_like(
+ times.dtype
+ ):
+ return cftime_to_nptime(times, raise_on_invalid=raise_on_invalid)
+ if is_np_datetime_like(times.dtype):
+ # Convert datetime64 objects to Timestamps since those have year, month, day, etc. attributes
+ times = pd.DatetimeIndex(times)
+ new = np.empty(times.shape, dtype="O")
+ for i, t in enumerate(times):
+ try:
+ dt = date_type(
+ t.year, t.month, t.day, t.hour, t.minute, t.second, t.microsecond
+ )
+ except ValueError as e:
+ if raise_on_invalid:
+ raise ValueError(
+ f"Cannot convert date {t} to a date in the "
+ f"{date_type(2000, 1, 1).calendar} calendar. Reason: {e}."
+ )
+ else:
+ dt = np.nan
+
+ new[i] = dt
+ return new
def convert_time_or_go_back(date, date_type):
@@ -122,11 +537,50 @@ def convert_time_or_go_back(date, date_type):
This is meant to convert end-of-month dates into a new calendar.
"""
- pass
+ # TODO: the strict enforcement of nanosecond precision Timestamps can be
+ # relaxed when addressing GitHub issue #7493.
+ if date_type == pd.Timestamp:
+ date_type = nanosecond_precision_timestamp
+ try:
+ return date_type(
+ date.year,
+ date.month,
+ date.day,
+ date.hour,
+ date.minute,
+ date.second,
+ date.microsecond,
+ )
+ except OutOfBoundsDatetime:
+ raise
+ except ValueError:
+ # Day is invalid, happens at the end of months, try again the day before
+ try:
+ return date_type(
+ date.year,
+ date.month,
+ date.day - 1,
+ date.hour,
+ date.minute,
+ date.second,
+ date.microsecond,
+ )
+ except ValueError:
+ # Still invalid, happens for 360_day to non-leap february. Try again 2 days before date.
+ return date_type(
+ date.year,
+ date.month,
+ date.day - 2,
+ date.hour,
+ date.minute,
+ date.second,
+ date.microsecond,
+ )
-def _should_cftime_be_used(source, target_calendar: str, use_cftime: (bool |
- None)) ->bool:
+def _should_cftime_be_used(
+ source, target_calendar: str, use_cftime: bool | None
+) -> bool:
"""Return whether conversion of the source to the target calendar should
result in a cftime-backed array.
@@ -134,22 +588,127 @@ def _should_cftime_be_used(source, target_calendar: str, use_cftime: (bool |
use_cftime is a boolean or None. If use_cftime is None, this returns True
if the source's range and target calendar are convertible to np.datetime64 objects.
"""
- pass
+ # Arguments Checks for target
+ if use_cftime is not True:
+ if _is_standard_calendar(target_calendar):
+ if _is_numpy_compatible_time_range(source):
+ # Conversion is possible with pandas, force False if it was None
+ return False
+ elif use_cftime is False:
+ raise ValueError(
+ "Source time range is not valid for numpy datetimes. Try using `use_cftime=True`."
+ )
+ elif use_cftime is False:
+ raise ValueError(
+ f"Calendar '{target_calendar}' is only valid with cftime. Try using `use_cftime=True`."
+ )
+ return True
+
+def _cleanup_netcdf_time_units(units: str) -> str:
+ time_units, ref_date = _unpack_netcdf_time_units(units)
+ time_units = time_units.lower()
+ if not time_units.endswith("s"):
+ time_units = f"{time_units}s"
+ try:
+ units = f"{time_units} since {format_timestamp(ref_date)}"
+ except (OutOfBoundsDatetime, ValueError):
+ # don't worry about reifying the units if they're out of bounds or
+ # formatted badly
+ pass
+ return units
-def _encode_datetime_with_cftime(dates, units: str, calendar: str
- ) ->np.ndarray:
+
+def _encode_datetime_with_cftime(dates, units: str, calendar: str) -> np.ndarray:
"""Fallback method for encoding dates using cftime.
This method is more flexible than xarray's parsing using datetime64[ns]
arrays but also slower because it loops over each element.
"""
- pass
+ if cftime is None:
+ raise ModuleNotFoundError("No module named 'cftime'")
+
+ if np.issubdtype(dates.dtype, np.datetime64):
+ # numpy's broken datetime conversion only works for us precision
+ dates = dates.astype("M8[us]").astype(datetime)
+
+ def encode_datetime(d):
+ # Since netCDF files do not support storing float128 values, we ensure
+ # that float64 values are used by setting longdouble=False in num2date.
+ # This try except logic can be removed when xarray's minimum version of
+ # cftime is at least 1.6.2.
+ try:
+ return (
+ np.nan
+ if d is None
+ else cftime.date2num(d, units, calendar, longdouble=False)
+ )
+ except TypeError:
+ return np.nan if d is None else cftime.date2num(d, units, calendar)
+
+ return reshape(np.array([encode_datetime(d) for d in ravel(dates)]), dates.shape)
+
+
+def cast_to_int_if_safe(num) -> np.ndarray:
+ int_num = np.asarray(num, dtype=np.int64)
+ if (num == int_num).all():
+ num = int_num
+ return num
+
+
+def _division(deltas, delta, floor):
+ if floor:
+ # calculate int64 floor division
+ # to preserve integer dtype if possible (GH 4045, GH7817).
+ num = deltas // delta.astype(np.int64)
+ num = num.astype(np.int64, copy=False)
+ else:
+ num = deltas / delta
+ return num
+
+def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", message="overflow")
+ cast_num = np.asarray(num, dtype=dtype)
-def encode_cf_datetime(dates: T_DuckArray, units: (str | None)=None,
- calendar: (str | None)=None, dtype: (np.dtype | None)=None) ->tuple[
- T_DuckArray, str, str]:
+ if np.issubdtype(dtype, np.integer):
+ if not (num == cast_num).all():
+ if np.issubdtype(num.dtype, np.floating):
+ raise ValueError(
+ f"Not possible to cast all encoded times from "
+ f"{num.dtype!r} to {dtype!r} without losing precision. "
+ f"Consider modifying the units such that integer values "
+ f"can be used, or removing the units and dtype encoding, "
+ f"at which point xarray will make an appropriate choice."
+ )
+ else:
+ raise OverflowError(
+ f"Not possible to cast encoded times from "
+ f"{num.dtype!r} to {dtype!r} without overflow. Consider "
+ f"removing the dtype encoding, at which point xarray will "
+ f"make an appropriate choice, or explicitly switching to "
+ "a larger integer dtype."
+ )
+ else:
+ if np.isinf(cast_num).any():
+ raise OverflowError(
+ f"Not possible to cast encoded times from {num.dtype!r} to "
+ f"{dtype!r} without overflow. Consider removing the dtype "
+ f"encoding, at which point xarray will make an appropriate "
+ f"choice, or explicitly switching to a larger floating point "
+ f"dtype."
+ )
+
+ return cast_num
+
+
+def encode_cf_datetime(
+ dates: T_DuckArray, # type: ignore
+ units: str | None = None,
+ calendar: str | None = None,
+ dtype: np.dtype | None = None,
+) -> tuple[T_DuckArray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
@@ -159,14 +718,315 @@ def encode_cf_datetime(dates: T_DuckArray, units: (str | None)=None,
--------
cftime.date2num
"""
- pass
+ dates = asarray(dates)
+ if is_chunked_array(dates):
+ return _lazily_encode_cf_datetime(dates, units, calendar, dtype)
+ else:
+ return _eagerly_encode_cf_datetime(dates, units, calendar, dtype)
-class CFDatetimeCoder(VariableCoder):
+def _eagerly_encode_cf_datetime(
+ dates: T_DuckArray, # type: ignore
+ units: str | None = None,
+ calendar: str | None = None,
+ dtype: np.dtype | None = None,
+ allow_units_modification: bool = True,
+) -> tuple[T_DuckArray, str, str]:
+ dates = asarray(dates)
+
+ data_units = infer_datetime_units(dates)
+
+ if units is None:
+ units = data_units
+ else:
+ units = _cleanup_netcdf_time_units(units)
+
+ if calendar is None:
+ calendar = infer_calendar_name(dates)
+
+ try:
+ if not _is_standard_calendar(calendar) or dates.dtype.kind == "O":
+ # parse with cftime instead
+ raise OutOfBoundsDatetime
+ assert dates.dtype == "datetime64[ns]"
+
+ time_units, ref_date = _unpack_time_units_and_ref_date(units)
+ time_delta = _time_units_to_timedelta64(time_units)
+
+ # Wrap the dates in a DatetimeIndex to do the subtraction to ensure
+ # an OverflowError is raised if the ref_date is too far away from
+ # dates to be encoded (GH 2272).
+ dates_as_index = pd.DatetimeIndex(ravel(dates))
+ time_deltas = dates_as_index - ref_date
+
+ # retrieve needed units to faithfully encode to int64
+ needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units)
+ if data_units != units:
+ # this accounts for differences in the reference times
+ ref_delta = abs(data_ref_date - ref_date).to_timedelta64()
+ data_delta = _time_units_to_timedelta64(needed_units)
+ if (ref_delta % data_delta) > np.timedelta64(0, "ns"):
+ needed_units = _infer_time_units_from_diff(ref_delta)
+
+ # needed time delta to encode faithfully to int64
+ needed_time_delta = _time_units_to_timedelta64(needed_units)
+
+ floor_division = True
+ if time_delta > needed_time_delta:
+ floor_division = False
+ if dtype is None:
+ emit_user_level_warning(
+ f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
+ f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. "
+ f"Set encoding['dtype'] to integer dtype to serialize to int64. "
+ f"Set encoding['dtype'] to floating point dtype to silence this warning."
+ )
+ elif np.issubdtype(dtype, np.integer) and allow_units_modification:
+ new_units = f"{needed_units} since {format_timestamp(ref_date)}"
+ emit_user_level_warning(
+ f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
+ f"Serializing with units {new_units!r} instead. "
+ f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
+ f"Set encoding['units'] to {new_units!r} to silence this warning ."
+ )
+ units = new_units
+ time_delta = needed_time_delta
+ floor_division = True
+
+ num = _division(time_deltas, time_delta, floor_division)
+ num = reshape(num.values, dates.shape)
+
+ except (OutOfBoundsDatetime, OverflowError, ValueError):
+ num = _encode_datetime_with_cftime(dates, units, calendar)
+ # do it now only for cftime-based flow
+ # we already covered for this in pandas-based flow
+ num = cast_to_int_if_safe(num)
+
+ if dtype is not None:
+ num = _cast_to_dtype_if_safe(num, dtype)
+
+ return num, units, calendar
+
+
+def _encode_cf_datetime_within_map_blocks(
+ dates: T_DuckArray, # type: ignore
+ units: str,
+ calendar: str,
+ dtype: np.dtype,
+) -> T_DuckArray:
+ num, *_ = _eagerly_encode_cf_datetime(
+ dates, units, calendar, dtype, allow_units_modification=False
+ )
+ return num
+
+
+def _lazily_encode_cf_datetime(
+ dates: T_ChunkedArray,
+ units: str | None = None,
+ calendar: str | None = None,
+ dtype: np.dtype | None = None,
+) -> tuple[T_ChunkedArray, str, str]:
+ if calendar is None:
+ # This will only trigger minor compute if dates is an object dtype array.
+ calendar = infer_calendar_name(dates)
+
+ if units is None and dtype is None:
+ if dates.dtype == "O":
+ units = "microseconds since 1970-01-01"
+ dtype = np.dtype("int64")
+ else:
+ units = "nanoseconds since 1970-01-01"
+ dtype = np.dtype("int64")
+
+ if units is None or dtype is None:
+ raise ValueError(
+ f"When encoding chunked arrays of datetime values, both the units "
+ f"and dtype must be prescribed or both must be unprescribed. "
+ f"Prescribing only one or the other is not currently supported. "
+ f"Got a units encoding of {units} and a dtype encoding of {dtype}."
+ )
+
+ chunkmanager = get_chunked_array_type(dates)
+ num = chunkmanager.map_blocks(
+ _encode_cf_datetime_within_map_blocks,
+ dates,
+ units,
+ calendar,
+ dtype,
+ dtype=dtype,
+ )
+ return num, units, calendar
- def __init__(self, use_cftime: (bool | None)=None) ->None:
+
+def encode_cf_timedelta(
+ timedeltas: T_DuckArray, # type: ignore
+ units: str | None = None,
+ dtype: np.dtype | None = None,
+) -> tuple[T_DuckArray, str]:
+ timedeltas = asarray(timedeltas)
+ if is_chunked_array(timedeltas):
+ return _lazily_encode_cf_timedelta(timedeltas, units, dtype)
+ else:
+ return _eagerly_encode_cf_timedelta(timedeltas, units, dtype)
+
+
+def _eagerly_encode_cf_timedelta(
+ timedeltas: T_DuckArray, # type: ignore
+ units: str | None = None,
+ dtype: np.dtype | None = None,
+ allow_units_modification: bool = True,
+) -> tuple[T_DuckArray, str]:
+ data_units = infer_timedelta_units(timedeltas)
+
+ if units is None:
+ units = data_units
+
+ time_delta = _time_units_to_timedelta64(units)
+ time_deltas = pd.TimedeltaIndex(ravel(timedeltas))
+
+ # retrieve needed units to faithfully encode to int64
+ needed_units = data_units
+ if data_units != units:
+ needed_units = _infer_time_units_from_diff(np.unique(time_deltas.dropna()))
+
+ # needed time delta to encode faithfully to int64
+ needed_time_delta = _time_units_to_timedelta64(needed_units)
+
+ floor_division = True
+ if time_delta > needed_time_delta:
+ floor_division = False
+ if dtype is None:
+ emit_user_level_warning(
+ f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. "
+ f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. "
+ f"Set encoding['dtype'] to integer dtype to serialize to int64. "
+ f"Set encoding['dtype'] to floating point dtype to silence this warning."
+ )
+ elif np.issubdtype(dtype, np.integer) and allow_units_modification:
+ emit_user_level_warning(
+ f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
+ f"Serializing with units {needed_units!r} instead. "
+ f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
+ f"Set encoding['units'] to {needed_units!r} to silence this warning ."
+ )
+ units = needed_units
+ time_delta = needed_time_delta
+ floor_division = True
+
+ num = _division(time_deltas, time_delta, floor_division)
+ num = reshape(num.values, timedeltas.shape)
+
+ if dtype is not None:
+ num = _cast_to_dtype_if_safe(num, dtype)
+
+ return num, units
+
+
+def _encode_cf_timedelta_within_map_blocks(
+ timedeltas: T_DuckArray, # type:ignore
+ units: str,
+ dtype: np.dtype,
+) -> T_DuckArray:
+ num, _ = _eagerly_encode_cf_timedelta(
+ timedeltas, units, dtype, allow_units_modification=False
+ )
+ return num
+
+
+def _lazily_encode_cf_timedelta(
+ timedeltas: T_ChunkedArray, units: str | None = None, dtype: np.dtype | None = None
+) -> tuple[T_ChunkedArray, str]:
+ if units is None and dtype is None:
+ units = "nanoseconds"
+ dtype = np.dtype("int64")
+
+ if units is None or dtype is None:
+ raise ValueError(
+ f"When encoding chunked arrays of timedelta values, both the "
+ f"units and dtype must be prescribed or both must be "
+ f"unprescribed. Prescribing only one or the other is not "
+ f"currently supported. Got a units encoding of {units} and a "
+ f"dtype encoding of {dtype}."
+ )
+
+ chunkmanager = get_chunked_array_type(timedeltas)
+ num = chunkmanager.map_blocks(
+ _encode_cf_timedelta_within_map_blocks,
+ timedeltas,
+ units,
+ dtype,
+ dtype=dtype,
+ )
+ return num, units
+
+
+class CFDatetimeCoder(VariableCoder):
+ def __init__(self, use_cftime: bool | None = None) -> None:
self.use_cftime = use_cftime
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if np.issubdtype(
+ variable.data.dtype, np.datetime64
+ ) or contains_cftime_datetimes(variable):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+
+ units = encoding.pop("units", None)
+ calendar = encoding.pop("calendar", None)
+ dtype = encoding.get("dtype", None)
+ (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)
+
+ safe_setitem(attrs, "units", units, name=name)
+ safe_setitem(attrs, "calendar", calendar, name=name)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ units = variable.attrs.get("units", None)
+ if isinstance(units, str) and "since" in units:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+
+ units = pop_to(attrs, encoding, "units")
+ calendar = pop_to(attrs, encoding, "calendar")
+ dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
+ transform = partial(
+ decode_cf_datetime,
+ units=units,
+ calendar=calendar,
+ use_cftime=self.use_cftime,
+ )
+ data = lazy_elemwise_func(data, transform, dtype)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
class CFTimedeltaCoder(VariableCoder):
- pass
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if np.issubdtype(variable.data.dtype, np.timedelta64):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+
+ data, units = encode_cf_timedelta(
+ data, encoding.pop("units", None), encoding.get("dtype", None)
+ )
+ safe_setitem(attrs, "units", units, name=name)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ units = variable.attrs.get("units", None)
+ if isinstance(units, str) and units in TIME_UNITS:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+
+ units = pop_to(attrs, encoding, "units")
+ transform = partial(decode_cf_timedelta, units=units)
+ dtype = np.dtype("timedelta64[ns]")
+ data = lazy_elemwise_func(data, transform, dtype=dtype)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py
index efe7890c..8a3afe65 100644
--- a/xarray/coding/variables.py
+++ b/xarray/coding/variables.py
@@ -1,15 +1,20 @@
"""Coders for individual Variable objects."""
+
from __future__ import annotations
+
import warnings
from collections.abc import Hashable, MutableMapping
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union
+
import numpy as np
import pandas as pd
+
from xarray.core import dtypes, duck_array_ops, indexing
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
+
if TYPE_CHECKING:
T_VarTuple = tuple[tuple[Hashable, ...], Any, dict, dict]
T_Name = Union[Hashable, None]
@@ -36,13 +41,13 @@ class VariableCoder:
variables in the underlying store.
"""
- def encode(self, variable: Variable, name: T_Name=None) ->Variable:
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
"""Convert an encoded variable to a decoded variable"""
- pass
+ raise NotImplementedError()
- def decode(self, variable: Variable, name: T_Name=None) ->Variable:
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
"""Convert an decoded variable to a encoded variable"""
- pass
+ raise NotImplementedError()
class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
@@ -59,13 +64,24 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
self.func = func
self._dtype = dtype
+ @property
+ def dtype(self) -> np.dtype:
+ return np.dtype(self._dtype)
+
+ def _oindex_get(self, key):
+ return type(self)(self.array.oindex[key], self.func, self.dtype)
+
+ def _vindex_get(self, key):
+ return type(self)(self.array.vindex[key], self.func, self.dtype)
+
def __getitem__(self, key):
return type(self)(self.array[key], self.func, self.dtype)
- def __repr__(self) ->str:
- return (
- f'{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})'
- )
+ def get_duck_array(self):
+ return self.func(self.array.get_duck_array())
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})"
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
@@ -87,12 +103,23 @@ class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
>>> NativeEndiannessArray(x)[indexer].dtype
dtype('int16')
"""
- __slots__ = 'array',
- def __init__(self, array) ->None:
+ __slots__ = ("array",)
+
+ def __init__(self, array) -> None:
self.array = indexing.as_indexable(array)
- def __getitem__(self, key) ->np.ndarray:
+ @property
+ def dtype(self) -> np.dtype:
+ return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
+
+ def _oindex_get(self, key):
+ return np.asarray(self.array.oindex[key], dtype=self.dtype)
+
+ def _vindex_get(self, key):
+ return np.asarray(self.array.vindex[key], dtype=self.dtype)
+
+ def __getitem__(self, key) -> np.ndarray:
return np.asarray(self.array[key], dtype=self.dtype)
@@ -114,12 +141,23 @@ class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
>>> BoolTypeArray(x)[indexer].dtype
dtype('bool')
"""
- __slots__ = 'array',
- def __init__(self, array) ->None:
+ __slots__ = ("array",)
+
+ def __init__(self, array) -> None:
self.array = indexing.as_indexable(array)
- def __getitem__(self, key) ->np.ndarray:
+ @property
+ def dtype(self) -> np.dtype:
+ return np.dtype("bool")
+
+ def _oindex_get(self, key):
+ return np.asarray(self.array.oindex[key], dtype=self.dtype)
+
+ def _vindex_get(self, key):
+ return np.asarray(self.array.vindex[key], dtype=self.dtype)
+
+ def __getitem__(self, key) -> np.ndarray:
return np.asarray(self.array[key], dtype=self.dtype)
@@ -138,23 +176,88 @@ def lazy_elemwise_func(array, func: Callable, dtype: np.typing.DTypeLike):
-------
Either a dask.array.Array or _ElementwiseFunctionArray.
"""
- pass
+ if is_chunked_array(array):
+ chunkmanager = get_chunked_array_type(array)
+
+ return chunkmanager.map_blocks(func, array, dtype=dtype) # type: ignore[arg-type]
+ else:
+ return _ElementwiseFunctionArray(array, func, dtype)
+
+
+def unpack_for_encoding(var: Variable) -> T_VarTuple:
+ return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
+
+
+def unpack_for_decoding(var: Variable) -> T_VarTuple:
+ return var.dims, var._data, var.attrs.copy(), var.encoding.copy()
+
+def safe_setitem(dest, key: Hashable, value, name: T_Name = None):
+ if key in dest:
+ var_str = f" on variable {name!r}" if name else ""
+ raise ValueError(
+ f"failed to prevent overwriting existing key {key} in attrs{var_str}. "
+ "This is probably an encoding field used by xarray to describe "
+ "how a variable is serialized. To proceed, remove this key from "
+ "the variable's attributes manually."
+ )
+ dest[key] = value
-def pop_to(source: MutableMapping, dest: MutableMapping, key: Hashable,
- name: T_Name=None) ->Any:
+
+def pop_to(
+ source: MutableMapping, dest: MutableMapping, key: Hashable, name: T_Name = None
+) -> Any:
"""
A convenience function which pops a key k from source to dest.
None values are not passed on. If k already exists in dest an
error is raised.
"""
- pass
-
-
-def _apply_mask(data: np.ndarray, encoded_fill_values: list,
- decoded_fill_value: Any, dtype: np.typing.DTypeLike) ->np.ndarray:
+ value = source.pop(key, None)
+ if value is not None:
+ safe_setitem(dest, key, value, name=name)
+ return value
+
+
+def _apply_mask(
+ data: np.ndarray,
+ encoded_fill_values: list,
+ decoded_fill_value: Any,
+ dtype: np.typing.DTypeLike,
+) -> np.ndarray:
"""Mask all matching values in a NumPy arrays."""
- pass
+ data = np.asarray(data, dtype=dtype)
+ condition = False
+ for fv in encoded_fill_values:
+ condition |= data == fv
+ return np.where(condition, decoded_fill_value, data)
+
+
+def _is_time_like(units):
+ # test for time-like
+ if units is None:
+ return False
+ time_strings = [
+ "days",
+ "hours",
+ "minutes",
+ "seconds",
+ "milliseconds",
+ "microseconds",
+ "nanoseconds",
+ ]
+ units = str(units)
+ # to prevent detecting units like `days accumulated` as time-like
+ # special casing for datetime-units and timedelta-units (GH-8269)
+ if "since" in units:
+ from xarray.coding.times import _unpack_netcdf_time_units
+
+ try:
+ _unpack_netcdf_time_units(units)
+ except ValueError:
+ return False
+ return True
+ else:
+ return any(tstr == units for tstr in time_strings)
def _check_fill_values(attrs, name, dtype):
@@ -164,17 +267,187 @@ def _check_fill_values(attrs, name, dtype):
Issue SerializationWarning if appropriate.
"""
- pass
+ raw_fill_dict = {}
+ [
+ pop_to(attrs, raw_fill_dict, attr, name=name)
+ for attr in ("missing_value", "_FillValue")
+ ]
+ encoded_fill_values = set()
+ for k in list(raw_fill_dict):
+ v = raw_fill_dict[k]
+ kfill = {fv for fv in np.ravel(v) if not pd.isnull(fv)}
+ if not kfill and np.issubdtype(dtype, np.integer):
+ warnings.warn(
+ f"variable {name!r} has non-conforming {k!r} "
+ f"{v!r} defined, dropping {k!r} entirely.",
+ SerializationWarning,
+ stacklevel=3,
+ )
+ del raw_fill_dict[k]
+ else:
+ encoded_fill_values |= kfill
+
+ if len(encoded_fill_values) > 1:
+ warnings.warn(
+ f"variable {name!r} has multiple fill values "
+ f"{encoded_fill_values} defined, decoding all values to NaN.",
+ SerializationWarning,
+ stacklevel=3,
+ )
+
+ return raw_fill_dict, encoded_fill_values
class CFMaskCoder(VariableCoder):
"""Mask or unmask fill values according to CF conventions."""
+ def encode(self, variable: Variable, name: T_Name = None):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
-def _choose_float_dtype(dtype: np.dtype, mapping: MutableMapping) ->type[np
- .floating[Any]]:
+ dtype = np.dtype(encoding.get("dtype", data.dtype))
+ fv = encoding.get("_FillValue")
+ mv = encoding.get("missing_value")
+ # to properly handle _FillValue/missing_value below [a], [b]
+ # we need to check if unsigned data is written as signed data
+ unsigned = encoding.get("_Unsigned") is not None
+
+ fv_exists = fv is not None
+ mv_exists = mv is not None
+
+ if not fv_exists and not mv_exists:
+ return variable
+
+ if fv_exists and mv_exists and not duck_array_ops.allclose_or_equiv(fv, mv):
+ raise ValueError(
+ f"Variable {name!r} has conflicting _FillValue ({fv}) and missing_value ({mv}). Cannot encode data."
+ )
+
+ if fv_exists:
+ # Ensure _FillValue is cast to same dtype as data's
+ # [a] need to skip this if _Unsigned is available
+ if not unsigned:
+ encoding["_FillValue"] = dtype.type(fv)
+ fill_value = pop_to(encoding, attrs, "_FillValue", name=name)
+
+ if mv_exists:
+ # try to use _FillValue, if it exists to align both values
+ # or use missing_value and ensure it's cast to same dtype as data's
+ # [b] need to provide mv verbatim if _Unsigned is available
+ encoding["missing_value"] = attrs.get(
+ "_FillValue",
+ (dtype.type(mv) if not unsigned else mv),
+ )
+ fill_value = pop_to(encoding, attrs, "missing_value", name=name)
+
+ # apply fillna
+ if not pd.isnull(fill_value):
+ # special case DateTime to properly handle NaT
+ if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
+ data = duck_array_ops.where(
+ data != np.iinfo(np.int64).min, data, fill_value
+ )
+ else:
+ data = duck_array_ops.fillna(data, fill_value)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+
+ def decode(self, variable: Variable, name: T_Name = None):
+ raw_fill_dict, encoded_fill_values = _check_fill_values(
+ variable.attrs, name, variable.dtype
+ )
+
+ if raw_fill_dict:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+ [
+ safe_setitem(encoding, attr, value, name=name)
+ for attr, value in raw_fill_dict.items()
+ ]
+
+ if encoded_fill_values:
+ # special case DateTime to properly handle NaT
+ dtype: np.typing.DTypeLike
+ decoded_fill_value: Any
+ if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu":
+ dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min
+ else:
+ if "scale_factor" not in attrs and "add_offset" not in attrs:
+ dtype, decoded_fill_value = dtypes.maybe_promote(data.dtype)
+ else:
+ dtype, decoded_fill_value = (
+ _choose_float_dtype(data.dtype, attrs),
+ np.nan,
+ )
+
+ transform = partial(
+ _apply_mask,
+ encoded_fill_values=encoded_fill_values,
+ decoded_fill_value=decoded_fill_value,
+ dtype=dtype,
+ )
+ data = lazy_elemwise_func(data, transform, dtype)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+
+def _scale_offset_decoding(data, scale_factor, add_offset, dtype: np.typing.DTypeLike):
+ data = data.astype(dtype=dtype, copy=True)
+ if scale_factor is not None:
+ data *= scale_factor
+ if add_offset is not None:
+ data += add_offset
+ return data
+
+
+def _choose_float_dtype(
+ dtype: np.dtype, mapping: MutableMapping
+) -> type[np.floating[Any]]:
"""Return a float dtype that can losslessly represent `dtype` values."""
- pass
+ # check scale/offset first to derive wanted float dtype
+ # see https://github.com/pydata/xarray/issues/5597#issuecomment-879561954
+ scale_factor = mapping.get("scale_factor")
+ add_offset = mapping.get("add_offset")
+ if scale_factor is not None or add_offset is not None:
+ # get the type from scale_factor/add_offset to determine
+ # the needed floating point type
+ if scale_factor is not None:
+ scale_type = np.dtype(type(scale_factor))
+ if add_offset is not None:
+ offset_type = np.dtype(type(add_offset))
+ # CF conforming, both scale_factor and add-offset are given and
+ # of same floating point type (float32/64)
+ if (
+ add_offset is not None
+ and scale_factor is not None
+ and offset_type == scale_type
+ and scale_type in [np.float32, np.float64]
+ ):
+ # in case of int32 -> we need upcast to float64
+ # due to precision issues
+ if dtype.itemsize == 4 and np.issubdtype(dtype, np.integer):
+ return np.float64
+ return scale_type.type
+ # Not CF conforming and add_offset given:
+ # A scale factor is entirely safe (vanishing into the mantissa),
+ # but a large integer offset could lead to loss of precision.
+ # Sensitivity analysis can be tricky, so we just use a float64
+ # if there's any offset at all - better unoptimised than wrong!
+ if add_offset is not None:
+ return np.float64
+ # return dtype depending on given scale_factor
+ return scale_type.type
+ # If no scale_factor or add_offset is given, use some general rules.
+ # Keep float32 as-is. Upcast half-precision to single-precision,
+ # because float16 is "intended for storage but not computation"
+ if dtype.itemsize <= 4 and np.issubdtype(dtype, np.floating):
+ return np.float32
+ # float32 can exactly represent all integers up to 24 bits
+ if dtype.itemsize <= 2 and np.issubdtype(dtype, np.integer):
+ return np.float32
+ # For all other types and circumstances, we just use float64.
+ # (safe because eg. complex numbers are not supported in NetCDF)
+ return np.float64
class CFScaleOffsetCoder(VariableCoder):
@@ -184,30 +457,247 @@ class CFScaleOffsetCoder(VariableCoder):
decode_values = encoded_values * scale_factor + add_offset
"""
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+
+ if "scale_factor" in encoding or "add_offset" in encoding:
+ # if we have a _FillValue/masked_value we do not want to cast now
+ # but leave that to CFMaskCoder
+ dtype = data.dtype
+ if "_FillValue" not in encoding and "missing_value" not in encoding:
+ dtype = _choose_float_dtype(data.dtype, encoding)
+ # but still we need a copy prevent changing original data
+ data = duck_array_ops.astype(data, dtype=dtype, copy=True)
+ if "add_offset" in encoding:
+ data -= pop_to(encoding, attrs, "add_offset", name=name)
+ if "scale_factor" in encoding:
+ data /= pop_to(encoding, attrs, "scale_factor", name=name)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ _attrs = variable.attrs
+ if "scale_factor" in _attrs or "add_offset" in _attrs:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+
+ scale_factor = pop_to(attrs, encoding, "scale_factor", name=name)
+ add_offset = pop_to(attrs, encoding, "add_offset", name=name)
+ if np.ndim(scale_factor) > 0:
+ scale_factor = np.asarray(scale_factor).item()
+ if np.ndim(add_offset) > 0:
+ add_offset = np.asarray(add_offset).item()
+ # if we have a _FillValue/masked_value we already have the wanted
+ # floating point dtype here (via CFMaskCoder), so no check is necessary
+ # only check in other cases
+ dtype = data.dtype
+ if "_FillValue" not in encoding and "missing_value" not in encoding:
+ dtype = _choose_float_dtype(dtype, encoding)
+
+ transform = partial(
+ _scale_offset_decoding,
+ scale_factor=scale_factor,
+ add_offset=add_offset,
+ dtype=dtype,
+ )
+ data = lazy_elemwise_func(data, transform, dtype)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
class UnsignedIntegerCoder(VariableCoder):
- pass
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ # from netCDF best practices
+ # https://docs.unidata.ucar.edu/nug/current/best_practices.html#bp_Unsigned-Data
+ # "_Unsigned = "true" to indicate that
+ # integer data should be treated as unsigned"
+ if variable.encoding.get("_Unsigned", "false") == "true":
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+
+ pop_to(encoding, attrs, "_Unsigned")
+ # we need the on-disk type here
+ # trying to get it from encoding, resort to an int with the same precision as data.dtype if not available
+ signed_dtype = np.dtype(encoding.get("dtype", f"i{data.dtype.itemsize}"))
+ if "_FillValue" in attrs:
+ try:
+ # user provided the on-disk signed fill
+ new_fill = signed_dtype.type(attrs["_FillValue"])
+ except OverflowError:
+ # user provided the in-memory unsigned fill, convert to signed type
+ unsigned_dtype = np.dtype(f"u{signed_dtype.itemsize}")
+ # use view here to prevent OverflowError
+ new_fill = (
+ np.array(attrs["_FillValue"], dtype=unsigned_dtype)
+ .view(signed_dtype)
+ .item()
+ )
+ attrs["_FillValue"] = new_fill
+ data = duck_array_ops.astype(duck_array_ops.around(data), signed_dtype)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if "_Unsigned" in variable.attrs:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+ unsigned = pop_to(attrs, encoding, "_Unsigned")
+
+ if data.dtype.kind == "i":
+ if unsigned == "true":
+ unsigned_dtype = np.dtype(f"u{data.dtype.itemsize}")
+ transform = partial(np.asarray, dtype=unsigned_dtype)
+ if "_FillValue" in attrs:
+ new_fill = np.array(attrs["_FillValue"], dtype=data.dtype)
+ # use view here to prevent OverflowError
+ attrs["_FillValue"] = new_fill.view(unsigned_dtype).item()
+ data = lazy_elemwise_func(data, transform, unsigned_dtype)
+ elif data.dtype.kind == "u":
+ if unsigned == "false":
+ signed_dtype = np.dtype(f"i{data.dtype.itemsize}")
+ transform = partial(np.asarray, dtype=signed_dtype)
+ data = lazy_elemwise_func(data, transform, signed_dtype)
+ if "_FillValue" in attrs:
+ new_fill = signed_dtype.type(attrs["_FillValue"])
+ attrs["_FillValue"] = new_fill
+ else:
+ warnings.warn(
+ f"variable {name!r} has _Unsigned attribute but is not "
+ "of integer type. Ignoring attribute.",
+ SerializationWarning,
+ stacklevel=3,
+ )
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
class DefaultFillvalueCoder(VariableCoder):
"""Encode default _FillValue if needed."""
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+ # make NaN the fill value for float types
+ if (
+ "_FillValue" not in attrs
+ and "_FillValue" not in encoding
+ and np.issubdtype(variable.dtype, np.floating)
+ ):
+ attrs["_FillValue"] = variable.dtype.type(np.nan)
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ raise NotImplementedError()
+
class BooleanCoder(VariableCoder):
"""Code boolean values."""
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if (
+ (variable.dtype == bool)
+ and ("dtype" not in variable.encoding)
+ and ("dtype" not in variable.attrs)
+ ):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+ attrs["dtype"] = "bool"
+ data = duck_array_ops.astype(data, dtype="i1", copy=True)
+
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if variable.attrs.get("dtype", False) == "bool":
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+ # overwrite (!) dtype in encoding, and remove from attrs
+ # needed for correct subsequent encoding
+ encoding["dtype"] = attrs.pop("dtype")
+ data = BoolTypeArray(data)
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
class EndianCoder(VariableCoder):
"""Decode Endianness to native."""
+ def encode(self):
+ raise NotImplementedError()
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ dims, data, attrs, encoding = unpack_for_decoding(variable)
+ if not data.dtype.isnative:
+ data = NativeEndiannessArray(data)
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
class NonStringCoder(VariableCoder):
"""Encode NonString variables if dtypes differ."""
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if "dtype" in variable.encoding and variable.encoding["dtype"] not in (
+ "S1",
+ str,
+ ):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+ dtype = np.dtype(encoding.pop("dtype"))
+ if dtype != variable.dtype:
+ if np.issubdtype(dtype, np.integer):
+ if (
+ np.issubdtype(variable.dtype, np.floating)
+ and "_FillValue" not in variable.attrs
+ and "missing_value" not in variable.attrs
+ ):
+ warnings.warn(
+ f"saving variable {name} with floating "
+ "point data as an integer dtype without "
+ "any _FillValue to use for NaNs",
+ SerializationWarning,
+ stacklevel=10,
+ )
+ data = np.around(data)
+ data = data.astype(dtype=dtype)
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self):
+ raise NotImplementedError()
+
class ObjectVLenStringCoder(VariableCoder):
- pass
+ def encode(self):
+ raise NotImplementedError
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if variable.dtype.kind == "O" and variable.encoding.get("dtype", False) is str:
+ variable = variable.astype(variable.encoding["dtype"])
+ return variable
+ else:
+ return variable
class NativeEnumCoder(VariableCoder):
"""Encode Enum into variable dtype metadata."""
+
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if (
+ "dtype" in variable.encoding
+ and np.dtype(variable.encoding["dtype"]).metadata
+ and "enum" in variable.encoding["dtype"].metadata
+ ):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+ data = data.astype(dtype=variable.encoding.pop("dtype"))
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ raise NotImplementedError()
diff --git a/xarray/conventions.py b/xarray/conventions.py
index de83ae8f..d572b215 100644
--- a/xarray/conventions.py
+++ b/xarray/conventions.py
@@ -1,23 +1,45 @@
from __future__ import annotations
+
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping, MutableMapping
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union
+
import numpy as np
import pandas as pd
+
from xarray.coding import strings, times, variables
from xarray.coding.variables import SerializationWarning, pop_to
from xarray.core import indexing
-from xarray.core.common import _contains_datetime_like_objects, contains_cftime_datetimes
+from xarray.core.common import (
+ _contains_datetime_like_objects,
+ contains_cftime_datetimes,
+)
from xarray.core.utils import emit_user_level_warning
from xarray.core.variable import IndexVariable, Variable
from xarray.namedarray.utils import is_duck_dask_array
-CF_RELATED_DATA = ('bounds', 'grid_mapping', 'climatology', 'geometry',
- 'node_coordinates', 'node_count', 'part_node_count', 'interior_ring',
- 'cell_measures', 'formula_terms')
-CF_RELATED_DATA_NEEDS_PARSING = 'cell_measures', 'formula_terms'
+
+CF_RELATED_DATA = (
+ "bounds",
+ "grid_mapping",
+ "climatology",
+ "geometry",
+ "node_coordinates",
+ "node_count",
+ "part_node_count",
+ "interior_ring",
+ "cell_measures",
+ "formula_terms",
+)
+CF_RELATED_DATA_NEEDS_PARSING = (
+ "cell_measures",
+ "formula_terms",
+)
+
+
if TYPE_CHECKING:
from xarray.backends.common import AbstractDataStore
from xarray.core.dataset import Dataset
+
T_VarTuple = tuple[tuple[Hashable, ...], Any, dict, dict]
T_Name = Union[Hashable, None]
T_Variables = Mapping[Any, Variable]
@@ -28,7 +50,51 @@ if TYPE_CHECKING:
def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
- pass
+ if array.dtype.kind != "O":
+ raise TypeError("infer_type must be called on a dtype=object array")
+
+ if array.size == 0:
+ return np.dtype(float)
+
+ native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel()))
+ if len(native_dtypes) > 1 and native_dtypes != {bytes, str}:
+ raise ValueError(
+ "unable to infer dtype on variable {!r}; object array "
+ "contains mixed native types: {}".format(
+ name, ", ".join(x.__name__ for x in native_dtypes)
+ )
+ )
+
+ element = array[(0,) * array.ndim]
+ # We use the base types to avoid subclasses of bytes and str (which might
+ # not play nice with e.g. hdf5 datatypes), such as those from numpy
+ if isinstance(element, bytes):
+ return strings.create_vlen_dtype(bytes)
+ elif isinstance(element, str):
+ return strings.create_vlen_dtype(str)
+
+ dtype = np.array(element).dtype
+ if dtype.kind != "O":
+ return dtype
+
+ raise ValueError(
+ f"unable to infer dtype on variable {name!r}; xarray "
+ "cannot serialize arbitrary Python objects"
+ )
+
+
+def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
+ # only the pandas multi-index dimension coordinate cannot be serialized (tuple values)
+ if isinstance(var._data, indexing.PandasMultiIndexingAdapter):
+ if name is None and isinstance(var, IndexVariable):
+ name = var.name
+ if var.dims == (name,):
+ raise NotImplementedError(
+ f"variable {name!r} is a MultiIndex, which cannot yet be "
+ "serialized. Instead, either use reset_index() "
+ "to convert MultiIndex levels into coordinate variables instead "
+ "or use https://cf-xarray.readthedocs.io/en/latest/coding.html."
+ )
def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
@@ -37,11 +103,64 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
We use this instead of np.array() to ensure that custom object dtypes end
up on the resulting array.
"""
- pass
-
-
-def encode_cf_variable(var: Variable, needs_copy: bool=True, name: T_Name=None
- ) ->Variable:
+ result = np.empty(data.shape, dtype)
+ result[...] = data
+ return result
+
+
+def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
+ # TODO: move this from conventions to backends? (it's not CF related)
+ if var.dtype.kind == "O":
+ dims, data, attrs, encoding = variables.unpack_for_encoding(var)
+
+ # leave vlen dtypes unchanged
+ if strings.check_vlen_dtype(data.dtype) is not None:
+ return var
+
+ if is_duck_dask_array(data):
+ emit_user_level_warning(
+ f"variable {name} has data in the form of a dask array with "
+ "dtype=object, which means it is being loaded into memory "
+ "to determine a data type that can be safely stored on disk. "
+ "To avoid this, coerce this variable to a fixed-size dtype "
+ "with astype() before saving it.",
+ category=SerializationWarning,
+ )
+ data = data.compute()
+
+ missing = pd.isnull(data)
+ if missing.any():
+ # nb. this will fail for dask.array data
+ non_missing_values = data[~missing]
+ inferred_dtype = _infer_dtype(non_missing_values, name)
+
+ # There is no safe bit-pattern for NA in typical binary string
+ # formats, we so can't set a fill_value. Unfortunately, this means
+ # we can't distinguish between missing values and empty strings.
+ fill_value: bytes | str
+ if strings.is_bytes_dtype(inferred_dtype):
+ fill_value = b""
+ elif strings.is_unicode_dtype(inferred_dtype):
+ fill_value = ""
+ else:
+ # insist on using float for numeric values
+ if not np.issubdtype(inferred_dtype, np.floating):
+ inferred_dtype = np.dtype(float)
+ fill_value = inferred_dtype.type(np.nan)
+
+ data = _copy_with_dtype(data, dtype=inferred_dtype)
+ data[missing] = fill_value
+ else:
+ data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
+
+ assert data.dtype.kind != "O" or data.dtype.metadata
+ var = Variable(dims, data, attrs, encoding, fastpath=True)
+ return var
+
+
+def encode_cf_variable(
+ var: Variable, needs_copy: bool = True, name: T_Name = None
+) -> Variable:
"""
Converts a Variable into a Variable which follows some
of the CF conventions:
@@ -61,13 +180,40 @@ def encode_cf_variable(var: Variable, needs_copy: bool=True, name: T_Name=None
out : Variable
A variable which has been encoded as described above.
"""
- pass
-
-
-def decode_cf_variable(name: Hashable, var: Variable, concat_characters:
- bool=True, mask_and_scale: bool=True, decode_times: bool=True,
- decode_endianness: bool=True, stack_char_dim: bool=True, use_cftime: (
- bool | None)=None, decode_timedelta: (bool | None)=None) ->Variable:
+ ensure_not_multiindex(var, name=name)
+
+ for coder in [
+ times.CFDatetimeCoder(),
+ times.CFTimedeltaCoder(),
+ variables.CFScaleOffsetCoder(),
+ variables.CFMaskCoder(),
+ variables.UnsignedIntegerCoder(),
+ variables.NativeEnumCoder(),
+ variables.NonStringCoder(),
+ variables.DefaultFillvalueCoder(),
+ variables.BooleanCoder(),
+ ]:
+ var = coder.encode(var, name=name)
+
+ # TODO(kmuehlbauer): check if ensure_dtype_not_object can be moved to backends:
+ var = ensure_dtype_not_object(var, name=name)
+
+ for attr_name in CF_RELATED_DATA:
+ pop_to(var.encoding, var.attrs, attr_name)
+ return var
+
+
+def decode_cf_variable(
+ name: Hashable,
+ var: Variable,
+ concat_characters: bool = True,
+ mask_and_scale: bool = True,
+ decode_times: bool = True,
+ decode_endianness: bool = True,
+ stack_char_dim: bool = True,
+ use_cftime: bool | None = None,
+ decode_timedelta: bool | None = None,
+) -> Variable:
"""
Decodes a variable which may hold CF encoded information.
@@ -113,10 +259,54 @@ def decode_cf_variable(name: Hashable, var: Variable, concat_characters:
out : Variable
A variable holding the decoded equivalent of var.
"""
- pass
+ # Ensure datetime-like Variables are passed through unmodified (GH 6453)
+ if _contains_datetime_like_objects(var):
+ return var
+
+ original_dtype = var.dtype
+
+ if decode_timedelta is None:
+ decode_timedelta = decode_times
+
+ if concat_characters:
+ if stack_char_dim:
+ var = strings.CharacterArrayCoder().decode(var, name=name)
+ var = strings.EncodedStringCoder().decode(var)
+
+ if original_dtype.kind == "O":
+ var = variables.ObjectVLenStringCoder().decode(var)
+ original_dtype = var.dtype
+
+ if mask_and_scale:
+ for coder in [
+ variables.UnsignedIntegerCoder(),
+ variables.CFMaskCoder(),
+ variables.CFScaleOffsetCoder(),
+ ]:
+ var = coder.decode(var, name=name)
+
+ if decode_timedelta:
+ var = times.CFTimedeltaCoder().decode(var, name=name)
+ if decode_times:
+ var = times.CFDatetimeCoder(use_cftime=use_cftime).decode(var, name=name)
+
+ if decode_endianness and not var.dtype.isnative:
+ var = variables.EndianCoder().decode(var)
+ original_dtype = var.dtype
+ var = variables.BooleanCoder().decode(var)
-def _update_bounds_attributes(variables: T_Variables) ->None:
+ dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
+
+ encoding.setdefault("dtype", original_dtype)
+
+ if not is_duck_dask_array(data):
+ data = indexing.LazilyIndexedArray(data)
+
+ return Variable(dimensions, data, attributes, encoding=encoding, fastpath=True)
+
+
+def _update_bounds_attributes(variables: T_Variables) -> None:
"""Adds time attributes to time bounds variables.
Variables handling time bounds ("Cell boundaries" in the CF
@@ -131,10 +321,21 @@ def _update_bounds_attributes(variables: T_Variables) ->None:
https://github.com/pydata/xarray/issues/2565
"""
- pass
+ # For all time variables with bounds
+ for v in variables.values():
+ attrs = v.attrs
+ units = attrs.get("units")
+ has_date_units = isinstance(units, str) and "since" in units
+ if has_date_units and "bounds" in attrs:
+ if attrs["bounds"] in variables:
+ bounds_attrs = variables[attrs["bounds"]].attrs
+ bounds_attrs.setdefault("units", attrs["units"])
+ if "calendar" in attrs:
+ bounds_attrs.setdefault("calendar", attrs["calendar"])
-def _update_bounds_encoding(variables: T_Variables) ->None:
+
+def _update_bounds_encoding(variables: T_Variables) -> None:
"""Adds time encoding to time bounds variables.
Variables handling time bounds ("Cell boundaries" in the CF
@@ -149,40 +350,177 @@ def _update_bounds_encoding(variables: T_Variables) ->None:
https://github.com/pydata/xarray/issues/2565
"""
- pass
-
-
-T = TypeVar('T')
-
-def _item_or_default(obj: (Mapping[Any, T] | T), key: Hashable, default: T
- ) ->T:
+ # For all time variables with bounds
+ for name, v in variables.items():
+ attrs = v.attrs
+ encoding = v.encoding
+ has_date_units = "units" in encoding and "since" in encoding["units"]
+ is_datetime_type = np.issubdtype(
+ v.dtype, np.datetime64
+ ) or contains_cftime_datetimes(v)
+
+ if (
+ is_datetime_type
+ and not has_date_units
+ and "bounds" in attrs
+ and attrs["bounds"] in variables
+ ):
+ emit_user_level_warning(
+ f"Variable {name:s} has datetime type and a "
+ f"bounds variable but {name:s}.encoding does not have "
+ f"units specified. The units encodings for {name:s} "
+ f"and {attrs['bounds']} will be determined independently "
+ "and may not be equal, counter to CF-conventions. "
+ "If this is a concern, specify a units encoding for "
+ f"{name:s} before writing to a file.",
+ )
+
+ if has_date_units and "bounds" in attrs:
+ if attrs["bounds"] in variables:
+ bounds_encoding = variables[attrs["bounds"]].encoding
+ bounds_encoding.setdefault("units", encoding["units"])
+ if "calendar" in encoding:
+ bounds_encoding.setdefault("calendar", encoding["calendar"])
+
+
+T = TypeVar("T")
+
+
+def _item_or_default(obj: Mapping[Any, T] | T, key: Hashable, default: T) -> T:
"""
Return item by key if obj is mapping and key is present, else return default value.
"""
- pass
-
-
-def decode_cf_variables(variables: T_Variables, attributes: T_Attrs,
- concat_characters: (bool | Mapping[str, bool])=True, mask_and_scale: (
- bool | Mapping[str, bool])=True, decode_times: (bool | Mapping[str,
- bool])=True, decode_coords: (bool | Literal['coordinates', 'all'])=True,
- drop_variables: T_DropVariables=None, use_cftime: (bool | Mapping[str,
- bool] | None)=None, decode_timedelta: (bool | Mapping[str, bool] | None
- )=None) ->tuple[T_Variables, T_Attrs, set[Hashable]]:
+ return obj.get(key, default) if isinstance(obj, Mapping) else obj
+
+
+def decode_cf_variables(
+ variables: T_Variables,
+ attributes: T_Attrs,
+ concat_characters: bool | Mapping[str, bool] = True,
+ mask_and_scale: bool | Mapping[str, bool] = True,
+ decode_times: bool | Mapping[str, bool] = True,
+ decode_coords: bool | Literal["coordinates", "all"] = True,
+ drop_variables: T_DropVariables = None,
+ use_cftime: bool | Mapping[str, bool] | None = None,
+ decode_timedelta: bool | Mapping[str, bool] | None = None,
+) -> tuple[T_Variables, T_Attrs, set[Hashable]]:
"""
Decode several CF encoded variables.
See: decode_cf_variable
"""
- pass
-
-
-def decode_cf(obj: T_DatasetOrAbstractstore, concat_characters: bool=True,
- mask_and_scale: bool=True, decode_times: bool=True, decode_coords: (
- bool | Literal['coordinates', 'all'])=True, drop_variables:
- T_DropVariables=None, use_cftime: (bool | None)=None, decode_timedelta:
- (bool | None)=None) ->Dataset:
+ dimensions_used_by = defaultdict(list)
+ for v in variables.values():
+ for d in v.dims:
+ dimensions_used_by[d].append(v)
+
+ def stackable(dim: Hashable) -> bool:
+ # figure out if a dimension can be concatenated over
+ if dim in variables:
+ return False
+ for v in dimensions_used_by[dim]:
+ if v.dtype.kind != "S" or dim != v.dims[-1]:
+ return False
+ return True
+
+ coord_names = set()
+
+ if isinstance(drop_variables, str):
+ drop_variables = [drop_variables]
+ elif drop_variables is None:
+ drop_variables = []
+ drop_variables = set(drop_variables)
+
+ # Time bounds coordinates might miss the decoding attributes
+ if decode_times:
+ _update_bounds_attributes(variables)
+
+ new_vars = {}
+ for k, v in variables.items():
+ if k in drop_variables:
+ continue
+ stack_char_dim = (
+ _item_or_default(concat_characters, k, True)
+ and v.dtype == "S1"
+ and v.ndim > 0
+ and stackable(v.dims[-1])
+ )
+ try:
+ new_vars[k] = decode_cf_variable(
+ k,
+ v,
+ concat_characters=_item_or_default(concat_characters, k, True),
+ mask_and_scale=_item_or_default(mask_and_scale, k, True),
+ decode_times=_item_or_default(decode_times, k, True),
+ stack_char_dim=stack_char_dim,
+ use_cftime=_item_or_default(use_cftime, k, None),
+ decode_timedelta=_item_or_default(decode_timedelta, k, None),
+ )
+ except Exception as e:
+ raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
+ if decode_coords in [True, "coordinates", "all"]:
+ var_attrs = new_vars[k].attrs
+ if "coordinates" in var_attrs:
+ var_coord_names = [
+ c for c in var_attrs["coordinates"].split() if c in variables
+ ]
+ # propagate as is
+ new_vars[k].encoding["coordinates"] = var_attrs["coordinates"]
+ del var_attrs["coordinates"]
+ # but only use as coordinate if existing
+ if var_coord_names:
+ coord_names.update(var_coord_names)
+
+ if decode_coords == "all":
+ for attr_name in CF_RELATED_DATA:
+ if attr_name in var_attrs:
+ attr_val = var_attrs[attr_name]
+ if attr_name not in CF_RELATED_DATA_NEEDS_PARSING:
+ var_names = attr_val.split()
+ else:
+ roles_and_names = [
+ role_or_name
+ for part in attr_val.split(":")
+ for role_or_name in part.split()
+ ]
+ if len(roles_and_names) % 2 == 1:
+ emit_user_level_warning(
+ f"Attribute {attr_name:s} malformed"
+ )
+ var_names = roles_and_names[1::2]
+ if all(var_name in variables for var_name in var_names):
+ new_vars[k].encoding[attr_name] = attr_val
+ coord_names.update(var_names)
+ else:
+ referenced_vars_not_in_variables = [
+ proj_name
+ for proj_name in var_names
+ if proj_name not in variables
+ ]
+ emit_user_level_warning(
+ f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}",
+ )
+ del var_attrs[attr_name]
+
+ if decode_coords and isinstance(attributes.get("coordinates", None), str):
+ attributes = dict(attributes)
+ crds = attributes.pop("coordinates")
+ coord_names.update(crds.split())
+
+ return new_vars, attributes, coord_names
+
+
+def decode_cf(
+ obj: T_DatasetOrAbstractstore,
+ concat_characters: bool = True,
+ mask_and_scale: bool = True,
+ decode_times: bool = True,
+ decode_coords: bool | Literal["coordinates", "all"] = True,
+ drop_variables: T_DropVariables = None,
+ use_cftime: bool | None = None,
+ decode_timedelta: bool | None = None,
+) -> Dataset:
"""Decode the given Dataset or Datastore according to CF conventions into
a new Dataset.
@@ -231,12 +569,51 @@ def decode_cf(obj: T_DatasetOrAbstractstore, concat_characters: bool=True,
-------
decoded : Dataset
"""
- pass
-
+ from xarray.backends.common import AbstractDataStore
+ from xarray.core.dataset import Dataset
-def cf_decoder(variables: T_Variables, attributes: T_Attrs,
- concat_characters: bool=True, mask_and_scale: bool=True, decode_times:
- bool=True) ->tuple[T_Variables, T_Attrs]:
+ vars: T_Variables
+ attrs: T_Attrs
+ if isinstance(obj, Dataset):
+ vars = obj._variables
+ attrs = obj.attrs
+ extra_coords = set(obj.coords)
+ close = obj._close
+ encoding = obj.encoding
+ elif isinstance(obj, AbstractDataStore):
+ vars, attrs = obj.load()
+ extra_coords = set()
+ close = obj.close
+ encoding = obj.get_encoding()
+ else:
+ raise TypeError("can only decode Dataset or DataStore objects")
+
+ vars, attrs, coord_names = decode_cf_variables(
+ vars,
+ attrs,
+ concat_characters,
+ mask_and_scale,
+ decode_times,
+ decode_coords,
+ drop_variables=drop_variables,
+ use_cftime=use_cftime,
+ decode_timedelta=decode_timedelta,
+ )
+ ds = Dataset(vars, attrs=attrs)
+ ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars))
+ ds.set_close(close)
+ ds.encoding = encoding
+
+ return ds
+
+
+def cf_decoder(
+ variables: T_Variables,
+ attributes: T_Attrs,
+ concat_characters: bool = True,
+ mask_and_scale: bool = True,
+ decode_times: bool = True,
+) -> tuple[T_Variables, T_Attrs]:
"""
Decode a set of CF encoded variables and attributes.
@@ -266,7 +643,111 @@ def cf_decoder(variables: T_Variables, attributes: T_Attrs,
--------
decode_cf_variable
"""
- pass
+ variables, attributes, _ = decode_cf_variables(
+ variables,
+ attributes,
+ concat_characters,
+ mask_and_scale,
+ decode_times,
+ )
+ return variables, attributes
+
+
+def _encode_coordinates(
+ variables: T_Variables, attributes: T_Attrs, non_dim_coord_names
+):
+ # calculate global and variable specific coordinates
+ non_dim_coord_names = set(non_dim_coord_names)
+
+ for name in list(non_dim_coord_names):
+ if isinstance(name, str) and " " in name:
+ emit_user_level_warning(
+ f"coordinate {name!r} has a space in its name, which means it "
+ "cannot be marked as a coordinate on disk and will be "
+ "saved as a data variable instead",
+ category=SerializationWarning,
+ )
+ non_dim_coord_names.discard(name)
+
+ global_coordinates = non_dim_coord_names.copy()
+ variable_coordinates = defaultdict(set)
+ not_technically_coordinates = set()
+ for coord_name in non_dim_coord_names:
+ target_dims = variables[coord_name].dims
+ for k, v in variables.items():
+ if (
+ k not in non_dim_coord_names
+ and k not in v.dims
+ and set(target_dims) <= set(v.dims)
+ ):
+ variable_coordinates[k].add(coord_name)
+
+ if any(
+ coord_name in v.encoding.get(attr_name, tuple())
+ for attr_name in CF_RELATED_DATA
+ ):
+ not_technically_coordinates.add(coord_name)
+ global_coordinates.discard(coord_name)
+
+ variables = {k: v.copy(deep=False) for k, v in variables.items()}
+
+ # keep track of variable names written to file under the "coordinates" attributes
+ written_coords = set()
+ for name, var in variables.items():
+ encoding = var.encoding
+ attrs = var.attrs
+ if "coordinates" in attrs and "coordinates" in encoding:
+ raise ValueError(
+ f"'coordinates' found in both attrs and encoding for variable {name!r}."
+ )
+
+ # if coordinates set to None, don't write coordinates attribute
+ if (
+ "coordinates" in attrs
+ and attrs.get("coordinates") is None
+ or "coordinates" in encoding
+ and encoding.get("coordinates") is None
+ ):
+ # make sure "coordinates" is removed from attrs/encoding
+ attrs.pop("coordinates", None)
+ encoding.pop("coordinates", None)
+ continue
+
+ # this will copy coordinates from encoding to attrs if "coordinates" in attrs
+ # after the next line, "coordinates" is never in encoding
+ # we get support for attrs["coordinates"] for free.
+ coords_str = pop_to(encoding, attrs, "coordinates") or attrs.get("coordinates")
+ if not coords_str and variable_coordinates[name]:
+ coordinates_text = " ".join(
+ str(coord_name)
+ for coord_name in sorted(variable_coordinates[name])
+ if coord_name not in not_technically_coordinates
+ )
+ if coordinates_text:
+ attrs["coordinates"] = coordinates_text
+ if "coordinates" in attrs:
+ written_coords.update(attrs["coordinates"].split())
+
+ # These coordinates are not associated with any particular variables, so we
+ # save them under a global 'coordinates' attribute so xarray can roundtrip
+ # the dataset faithfully. Because this serialization goes beyond CF
+ # conventions, only do it if necessary.
+ # Reference discussion:
+ # http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/007571.html
+ global_coordinates.difference_update(written_coords)
+ if global_coordinates:
+ attributes = dict(attributes)
+ if "coordinates" in attributes:
+ emit_user_level_warning(
+ f"cannot serialize global coordinates {global_coordinates!r} because the global "
+ f"attribute 'coordinates' already exists. This may prevent faithful roundtripping"
+ f"of xarray datasets",
+ category=SerializationWarning,
+ )
+ else:
+ attributes["coordinates"] = " ".join(sorted(map(str, global_coordinates)))
+
+ return variables, attributes
def encode_dataset_coordinates(dataset: Dataset):
@@ -285,7 +766,10 @@ def encode_dataset_coordinates(dataset: Dataset):
variables : dict
attrs : dict
"""
- pass
+ non_dim_coord_names = set(dataset.coords) - set(dataset.dims)
+ return _encode_coordinates(
+ dataset._variables, dataset.attrs, non_dim_coord_names=non_dim_coord_names
+ )
def cf_encoder(variables: T_Variables, attributes: T_Attrs):
@@ -314,4 +798,30 @@ def cf_encoder(variables: T_Variables, attributes: T_Attrs):
--------
decode_cf_variable, encode_cf_variable
"""
- pass
+
+ # add encoding for time bounds variables if present.
+ _update_bounds_encoding(variables)
+
+ new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
+
+ # Remove attrs from bounds variables (issue #2921)
+ for var in new_vars.values():
+ bounds = var.attrs["bounds"] if "bounds" in var.attrs else None
+ if bounds and bounds in new_vars:
+ # see http://cfconventions.org/cf-conventions/cf-conventions.html#cell-boundaries
+ for attr in [
+ "units",
+ "standard_name",
+ "axis",
+ "positive",
+ "calendar",
+ "long_name",
+ "leap_month",
+ "leap_year",
+ "month_lengths",
+ ]:
+ if attr in new_vars[bounds].attrs and attr in var.attrs:
+ if new_vars[bounds].attrs[attr] == var.attrs[attr]:
+ new_vars[bounds].attrs.pop(attr)
+
+ return new_vars, attributes
diff --git a/xarray/convert.py b/xarray/convert.py
index d29fc8f6..b8d81ccf 100644
--- a/xarray/convert.py
+++ b/xarray/convert.py
@@ -1,61 +1,209 @@
"""Functions for converting to and from xarray objects
"""
+
from collections import Counter
+
import numpy as np
+
from xarray.coding.times import CFDatetimeCoder, CFTimedeltaCoder
from xarray.conventions import decode_cf
from xarray.core import duck_array_ops
from xarray.core.dataarray import DataArray
from xarray.core.dtypes import get_fill_value
from xarray.namedarray.pycompat import array_type
-iris_forbidden_keys = {'standard_name', 'long_name', 'units', 'bounds',
- 'axis', 'calendar', 'leap_month', 'leap_year', 'month_lengths',
- 'coordinates', 'grid_mapping', 'climatology', 'cell_methods',
- 'formula_terms', 'compress', 'missing_value', 'add_offset',
- 'scale_factor', 'valid_max', 'valid_min', 'valid_range', '_FillValue'}
-cell_methods_strings = {'point', 'sum', 'maximum', 'median', 'mid_range',
- 'minimum', 'mean', 'mode', 'standard_deviation', 'variance'}
+
+iris_forbidden_keys = {
+ "standard_name",
+ "long_name",
+ "units",
+ "bounds",
+ "axis",
+ "calendar",
+ "leap_month",
+ "leap_year",
+ "month_lengths",
+ "coordinates",
+ "grid_mapping",
+ "climatology",
+ "cell_methods",
+ "formula_terms",
+ "compress",
+ "missing_value",
+ "add_offset",
+ "scale_factor",
+ "valid_max",
+ "valid_min",
+ "valid_range",
+ "_FillValue",
+}
+cell_methods_strings = {
+ "point",
+ "sum",
+ "maximum",
+ "median",
+ "mid_range",
+ "minimum",
+ "mean",
+ "mode",
+ "standard_deviation",
+ "variance",
+}
+
+
+def encode(var):
+ return CFTimedeltaCoder().encode(CFDatetimeCoder().encode(var.variable))
def _filter_attrs(attrs, ignored_attrs):
"""Return attrs that are not in ignored_attrs"""
- pass
+ return {k: v for k, v in attrs.items() if k not in ignored_attrs}
def _pick_attrs(attrs, keys):
"""Return attrs with keys in keys list"""
- pass
+ return {k: v for k, v in attrs.items() if k in keys}
def _get_iris_args(attrs):
"""Converts the xarray attrs into args that can be passed into Iris"""
- pass
+ # iris.unit is deprecated in Iris v1.9
+ import cf_units
+ args = {"attributes": _filter_attrs(attrs, iris_forbidden_keys)}
+ args.update(_pick_attrs(attrs, ("standard_name", "long_name")))
+ unit_args = _pick_attrs(attrs, ("calendar",))
+ if "units" in attrs:
+ args["units"] = cf_units.Unit(attrs["units"], **unit_args)
+ return args
+
+# TODO: Add converting bounds from xarray to Iris and back
def to_iris(dataarray):
"""Convert a DataArray into a Iris Cube"""
- pass
+ # Iris not a hard dependency
+ import iris
+ from iris.fileformats.netcdf import parse_cell_methods
+
+ dim_coords = []
+ aux_coords = []
+
+ for coord_name in dataarray.coords:
+ coord = encode(dataarray.coords[coord_name])
+ coord_args = _get_iris_args(coord.attrs)
+ coord_args["var_name"] = coord_name
+ axis = None
+ if coord.dims:
+ axis = dataarray.get_axis_num(coord.dims)
+ if coord_name in dataarray.dims:
+ try:
+ iris_coord = iris.coords.DimCoord(coord.values, **coord_args)
+ dim_coords.append((iris_coord, axis))
+ except ValueError:
+ iris_coord = iris.coords.AuxCoord(coord.values, **coord_args)
+ aux_coords.append((iris_coord, axis))
+ else:
+ iris_coord = iris.coords.AuxCoord(coord.values, **coord_args)
+ aux_coords.append((iris_coord, axis))
+
+ args = _get_iris_args(dataarray.attrs)
+ args["var_name"] = dataarray.name
+ args["dim_coords_and_dims"] = dim_coords
+ args["aux_coords_and_dims"] = aux_coords
+ if "cell_methods" in dataarray.attrs:
+ args["cell_methods"] = parse_cell_methods(dataarray.attrs["cell_methods"])
+
+ masked_data = duck_array_ops.masked_invalid(dataarray.data)
+ cube = iris.cube.Cube(masked_data, **args)
+
+ return cube
def _iris_obj_to_attrs(obj):
"""Return a dictionary of attrs when given a Iris object"""
- pass
+ attrs = {"standard_name": obj.standard_name, "long_name": obj.long_name}
+ if obj.units.calendar:
+ attrs["calendar"] = obj.units.calendar
+ if obj.units.origin != "1" and not obj.units.is_unknown():
+ attrs["units"] = obj.units.origin
+ attrs.update(obj.attributes)
+ return {k: v for k, v in attrs.items() if v is not None}
def _iris_cell_methods_to_str(cell_methods_obj):
"""Converts a Iris cell methods into a string"""
- pass
-
-
-def _name(iris_obj, default='unknown'):
+ cell_methods = []
+ for cell_method in cell_methods_obj:
+ names = "".join(f"{n}: " for n in cell_method.coord_names)
+ intervals = " ".join(
+ f"interval: {interval}" for interval in cell_method.intervals
+ )
+ comments = " ".join(f"comment: {comment}" for comment in cell_method.comments)
+ extra = " ".join([intervals, comments]).strip()
+ if extra:
+ extra = f" ({extra})"
+ cell_methods.append(names + cell_method.method + extra)
+ return " ".join(cell_methods)
+
+
+def _name(iris_obj, default="unknown"):
"""Mimics `iris_obj.name()` but with different name resolution order.
Similar to iris_obj.name() method, but using iris_obj.var_name first to
enable roundtripping.
"""
- pass
+ return iris_obj.var_name or iris_obj.standard_name or iris_obj.long_name or default
def from_iris(cube):
"""Convert a Iris cube into an DataArray"""
- pass
+ import iris.exceptions
+
+ name = _name(cube)
+ if name == "unknown":
+ name = None
+ dims = []
+ for i in range(cube.ndim):
+ try:
+ dim_coord = cube.coord(dim_coords=True, dimensions=(i,))
+ dims.append(_name(dim_coord))
+ except iris.exceptions.CoordinateNotFoundError:
+ dims.append(f"dim_{i}")
+
+ if len(set(dims)) != len(dims):
+ duplicates = [k for k, v in Counter(dims).items() if v > 1]
+ raise ValueError(f"Duplicate coordinate name {duplicates}.")
+
+ coords = {}
+
+ for coord in cube.coords():
+ coord_attrs = _iris_obj_to_attrs(coord)
+ coord_dims = [dims[i] for i in cube.coord_dims(coord)]
+ if coord_dims:
+ coords[_name(coord)] = (coord_dims, coord.points, coord_attrs)
+ else:
+ coords[_name(coord)] = ((), coord.points.item(), coord_attrs)
+
+ array_attrs = _iris_obj_to_attrs(cube)
+ cell_methods = _iris_cell_methods_to_str(cube.cell_methods)
+ if cell_methods:
+ array_attrs["cell_methods"] = cell_methods
+
+ # Deal with iris 1.* and 2.*
+ cube_data = cube.core_data() if hasattr(cube, "core_data") else cube.data
+
+ # Deal with dask and numpy masked arrays
+ dask_array_type = array_type("dask")
+ if isinstance(cube_data, dask_array_type):
+ from dask.array import ma as dask_ma
+
+ filled_data = dask_ma.filled(cube_data, get_fill_value(cube.dtype))
+ elif isinstance(cube_data, np.ma.MaskedArray):
+ filled_data = np.ma.filled(cube_data, get_fill_value(cube.dtype))
+ else:
+ filled_data = cube_data
+
+ dataarray = DataArray(
+ filled_data, coords=coords, name=name, attrs=array_attrs, dims=dims
+ )
+ decoded_ds = decode_cf(dataarray._to_temp_dataset())
+ return dataarray._from_temp_dataset(decoded_ds)
diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py
index 0e529472..acc534d8 100644
--- a/xarray/core/_aggregations.py
+++ b/xarray/core/_aggregations.py
@@ -1,22 +1,46 @@
"""Mixin classes with reduction operations."""
+
+# This file was generated using xarray.util.generate_aggregations. Do not edit manually.
+
from __future__ import annotations
+
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable
+
from xarray.core import duck_array_ops
from xarray.core.options import OPTIONS
from xarray.core.types import Dims, Self
from xarray.core.utils import contains_only_chunked_or_numpy, module_available
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
-flox_available = module_available('flox')
+
+flox_available = module_available("flox")
class DatasetAggregations:
__slots__ = ()
- def count(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Self:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> Self:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``count`` along some dimension(s).
@@ -74,10 +98,21 @@ class DatasetAggregations:
Data variables:
da int64 8B 5
"""
- pass
-
- def all(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``all`` along some dimension(s).
@@ -135,10 +170,21 @@ class DatasetAggregations:
Data variables:
da bool 1B False
"""
- pass
-
- def any(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``any`` along some dimension(s).
@@ -196,10 +242,22 @@ class DatasetAggregations:
Data variables:
da bool 1B True
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``max`` along some dimension(s).
@@ -270,10 +328,23 @@ class DatasetAggregations:
Data variables:
da float64 8B nan
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``min`` along some dimension(s).
@@ -344,10 +415,23 @@ class DatasetAggregations:
Data variables:
da float64 8B nan
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``mean`` along some dimension(s).
@@ -422,11 +506,24 @@ class DatasetAggregations:
Data variables:
da float64 8B nan
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``prod`` along some dimension(s).
@@ -515,11 +612,25 @@ class DatasetAggregations:
Data variables:
da float64 8B 0.0
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``sum`` along some dimension(s).
@@ -608,10 +719,25 @@ class DatasetAggregations:
Data variables:
da float64 8B 8.0
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``std`` along some dimension(s).
@@ -697,10 +823,25 @@ class DatasetAggregations:
Data variables:
da float64 8B 1.14
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``var`` along some dimension(s).
@@ -786,10 +927,24 @@ class DatasetAggregations:
Data variables:
da float64 8B 1.3
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``median`` along some dimension(s).
@@ -864,10 +1019,23 @@ class DatasetAggregations:
Data variables:
da float64 8B nan
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``cumsum`` along some dimension(s).
@@ -944,10 +1112,23 @@ class DatasetAggregations:
Data variables:
da (time) float64 48B 1.0 3.0 6.0 6.0 8.0 nan
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this Dataset's data by applying ``cumprod`` along some dimension(s).
@@ -1024,14 +1205,38 @@ class DatasetAggregations:
Data variables:
da (time) float64 48B 1.0 2.0 6.0 0.0 0.0 nan
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
class DataArrayAggregations:
__slots__ = ()
- def count(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Self:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> Self:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``count`` along some dimension(s).
@@ -1084,10 +1289,20 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(5)
"""
- pass
-
- def all(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``all`` along some dimension(s).
@@ -1140,10 +1355,20 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 1B
array(False)
"""
- pass
-
- def any(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``any`` along some dimension(s).
@@ -1196,10 +1421,21 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 1B
array(True)
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``max`` along some dimension(s).
@@ -1263,10 +1499,22 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(nan)
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``min`` along some dimension(s).
@@ -1330,10 +1578,22 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(nan)
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``mean`` along some dimension(s).
@@ -1401,11 +1661,23 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(nan)
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``prod`` along some dimension(s).
@@ -1485,11 +1757,24 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(0.)
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``sum`` along some dimension(s).
@@ -1569,10 +1854,24 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(8.)
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``std`` along some dimension(s).
@@ -1649,10 +1948,24 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(1.14017543)
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``var`` along some dimension(s).
@@ -1729,10 +2042,23 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(1.3)
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``median`` along some dimension(s).
@@ -1800,10 +2126,22 @@ class DataArrayAggregations:
<xarray.DataArray ()> Size: 8B
array(nan)
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``cumsum`` along some dimension(s).
@@ -1877,10 +2215,22 @@ class DataArrayAggregations:
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this DataArray's data by applying ``cumprod`` along some dimension(s).
@@ -1954,14 +2304,44 @@ class DataArrayAggregations:
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
class DatasetGroupByAggregations:
_obj: Dataset
- def count(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Dataset:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> Dataset:
+ raise NotImplementedError()
+
+ def _flox_reduce(
+ self,
+ dim: Dims,
+ **kwargs: Any,
+ ) -> Dataset:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``count`` along some dimension(s).
@@ -2029,10 +2409,35 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) int64 24B 1 2 2
"""
- pass
-
- def all(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="count",
+ dim=dim,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``all`` along some dimension(s).
@@ -2100,10 +2505,35 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) bool 3B False True True
"""
- pass
-
- def any(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="all",
+ dim=dim,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``any`` along some dimension(s).
@@ -2171,10 +2601,36 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) bool 3B True True True
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="any",
+ dim=dim,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``max`` along some dimension(s).
@@ -2257,10 +2713,38 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 2.0 3.0
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="max",
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``min`` along some dimension(s).
@@ -2343,10 +2827,38 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 2.0 0.0
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="min",
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``mean`` along some dimension(s).
@@ -2431,11 +2943,39 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 2.0 1.5
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="mean",
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``prod`` along some dimension(s).
@@ -2536,11 +3076,41 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 4.0 0.0
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="prod",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``sum`` along some dimension(s).
@@ -2641,10 +3211,41 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 4.0 3.0
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="sum",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``std`` along some dimension(s).
@@ -2742,10 +3343,41 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 0.0 2.121
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="std",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``var`` along some dimension(s).
@@ -2843,10 +3475,40 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 0.0 4.5
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="var",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``median`` along some dimension(s).
@@ -2931,10 +3593,23 @@ class DatasetGroupByAggregations:
Data variables:
da (labels) float64 24B nan 2.0 1.5
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``cumsum`` along some dimension(s).
@@ -3017,10 +3692,23 @@ class DatasetGroupByAggregations:
Data variables:
da (time) float64 48B 1.0 2.0 3.0 3.0 4.0 nan
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``cumprod`` along some dimension(s).
@@ -3103,14 +3791,45 @@ class DatasetGroupByAggregations:
Data variables:
da (time) float64 48B 1.0 2.0 3.0 0.0 4.0 nan
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
class DatasetResampleAggregations:
_obj: Dataset
- def count(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Dataset:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> Dataset:
+ raise NotImplementedError()
+
+ def _flox_reduce(
+ self,
+ dim: Dims,
+ **kwargs: Any,
+ ) -> Dataset:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``count`` along some dimension(s).
@@ -3178,10 +3897,35 @@ class DatasetResampleAggregations:
Data variables:
da (time) int64 24B 1 3 1
"""
- pass
-
- def all(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="count",
+ dim=dim,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``all`` along some dimension(s).
@@ -3249,10 +3993,35 @@ class DatasetResampleAggregations:
Data variables:
da (time) bool 3B True True False
"""
- pass
-
- def any(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="all",
+ dim=dim,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``any`` along some dimension(s).
@@ -3320,10 +4089,36 @@ class DatasetResampleAggregations:
Data variables:
da (time) bool 3B True True True
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="any",
+ dim=dim,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``max`` along some dimension(s).
@@ -3406,10 +4201,38 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B 1.0 3.0 nan
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="max",
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``min`` along some dimension(s).
@@ -3492,10 +4315,38 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B 1.0 0.0 nan
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="min",
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``mean`` along some dimension(s).
@@ -3580,11 +4431,39 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B 1.0 1.667 nan
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="mean",
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``prod`` along some dimension(s).
@@ -3685,11 +4564,41 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B nan 0.0 nan
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="prod",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``sum`` along some dimension(s).
@@ -3790,10 +4699,41 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B nan 5.0 nan
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="sum",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``std`` along some dimension(s).
@@ -3891,10 +4831,41 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B nan 1.528 nan
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="std",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``var`` along some dimension(s).
@@ -3992,10 +4963,40 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B nan 2.333 nan
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="var",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``median`` along some dimension(s).
@@ -4080,10 +5081,23 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 24B 1.0 2.0 nan
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``cumsum`` along some dimension(s).
@@ -4166,10 +5180,23 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 48B 1.0 2.0 5.0 5.0 2.0 nan
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->Dataset:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""
Reduce this Dataset's data by applying ``cumprod`` along some dimension(s).
@@ -4252,14 +5279,45 @@ class DatasetResampleAggregations:
Data variables:
da (time) float64 48B 1.0 2.0 6.0 0.0 2.0 nan
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=True,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
class DataArrayGroupByAggregations:
_obj: DataArray
- def count(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> DataArray:
+ raise NotImplementedError()
+
+ def _flox_reduce(
+ self,
+ dim: Dims,
+ **kwargs: Any,
+ ) -> DataArray:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``count`` along some dimension(s).
@@ -4322,10 +5380,33 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def all(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="count",
+ dim=dim,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``all`` along some dimension(s).
@@ -4388,10 +5469,33 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def any(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="all",
+ dim=dim,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``any`` along some dimension(s).
@@ -4454,10 +5558,34 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="any",
+ dim=dim,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``max`` along some dimension(s).
@@ -4533,10 +5661,36 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="max",
+ dim=dim,
+ skipna=skipna,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``min`` along some dimension(s).
@@ -4612,10 +5766,36 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="min",
+ dim=dim,
+ skipna=skipna,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``mean`` along some dimension(s).
@@ -4693,11 +5873,37 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="mean",
+ dim=dim,
+ skipna=skipna,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``prod`` along some dimension(s).
@@ -4789,11 +5995,39 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="prod",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``sum`` along some dimension(s).
@@ -4885,10 +6119,39 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="sum",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``std`` along some dimension(s).
@@ -4977,10 +6240,39 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="std",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``var`` along some dimension(s).
@@ -5069,10 +6361,38 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="var",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``median`` along some dimension(s).
@@ -5150,10 +6470,22 @@ class DataArrayGroupByAggregations:
Coordinates:
* labels (labels) object 24B 'a' 'b' 'c'
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``cumsum`` along some dimension(s).
@@ -5233,10 +6565,22 @@ class DataArrayGroupByAggregations:
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``cumprod`` along some dimension(s).
@@ -5316,14 +6660,44 @@ class DataArrayGroupByAggregations:
* time (time) datetime64[ns] 48B 2001-01-31 2001-02-28 ... 2001-06-30
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
class DataArrayResampleAggregations:
_obj: DataArray
- def count(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> DataArray:
+ raise NotImplementedError()
+
+ def _flox_reduce(
+ self,
+ dim: Dims,
+ **kwargs: Any,
+ ) -> DataArray:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``count`` along some dimension(s).
@@ -5386,10 +6760,33 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def all(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="count",
+ dim=dim,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``all`` along some dimension(s).
@@ -5452,10 +6849,33 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def any(self, dim: Dims=None, *, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="all",
+ dim=dim,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``any`` along some dimension(s).
@@ -5518,10 +6938,34 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="any",
+ dim=dim,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``max`` along some dimension(s).
@@ -5597,10 +7041,36 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, keep_attrs:
- (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="max",
+ dim=dim,
+ skipna=skipna,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``min`` along some dimension(s).
@@ -5676,10 +7146,36 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="min",
+ dim=dim,
+ skipna=skipna,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``mean`` along some dimension(s).
@@ -5757,11 +7253,37 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="mean",
+ dim=dim,
+ skipna=skipna,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``prod`` along some dimension(s).
@@ -5853,11 +7375,39 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, keep_attrs: (bool | None)=None, **kwargs: Any
- ) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="prod",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``sum`` along some dimension(s).
@@ -5949,10 +7499,39 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="sum",
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``std`` along some dimension(s).
@@ -6041,10 +7620,39 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="std",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``var`` along some dimension(s).
@@ -6133,10 +7741,38 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ if (
+ flox_available
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="var",
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``median`` along some dimension(s).
@@ -6214,10 +7850,22 @@ class DataArrayResampleAggregations:
Coordinates:
* time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``cumsum`` along some dimension(s).
@@ -6297,10 +7945,22 @@ class DataArrayResampleAggregations:
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
Dimensions without coordinates: time
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None,
- keep_attrs: (bool | None)=None, **kwargs: Any) ->DataArray:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""
Reduce this DataArray's data by applying ``cumprod`` along some dimension(s).
@@ -6380,4 +8040,10 @@ class DataArrayResampleAggregations:
labels (time) <U1 24B 'a' 'b' 'c' 'c' 'b' 'a'
Dimensions without coordinates: time
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py
index 0059f74a..61aa1846 100644
--- a/xarray/core/_typed_ops.py
+++ b/xarray/core/_typed_ops.py
@@ -1,9 +1,21 @@
"""Mixin classes with arithmetic operators."""
+
+# This file was generated using xarray.util.generate_ops. Do not edit manually.
+
from __future__ import annotations
+
import operator
from typing import TYPE_CHECKING, Any, Callable, overload
+
from xarray.core import nputils, ops
-from xarray.core.types import DaCompatible, DsCompatible, Self, T_Xarray, VarCompatible
+from xarray.core.types import (
+ DaCompatible,
+ DsCompatible,
+ Self,
+ T_Xarray,
+ VarCompatible,
+)
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
@@ -13,138 +25,165 @@ if TYPE_CHECKING:
class DatasetOpsMixin:
__slots__ = ()
- def __add__(self, other: DsCompatible) ->Self:
+ def _binary_op(
+ self, other: DsCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
+ raise NotImplementedError
+
+ def __add__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.add)
- def __sub__(self, other: DsCompatible) ->Self:
+ def __sub__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.sub)
- def __mul__(self, other: DsCompatible) ->Self:
+ def __mul__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mul)
- def __pow__(self, other: DsCompatible) ->Self:
+ def __pow__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other: DsCompatible) ->Self:
+ def __truediv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other: DsCompatible) ->Self:
+ def __floordiv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other: DsCompatible) ->Self:
+ def __mod__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mod)
- def __and__(self, other: DsCompatible) ->Self:
+ def __and__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.and_)
- def __xor__(self, other: DsCompatible) ->Self:
+ def __xor__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.xor)
- def __or__(self, other: DsCompatible) ->Self:
+ def __or__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other: DsCompatible) ->Self:
+ def __lshift__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other: DsCompatible) ->Self:
+ def __rshift__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other: DsCompatible) ->Self:
+ def __lt__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.lt)
- def __le__(self, other: DsCompatible) ->Self:
+ def __le__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.le)
- def __gt__(self, other: DsCompatible) ->Self:
+ def __gt__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.gt)
- def __ge__(self, other: DsCompatible) ->Self:
+ def __ge__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.ge)
- def __eq__(self, other: DsCompatible) ->Self:
+ def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other: DsCompatible) ->Self:
+ def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- __hash__: None
- def __radd__(self, other: DsCompatible) ->Self:
+ # When __eq__ is defined but __hash__ is not, then an object is unhashable,
+ # and it should be declared as follows:
+ __hash__: None # type:ignore[assignment]
+
+ def __radd__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other: DsCompatible) ->Self:
+ def __rsub__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other: DsCompatible) ->Self:
+ def __rmul__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other: DsCompatible) ->Self:
+ def __rpow__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other: DsCompatible) ->Self:
+ def __rtruediv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other: DsCompatible) ->Self:
+ def __rfloordiv__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other: DsCompatible) ->Self:
+ def __rmod__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other: DsCompatible) ->Self:
+ def __rand__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other: DsCompatible) ->Self:
+ def __rxor__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other: DsCompatible) ->Self:
+ def __ror__(self, other: DsCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)
- def __iadd__(self, other: DsCompatible) ->Self:
+ def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self:
+ raise NotImplementedError
+
+ def __iadd__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.iadd)
- def __isub__(self, other: DsCompatible) ->Self:
+ def __isub__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.isub)
- def __imul__(self, other: DsCompatible) ->Self:
+ def __imul__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.imul)
- def __ipow__(self, other: DsCompatible) ->Self:
+ def __ipow__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ipow)
- def __itruediv__(self, other: DsCompatible) ->Self:
+ def __itruediv__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.itruediv)
- def __ifloordiv__(self, other: DsCompatible) ->Self:
+ def __ifloordiv__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ifloordiv)
- def __imod__(self, other: DsCompatible) ->Self:
+ def __imod__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.imod)
- def __iand__(self, other: DsCompatible) ->Self:
+ def __iand__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.iand)
- def __ixor__(self, other: DsCompatible) ->Self:
+ def __ixor__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ixor)
- def __ior__(self, other: DsCompatible) ->Self:
+ def __ior__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ior)
- def __ilshift__(self, other: DsCompatible) ->Self:
+ def __ilshift__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.ilshift)
- def __irshift__(self, other: DsCompatible) ->Self:
+ def __irshift__(self, other: DsCompatible) -> Self:
return self._inplace_binary_op(other, operator.irshift)
- def __neg__(self) ->Self:
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
+ raise NotImplementedError
+
+ def __neg__(self) -> Self:
return self._unary_op(operator.neg)
- def __pos__(self) ->Self:
+ def __pos__(self) -> Self:
return self._unary_op(operator.pos)
- def __abs__(self) ->Self:
+ def __abs__(self) -> Self:
return self._unary_op(operator.abs)
- def __invert__(self) ->Self:
+ def __invert__(self) -> Self:
return self._unary_op(operator.invert)
+
+ def round(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.round_, *args, **kwargs)
+
+ def argsort(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.argsort, *args, **kwargs)
+
+ def conj(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.conj, *args, **kwargs)
+
+ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.conjugate, *args, **kwargs)
+
__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
@@ -198,138 +237,165 @@ class DatasetOpsMixin:
class DataArrayOpsMixin:
__slots__ = ()
- def __add__(self, other: DaCompatible) ->Self:
+ def _binary_op(
+ self, other: DaCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
+ raise NotImplementedError
+
+ def __add__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.add)
- def __sub__(self, other: DaCompatible) ->Self:
+ def __sub__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.sub)
- def __mul__(self, other: DaCompatible) ->Self:
+ def __mul__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mul)
- def __pow__(self, other: DaCompatible) ->Self:
+ def __pow__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other: DaCompatible) ->Self:
+ def __truediv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other: DaCompatible) ->Self:
+ def __floordiv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other: DaCompatible) ->Self:
+ def __mod__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mod)
- def __and__(self, other: DaCompatible) ->Self:
+ def __and__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.and_)
- def __xor__(self, other: DaCompatible) ->Self:
+ def __xor__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.xor)
- def __or__(self, other: DaCompatible) ->Self:
+ def __or__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other: DaCompatible) ->Self:
+ def __lshift__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other: DaCompatible) ->Self:
+ def __rshift__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other: DaCompatible) ->Self:
+ def __lt__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.lt)
- def __le__(self, other: DaCompatible) ->Self:
+ def __le__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.le)
- def __gt__(self, other: DaCompatible) ->Self:
+ def __gt__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.gt)
- def __ge__(self, other: DaCompatible) ->Self:
+ def __ge__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.ge)
- def __eq__(self, other: DaCompatible) ->Self:
+ def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other: DaCompatible) ->Self:
+ def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- __hash__: None
- def __radd__(self, other: DaCompatible) ->Self:
+ # When __eq__ is defined but __hash__ is not, then an object is unhashable,
+ # and it should be declared as follows:
+ __hash__: None # type:ignore[assignment]
+
+ def __radd__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other: DaCompatible) ->Self:
+ def __rsub__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other: DaCompatible) ->Self:
+ def __rmul__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other: DaCompatible) ->Self:
+ def __rpow__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other: DaCompatible) ->Self:
+ def __rtruediv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other: DaCompatible) ->Self:
+ def __rfloordiv__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other: DaCompatible) ->Self:
+ def __rmod__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other: DaCompatible) ->Self:
+ def __rand__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other: DaCompatible) ->Self:
+ def __rxor__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other: DaCompatible) ->Self:
+ def __ror__(self, other: DaCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)
- def __iadd__(self, other: DaCompatible) ->Self:
+ def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self:
+ raise NotImplementedError
+
+ def __iadd__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.iadd)
- def __isub__(self, other: DaCompatible) ->Self:
+ def __isub__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.isub)
- def __imul__(self, other: DaCompatible) ->Self:
+ def __imul__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.imul)
- def __ipow__(self, other: DaCompatible) ->Self:
+ def __ipow__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ipow)
- def __itruediv__(self, other: DaCompatible) ->Self:
+ def __itruediv__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.itruediv)
- def __ifloordiv__(self, other: DaCompatible) ->Self:
+ def __ifloordiv__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ifloordiv)
- def __imod__(self, other: DaCompatible) ->Self:
+ def __imod__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.imod)
- def __iand__(self, other: DaCompatible) ->Self:
+ def __iand__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.iand)
- def __ixor__(self, other: DaCompatible) ->Self:
+ def __ixor__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ixor)
- def __ior__(self, other: DaCompatible) ->Self:
+ def __ior__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ior)
- def __ilshift__(self, other: DaCompatible) ->Self:
+ def __ilshift__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.ilshift)
- def __irshift__(self, other: DaCompatible) ->Self:
+ def __irshift__(self, other: DaCompatible) -> Self:
return self._inplace_binary_op(other, operator.irshift)
- def __neg__(self) ->Self:
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
+ raise NotImplementedError
+
+ def __neg__(self) -> Self:
return self._unary_op(operator.neg)
- def __pos__(self) ->Self:
+ def __pos__(self) -> Self:
return self._unary_op(operator.pos)
- def __abs__(self) ->Self:
+ def __abs__(self) -> Self:
return self._unary_op(operator.abs)
- def __invert__(self) ->Self:
+ def __invert__(self) -> Self:
return self._unary_op(operator.invert)
+
+ def round(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.round_, *args, **kwargs)
+
+ def argsort(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.argsort, *args, **kwargs)
+
+ def conj(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.conj, *args, **kwargs)
+
+ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.conjugate, *args, **kwargs)
+
__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
@@ -383,282 +449,273 @@ class DataArrayOpsMixin:
class VariableOpsMixin:
__slots__ = ()
+ def _binary_op(
+ self, other: VarCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
+ raise NotImplementedError
+
@overload
- def __add__(self, other: T_DA) ->T_DA:
- ...
+ def __add__(self, other: T_DA) -> T_DA: ...
@overload
- def __add__(self, other: VarCompatible) ->Self:
- ...
+ def __add__(self, other: VarCompatible) -> Self: ...
- def __add__(self, other: VarCompatible) ->(Self | T_DA):
+ def __add__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.add)
@overload
- def __sub__(self, other: T_DA) ->T_DA:
- ...
+ def __sub__(self, other: T_DA) -> T_DA: ...
@overload
- def __sub__(self, other: VarCompatible) ->Self:
- ...
+ def __sub__(self, other: VarCompatible) -> Self: ...
- def __sub__(self, other: VarCompatible) ->(Self | T_DA):
+ def __sub__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.sub)
@overload
- def __mul__(self, other: T_DA) ->T_DA:
- ...
+ def __mul__(self, other: T_DA) -> T_DA: ...
@overload
- def __mul__(self, other: VarCompatible) ->Self:
- ...
+ def __mul__(self, other: VarCompatible) -> Self: ...
- def __mul__(self, other: VarCompatible) ->(Self | T_DA):
+ def __mul__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.mul)
@overload
- def __pow__(self, other: T_DA) ->T_DA:
- ...
+ def __pow__(self, other: T_DA) -> T_DA: ...
@overload
- def __pow__(self, other: VarCompatible) ->Self:
- ...
+ def __pow__(self, other: VarCompatible) -> Self: ...
- def __pow__(self, other: VarCompatible) ->(Self | T_DA):
+ def __pow__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.pow)
@overload
- def __truediv__(self, other: T_DA) ->T_DA:
- ...
+ def __truediv__(self, other: T_DA) -> T_DA: ...
@overload
- def __truediv__(self, other: VarCompatible) ->Self:
- ...
+ def __truediv__(self, other: VarCompatible) -> Self: ...
- def __truediv__(self, other: VarCompatible) ->(Self | T_DA):
+ def __truediv__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.truediv)
@overload
- def __floordiv__(self, other: T_DA) ->T_DA:
- ...
+ def __floordiv__(self, other: T_DA) -> T_DA: ...
@overload
- def __floordiv__(self, other: VarCompatible) ->Self:
- ...
+ def __floordiv__(self, other: VarCompatible) -> Self: ...
- def __floordiv__(self, other: VarCompatible) ->(Self | T_DA):
+ def __floordiv__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.floordiv)
@overload
- def __mod__(self, other: T_DA) ->T_DA:
- ...
+ def __mod__(self, other: T_DA) -> T_DA: ...
@overload
- def __mod__(self, other: VarCompatible) ->Self:
- ...
+ def __mod__(self, other: VarCompatible) -> Self: ...
- def __mod__(self, other: VarCompatible) ->(Self | T_DA):
+ def __mod__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.mod)
@overload
- def __and__(self, other: T_DA) ->T_DA:
- ...
+ def __and__(self, other: T_DA) -> T_DA: ...
@overload
- def __and__(self, other: VarCompatible) ->Self:
- ...
+ def __and__(self, other: VarCompatible) -> Self: ...
- def __and__(self, other: VarCompatible) ->(Self | T_DA):
+ def __and__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.and_)
@overload
- def __xor__(self, other: T_DA) ->T_DA:
- ...
+ def __xor__(self, other: T_DA) -> T_DA: ...
@overload
- def __xor__(self, other: VarCompatible) ->Self:
- ...
+ def __xor__(self, other: VarCompatible) -> Self: ...
- def __xor__(self, other: VarCompatible) ->(Self | T_DA):
+ def __xor__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.xor)
@overload
- def __or__(self, other: T_DA) ->T_DA:
- ...
+ def __or__(self, other: T_DA) -> T_DA: ...
@overload
- def __or__(self, other: VarCompatible) ->Self:
- ...
+ def __or__(self, other: VarCompatible) -> Self: ...
- def __or__(self, other: VarCompatible) ->(Self | T_DA):
+ def __or__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.or_)
@overload
- def __lshift__(self, other: T_DA) ->T_DA:
- ...
+ def __lshift__(self, other: T_DA) -> T_DA: ...
@overload
- def __lshift__(self, other: VarCompatible) ->Self:
- ...
+ def __lshift__(self, other: VarCompatible) -> Self: ...
- def __lshift__(self, other: VarCompatible) ->(Self | T_DA):
+ def __lshift__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.lshift)
@overload
- def __rshift__(self, other: T_DA) ->T_DA:
- ...
+ def __rshift__(self, other: T_DA) -> T_DA: ...
@overload
- def __rshift__(self, other: VarCompatible) ->Self:
- ...
+ def __rshift__(self, other: VarCompatible) -> Self: ...
- def __rshift__(self, other: VarCompatible) ->(Self | T_DA):
+ def __rshift__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.rshift)
@overload
- def __lt__(self, other: T_DA) ->T_DA:
- ...
+ def __lt__(self, other: T_DA) -> T_DA: ...
@overload
- def __lt__(self, other: VarCompatible) ->Self:
- ...
+ def __lt__(self, other: VarCompatible) -> Self: ...
- def __lt__(self, other: VarCompatible) ->(Self | T_DA):
+ def __lt__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.lt)
@overload
- def __le__(self, other: T_DA) ->T_DA:
- ...
+ def __le__(self, other: T_DA) -> T_DA: ...
@overload
- def __le__(self, other: VarCompatible) ->Self:
- ...
+ def __le__(self, other: VarCompatible) -> Self: ...
- def __le__(self, other: VarCompatible) ->(Self | T_DA):
+ def __le__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.le)
@overload
- def __gt__(self, other: T_DA) ->T_DA:
- ...
+ def __gt__(self, other: T_DA) -> T_DA: ...
@overload
- def __gt__(self, other: VarCompatible) ->Self:
- ...
+ def __gt__(self, other: VarCompatible) -> Self: ...
- def __gt__(self, other: VarCompatible) ->(Self | T_DA):
+ def __gt__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.gt)
@overload
- def __ge__(self, other: T_DA) ->T_DA:
- ...
+ def __ge__(self, other: T_DA) -> T_DA: ...
@overload
- def __ge__(self, other: VarCompatible) ->Self:
- ...
+ def __ge__(self, other: VarCompatible) -> Self: ...
- def __ge__(self, other: VarCompatible) ->(Self | T_DA):
+ def __ge__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, operator.ge)
- @overload
- def __eq__(self, other: T_DA) ->T_DA:
- ...
+ @overload # type:ignore[override]
+ def __eq__(self, other: T_DA) -> T_DA: ...
@overload
- def __eq__(self, other: VarCompatible) ->Self:
- ...
+ def __eq__(self, other: VarCompatible) -> Self: ...
- def __eq__(self, other: VarCompatible) ->(Self | T_DA):
+ def __eq__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, nputils.array_eq)
- @overload
- def __ne__(self, other: T_DA) ->T_DA:
- ...
+ @overload # type:ignore[override]
+ def __ne__(self, other: T_DA) -> T_DA: ...
@overload
- def __ne__(self, other: VarCompatible) ->Self:
- ...
+ def __ne__(self, other: VarCompatible) -> Self: ...
- def __ne__(self, other: VarCompatible) ->(Self | T_DA):
+ def __ne__(self, other: VarCompatible) -> Self | T_DA:
return self._binary_op(other, nputils.array_ne)
- __hash__: None
- def __radd__(self, other: VarCompatible) ->Self:
+ # When __eq__ is defined but __hash__ is not, then an object is unhashable,
+ # and it should be declared as follows:
+ __hash__: None # type:ignore[assignment]
+
+ def __radd__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other: VarCompatible) ->Self:
+ def __rsub__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other: VarCompatible) ->Self:
+ def __rmul__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other: VarCompatible) ->Self:
+ def __rpow__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other: VarCompatible) ->Self:
+ def __rtruediv__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other: VarCompatible) ->Self:
+ def __rfloordiv__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other: VarCompatible) ->Self:
+ def __rmod__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other: VarCompatible) ->Self:
+ def __rand__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other: VarCompatible) ->Self:
+ def __rxor__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other: VarCompatible) ->Self:
+ def __ror__(self, other: VarCompatible) -> Self:
return self._binary_op(other, operator.or_, reflexive=True)
- def __iadd__(self, other: VarCompatible) ->Self:
+ def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self:
+ raise NotImplementedError
+
+ def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.iadd)
- def __isub__(self, other: VarCompatible) ->Self:
+ def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.isub)
- def __imul__(self, other: VarCompatible) ->Self:
+ def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.imul)
- def __ipow__(self, other: VarCompatible) ->Self:
+ def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ipow)
- def __itruediv__(self, other: VarCompatible) ->Self:
+ def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.itruediv)
- def __ifloordiv__(self, other: VarCompatible) ->Self:
+ def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ifloordiv)
- def __imod__(self, other: VarCompatible) ->Self:
+ def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.imod)
- def __iand__(self, other: VarCompatible) ->Self:
+ def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.iand)
- def __ixor__(self, other: VarCompatible) ->Self:
+ def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ixor)
- def __ior__(self, other: VarCompatible) ->Self:
+ def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ior)
- def __ilshift__(self, other: VarCompatible) ->Self:
+ def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.ilshift)
- def __irshift__(self, other: VarCompatible) ->Self:
+ def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc]
return self._inplace_binary_op(other, operator.irshift)
- def __neg__(self) ->Self:
+ def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
+ raise NotImplementedError
+
+ def __neg__(self) -> Self:
return self._unary_op(operator.neg)
- def __pos__(self) ->Self:
+ def __pos__(self) -> Self:
return self._unary_op(operator.pos)
- def __abs__(self) ->Self:
+ def __abs__(self) -> Self:
return self._unary_op(operator.abs)
- def __invert__(self) ->Self:
+ def __invert__(self) -> Self:
return self._unary_op(operator.invert)
+
+ def round(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.round_, *args, **kwargs)
+
+ def argsort(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.argsort, *args, **kwargs)
+
+ def conj(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.conj, *args, **kwargs)
+
+ def conjugate(self, *args: Any, **kwargs: Any) -> Self:
+ return self._unary_op(ops.conjugate, *args, **kwargs)
+
__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
@@ -712,90 +769,99 @@ class VariableOpsMixin:
class DatasetGroupByOpsMixin:
__slots__ = ()
- def __add__(self, other: (Dataset | DataArray)) ->Dataset:
+ def _binary_op(
+ self, other: Dataset | DataArray, f: Callable, reflexive: bool = False
+ ) -> Dataset:
+ raise NotImplementedError
+
+ def __add__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.add)
- def __sub__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __sub__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.sub)
- def __mul__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __mul__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mul)
- def __pow__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __pow__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __truediv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __floordiv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __mod__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mod)
- def __and__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __and__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.and_)
- def __xor__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __xor__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.xor)
- def __or__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __or__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __lshift__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rshift__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __lt__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.lt)
- def __le__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __le__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.le)
- def __gt__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __gt__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.gt)
- def __ge__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __ge__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.ge)
- def __eq__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __eq__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __ne__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- __hash__: None
- def __radd__(self, other: (Dataset | DataArray)) ->Dataset:
+ # When __eq__ is defined but __hash__ is not, then an object is unhashable,
+ # and it should be declared as follows:
+ __hash__: None # type:ignore[assignment]
+
+ def __radd__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rsub__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rmul__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rpow__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rtruediv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rfloordiv__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rmod__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rand__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __rxor__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other: (Dataset | DataArray)) ->Dataset:
+ def __ror__(self, other: Dataset | DataArray) -> Dataset:
return self._binary_op(other, operator.or_, reflexive=True)
+
__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
@@ -829,90 +895,99 @@ class DatasetGroupByOpsMixin:
class DataArrayGroupByOpsMixin:
__slots__ = ()
- def __add__(self, other: T_Xarray) ->T_Xarray:
+ def _binary_op(
+ self, other: T_Xarray, f: Callable, reflexive: bool = False
+ ) -> T_Xarray:
+ raise NotImplementedError
+
+ def __add__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.add)
- def __sub__(self, other: T_Xarray) ->T_Xarray:
+ def __sub__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.sub)
- def __mul__(self, other: T_Xarray) ->T_Xarray:
+ def __mul__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mul)
- def __pow__(self, other: T_Xarray) ->T_Xarray:
+ def __pow__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.pow)
- def __truediv__(self, other: T_Xarray) ->T_Xarray:
+ def __truediv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.truediv)
- def __floordiv__(self, other: T_Xarray) ->T_Xarray:
+ def __floordiv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.floordiv)
- def __mod__(self, other: T_Xarray) ->T_Xarray:
+ def __mod__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mod)
- def __and__(self, other: T_Xarray) ->T_Xarray:
+ def __and__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.and_)
- def __xor__(self, other: T_Xarray) ->T_Xarray:
+ def __xor__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.xor)
- def __or__(self, other: T_Xarray) ->T_Xarray:
+ def __or__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.or_)
- def __lshift__(self, other: T_Xarray) ->T_Xarray:
+ def __lshift__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.lshift)
- def __rshift__(self, other: T_Xarray) ->T_Xarray:
+ def __rshift__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.rshift)
- def __lt__(self, other: T_Xarray) ->T_Xarray:
+ def __lt__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.lt)
- def __le__(self, other: T_Xarray) ->T_Xarray:
+ def __le__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.le)
- def __gt__(self, other: T_Xarray) ->T_Xarray:
+ def __gt__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.gt)
- def __ge__(self, other: T_Xarray) ->T_Xarray:
+ def __ge__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.ge)
- def __eq__(self, other: T_Xarray) ->T_Xarray:
+ def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override]
return self._binary_op(other, nputils.array_eq)
- def __ne__(self, other: T_Xarray) ->T_Xarray:
+ def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override]
return self._binary_op(other, nputils.array_ne)
- __hash__: None
- def __radd__(self, other: T_Xarray) ->T_Xarray:
+ # When __eq__ is defined but __hash__ is not, then an object is unhashable,
+ # and it should be declared as follows:
+ __hash__: None # type:ignore[assignment]
+
+ def __radd__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.add, reflexive=True)
- def __rsub__(self, other: T_Xarray) ->T_Xarray:
+ def __rsub__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.sub, reflexive=True)
- def __rmul__(self, other: T_Xarray) ->T_Xarray:
+ def __rmul__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mul, reflexive=True)
- def __rpow__(self, other: T_Xarray) ->T_Xarray:
+ def __rpow__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.pow, reflexive=True)
- def __rtruediv__(self, other: T_Xarray) ->T_Xarray:
+ def __rtruediv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.truediv, reflexive=True)
- def __rfloordiv__(self, other: T_Xarray) ->T_Xarray:
+ def __rfloordiv__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.floordiv, reflexive=True)
- def __rmod__(self, other: T_Xarray) ->T_Xarray:
+ def __rmod__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.mod, reflexive=True)
- def __rand__(self, other: T_Xarray) ->T_Xarray:
+ def __rand__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.and_, reflexive=True)
- def __rxor__(self, other: T_Xarray) ->T_Xarray:
+ def __rxor__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.xor, reflexive=True)
- def __ror__(self, other: T_Xarray) ->T_Xarray:
+ def __ror__(self, other: T_Xarray) -> T_Xarray:
return self._binary_op(other, operator.or_, reflexive=True)
+
__add__.__doc__ = operator.add.__doc__
__sub__.__doc__ = operator.sub.__doc__
__mul__.__doc__ = operator.mul.__doc__
diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py
index 87474ec9..41b982d2 100644
--- a/xarray/core/accessor_dt.py
+++ b/xarray/core/accessor_dt.py
@@ -1,16 +1,25 @@
from __future__ import annotations
+
import warnings
from typing import TYPE_CHECKING, Generic
+
import numpy as np
import pandas as pd
+
from xarray.coding.times import infer_calendar_name
from xarray.core import duck_array_ops
-from xarray.core.common import _contains_datetime_like_objects, is_np_datetime_like, is_np_timedelta_like
+from xarray.core.common import (
+ _contains_datetime_like_objects,
+ is_np_datetime_like,
+ is_np_timedelta_like,
+)
from xarray.core.types import T_DataArray
from xarray.core.variable import IndexVariable
from xarray.namedarray.utils import is_duck_dask_array
+
if TYPE_CHECKING:
from numpy.typing import DTypeLike
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import CFCalendar
@@ -18,21 +27,77 @@ if TYPE_CHECKING:
def _season_from_months(months):
"""Compute season (DJF, MAM, JJA, SON) from month ordinal"""
- pass
+ # TODO: Move "season" accessor upstream into pandas
+ seasons = np.array(["DJF", "MAM", "JJA", "SON", "nan"])
+ months = np.asarray(months)
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore", message="invalid value encountered in floor_divide"
+ )
+ warnings.filterwarnings(
+ "ignore", message="invalid value encountered in remainder"
+ )
+ idx = (months // 3) % 4
+
+ idx[np.isnan(idx)] = 4
+ return seasons[idx.astype(int)]
def _access_through_cftimeindex(values, name):
"""Coerce an array of datetime-like values to a CFTimeIndex
and access requested datetime component
"""
- pass
+ from xarray.coding.cftimeindex import CFTimeIndex
+
+ if not isinstance(values, CFTimeIndex):
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
+ else:
+ values_as_cftimeindex = values
+ if name == "season":
+ months = values_as_cftimeindex.month
+ field_values = _season_from_months(months)
+ elif name == "date":
+ raise AttributeError(
+ "'CFTimeIndex' object has no attribute `date`. Consider using the floor method "
+ "instead, for instance: `.time.dt.floor('D')`."
+ )
+ else:
+ field_values = getattr(values_as_cftimeindex, name)
+ return field_values.reshape(values.shape)
def _access_through_series(values, name):
"""Coerce an array of datetime-like values to a pandas Series and
access requested datetime component
"""
- pass
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
+ if name == "season":
+ months = values_as_series.dt.month.values
+ field_values = _season_from_months(months)
+ elif name == "total_seconds":
+ field_values = values_as_series.dt.total_seconds().values
+ elif name == "isocalendar":
+ # special NaT-handling can be removed when
+ # https://github.com/pandas-dev/pandas/issues/54657 is resolved
+ field_values = values_as_series.dt.isocalendar()
+ # test for <NA> and apply needed dtype
+ hasna = any(field_values.year.isnull())
+ if hasna:
+ field_values = np.dstack(
+ [
+ getattr(field_values, name).astype(np.float64, copy=False).values
+ for name in ["year", "week", "day"]
+ ]
+ )
+ else:
+ field_values = np.array(field_values, dtype=np.int64)
+ # isocalendar returns iso- year, week, and weekday -> reshape
+ return field_values.T.reshape(3, *values.shape)
+ else:
+ field_values = getattr(values_as_series.dt, name).values
+
+ return field_values.reshape(values.shape)
def _get_date_field(values, name, dtype):
@@ -54,14 +119,48 @@ def _get_date_field(values, name, dtype):
Array-like of datetime fields accessed for each element in values
"""
- pass
+ if is_np_datetime_like(values.dtype):
+ access_method = _access_through_series
+ else:
+ access_method = _access_through_cftimeindex
+
+ if is_duck_dask_array(values):
+ from dask.array import map_blocks
+
+ new_axis = chunks = None
+ # isocalendar adds an axis
+ if name == "isocalendar":
+ chunks = (3,) + values.chunksize
+ new_axis = 0
+
+ return map_blocks(
+ access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks
+ )
+ else:
+ out = access_method(values, name)
+ # cast only for integer types to keep float64 in presence of NaT
+ # see https://github.com/pydata/xarray/issues/7928
+ if np.issubdtype(out.dtype, np.integer):
+ out = out.astype(dtype, copy=False)
+ return out
def _round_through_series_or_index(values, name, freq):
"""Coerce an array of datetime-like values to a pandas Series or xarray
CFTimeIndex and apply requested rounding
"""
- pass
+ from xarray.coding.cftimeindex import CFTimeIndex
+
+ if is_np_datetime_like(values.dtype):
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
+ method = getattr(values_as_series.dt, name)
+ else:
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
+ method = getattr(values_as_cftimeindex, name)
+
+ field_values = method(freq=freq).values
+
+ return field_values.reshape(values.shape)
def _round_field(values, name, freq):
@@ -83,30 +182,77 @@ def _round_field(values, name, freq):
Array-like of datetime fields accessed for each element in values
"""
- pass
+ if is_duck_dask_array(values):
+ from dask.array import map_blocks
+
+ dtype = np.datetime64 if is_np_datetime_like(values.dtype) else np.dtype("O")
+ return map_blocks(
+ _round_through_series_or_index, values, name, freq=freq, dtype=dtype
+ )
+ else:
+ return _round_through_series_or_index(values, name, freq)
def _strftime_through_cftimeindex(values, date_format: str):
"""Coerce an array of cftime-like values to a CFTimeIndex
and access requested datetime component
"""
- pass
+ from xarray.coding.cftimeindex import CFTimeIndex
+
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
+
+ field_values = values_as_cftimeindex.strftime(date_format)
+ return field_values.values.reshape(values.shape)
def _strftime_through_series(values, date_format: str):
"""Coerce an array of datetime-like values to a pandas Series and
apply string formatting
"""
- pass
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
+ strs = values_as_series.dt.strftime(date_format)
+ return strs.values.reshape(values.shape)
+
+
+def _strftime(values, date_format):
+ if is_np_datetime_like(values.dtype):
+ access_method = _strftime_through_series
+ else:
+ access_method = _strftime_through_cftimeindex
+ if is_duck_dask_array(values):
+ from dask.array import map_blocks
+
+ return map_blocks(access_method, values, date_format)
+ else:
+ return access_method(values, date_format)
+
+
+def _index_or_data(obj):
+ if isinstance(obj.variable, IndexVariable):
+ return obj.to_index()
+ else:
+ return obj.data
class TimeAccessor(Generic[T_DataArray]):
- __slots__ = '_obj',
+ __slots__ = ("_obj",)
- def __init__(self, obj: T_DataArray) ->None:
+ def __init__(self, obj: T_DataArray) -> None:
self._obj = obj
- def floor(self, freq: str) ->T_DataArray:
+ def _date_field(self, name: str, dtype: DTypeLike) -> T_DataArray:
+ if dtype is None:
+ dtype = self._obj.dtype
+ result = _get_date_field(_index_or_data(self._obj), name, dtype)
+ newvar = self._obj.variable.copy(data=result, deep=False)
+ return self._obj._replace(newvar, name=name)
+
+ def _tslib_round_accessor(self, name: str, freq: str) -> T_DataArray:
+ result = _round_field(_index_or_data(self._obj), name, freq)
+ newvar = self._obj.variable.copy(data=result, deep=False)
+ return self._obj._replace(newvar, name=name)
+
+ def floor(self, freq: str) -> T_DataArray:
"""
Round timestamps downward to specified frequency resolution.
@@ -120,9 +266,10 @@ class TimeAccessor(Generic[T_DataArray]):
floor-ed timestamps : same type as values
Array-like of datetime fields accessed for each element in values
"""
- pass
- def ceil(self, freq: str) ->T_DataArray:
+ return self._tslib_round_accessor("floor", freq)
+
+ def ceil(self, freq: str) -> T_DataArray:
"""
Round timestamps upward to specified frequency resolution.
@@ -136,9 +283,9 @@ class TimeAccessor(Generic[T_DataArray]):
ceil-ed timestamps : same type as values
Array-like of datetime fields accessed for each element in values
"""
- pass
+ return self._tslib_round_accessor("ceil", freq)
- def round(self, freq: str) ->T_DataArray:
+ def round(self, freq: str) -> T_DataArray:
"""
Round timestamps to specified frequency resolution.
@@ -152,7 +299,7 @@ class TimeAccessor(Generic[T_DataArray]):
rounded timestamps : same type as values
Array-like of datetime fields accessed for each element in values
"""
- pass
+ return self._tslib_round_accessor("round", freq)
class DatetimeAccessor(TimeAccessor[T_DataArray]):
@@ -190,7 +337,7 @@ class DatetimeAccessor(TimeAccessor[T_DataArray]):
"""
- def strftime(self, date_format: str) ->T_DataArray:
+ def strftime(self, date_format: str) -> T_DataArray:
"""
Return an array of formatted strings specified by date_format, which
supports the same string format as the python standard library. Details
@@ -215,143 +362,176 @@ class DatetimeAccessor(TimeAccessor[T_DataArray]):
<xarray.DataArray 'strftime' ()> Size: 8B
array('January 01, 2000, 12:00:00 AM', dtype=object)
"""
- pass
+ obj_type = type(self._obj)
+
+ result = _strftime(self._obj.data, date_format)
- def isocalendar(self) ->Dataset:
+ return obj_type(
+ result, name="strftime", coords=self._obj.coords, dims=self._obj.dims
+ )
+
+ def isocalendar(self) -> Dataset:
"""Dataset containing ISO year, week number, and weekday.
Notes
-----
The iso year and weekday differ from the nominal year and weekday.
"""
- pass
+
+ from xarray.core.dataset import Dataset
+
+ if not is_np_datetime_like(self._obj.data.dtype):
+ raise AttributeError("'CFTimeIndex' object has no attribute 'isocalendar'")
+
+ values = _get_date_field(self._obj.data, "isocalendar", np.int64)
+
+ obj_type = type(self._obj)
+ data_vars = {}
+ for i, name in enumerate(["year", "week", "weekday"]):
+ data_vars[name] = obj_type(
+ values[i], name=name, coords=self._obj.coords, dims=self._obj.dims
+ )
+
+ return Dataset(data_vars)
@property
- def year(self) ->T_DataArray:
+ def year(self) -> T_DataArray:
"""The year of the datetime"""
- pass
+ return self._date_field("year", np.int64)
@property
- def month(self) ->T_DataArray:
+ def month(self) -> T_DataArray:
"""The month as January=1, December=12"""
- pass
+ return self._date_field("month", np.int64)
@property
- def day(self) ->T_DataArray:
+ def day(self) -> T_DataArray:
"""The days of the datetime"""
- pass
+ return self._date_field("day", np.int64)
@property
- def hour(self) ->T_DataArray:
+ def hour(self) -> T_DataArray:
"""The hours of the datetime"""
- pass
+ return self._date_field("hour", np.int64)
@property
- def minute(self) ->T_DataArray:
+ def minute(self) -> T_DataArray:
"""The minutes of the datetime"""
- pass
+ return self._date_field("minute", np.int64)
@property
- def second(self) ->T_DataArray:
+ def second(self) -> T_DataArray:
"""The seconds of the datetime"""
- pass
+ return self._date_field("second", np.int64)
@property
- def microsecond(self) ->T_DataArray:
+ def microsecond(self) -> T_DataArray:
"""The microseconds of the datetime"""
- pass
+ return self._date_field("microsecond", np.int64)
@property
- def nanosecond(self) ->T_DataArray:
+ def nanosecond(self) -> T_DataArray:
"""The nanoseconds of the datetime"""
- pass
+ return self._date_field("nanosecond", np.int64)
@property
- def weekofyear(self) ->DataArray:
- """The week ordinal of the year"""
- pass
+ def weekofyear(self) -> DataArray:
+ "The week ordinal of the year"
+
+ warnings.warn(
+ "dt.weekofyear and dt.week have been deprecated. Please use "
+ "dt.isocalendar().week instead.",
+ FutureWarning,
+ )
+
+ weekofyear = self.isocalendar().week
+
+ return weekofyear
+
week = weekofyear
@property
- def dayofweek(self) ->T_DataArray:
+ def dayofweek(self) -> T_DataArray:
"""The day of the week with Monday=0, Sunday=6"""
- pass
+ return self._date_field("dayofweek", np.int64)
+
weekday = dayofweek
@property
- def dayofyear(self) ->T_DataArray:
+ def dayofyear(self) -> T_DataArray:
"""The ordinal day of the year"""
- pass
+ return self._date_field("dayofyear", np.int64)
@property
- def quarter(self) ->T_DataArray:
+ def quarter(self) -> T_DataArray:
"""The quarter of the date"""
- pass
+ return self._date_field("quarter", np.int64)
@property
- def days_in_month(self) ->T_DataArray:
+ def days_in_month(self) -> T_DataArray:
"""The number of days in the month"""
- pass
+ return self._date_field("days_in_month", np.int64)
+
daysinmonth = days_in_month
@property
- def season(self) ->T_DataArray:
+ def season(self) -> T_DataArray:
"""Season of the year"""
- pass
+ return self._date_field("season", object)
@property
- def time(self) ->T_DataArray:
+ def time(self) -> T_DataArray:
"""Timestamps corresponding to datetimes"""
- pass
+ return self._date_field("time", object)
@property
- def date(self) ->T_DataArray:
+ def date(self) -> T_DataArray:
"""Date corresponding to datetimes"""
- pass
+ return self._date_field("date", object)
@property
- def is_month_start(self) ->T_DataArray:
+ def is_month_start(self) -> T_DataArray:
"""Indicate whether the date is the first day of the month"""
- pass
+ return self._date_field("is_month_start", bool)
@property
- def is_month_end(self) ->T_DataArray:
+ def is_month_end(self) -> T_DataArray:
"""Indicate whether the date is the last day of the month"""
- pass
+ return self._date_field("is_month_end", bool)
@property
- def is_quarter_start(self) ->T_DataArray:
+ def is_quarter_start(self) -> T_DataArray:
"""Indicate whether the date is the first day of a quarter"""
- pass
+ return self._date_field("is_quarter_start", bool)
@property
- def is_quarter_end(self) ->T_DataArray:
+ def is_quarter_end(self) -> T_DataArray:
"""Indicate whether the date is the last day of a quarter"""
- pass
+ return self._date_field("is_quarter_end", bool)
@property
- def is_year_start(self) ->T_DataArray:
+ def is_year_start(self) -> T_DataArray:
"""Indicate whether the date is the first day of a year"""
- pass
+ return self._date_field("is_year_start", bool)
@property
- def is_year_end(self) ->T_DataArray:
+ def is_year_end(self) -> T_DataArray:
"""Indicate whether the date is the last day of the year"""
- pass
+ return self._date_field("is_year_end", bool)
@property
- def is_leap_year(self) ->T_DataArray:
+ def is_leap_year(self) -> T_DataArray:
"""Indicate if the date belongs to a leap year"""
- pass
+ return self._date_field("is_leap_year", bool)
@property
- def calendar(self) ->CFCalendar:
+ def calendar(self) -> CFCalendar:
"""The name of the calendar of the dates.
Only relevant for arrays of :py:class:`cftime.datetime` objects,
returns "proleptic_gregorian" for arrays of :py:class:`numpy.datetime64` values.
"""
- pass
+ return infer_calendar_name(self._obj.data)
class TimedeltaAccessor(TimeAccessor[T_DataArray]):
@@ -402,39 +582,52 @@ class TimedeltaAccessor(TimeAccessor[T_DataArray]):
"""
@property
- def days(self) ->T_DataArray:
+ def days(self) -> T_DataArray:
"""Number of days for each element"""
- pass
+ return self._date_field("days", np.int64)
@property
- def seconds(self) ->T_DataArray:
+ def seconds(self) -> T_DataArray:
"""Number of seconds (>= 0 and less than 1 day) for each element"""
- pass
+ return self._date_field("seconds", np.int64)
@property
- def microseconds(self) ->T_DataArray:
+ def microseconds(self) -> T_DataArray:
"""Number of microseconds (>= 0 and less than 1 second) for each element"""
- pass
+ return self._date_field("microseconds", np.int64)
@property
- def nanoseconds(self) ->T_DataArray:
+ def nanoseconds(self) -> T_DataArray:
"""Number of nanoseconds (>= 0 and less than 1 microsecond) for each element"""
- pass
+ return self._date_field("nanoseconds", np.int64)
- def total_seconds(self) ->T_DataArray:
+ # Not defined as a property in order to match the Pandas API
+ def total_seconds(self) -> T_DataArray:
"""Total duration of each element expressed in seconds."""
- pass
+ return self._date_field("total_seconds", np.float64)
-class CombinedDatetimelikeAccessor(DatetimeAccessor[T_DataArray],
- TimedeltaAccessor[T_DataArray]):
-
- def __new__(cls, obj: T_DataArray) ->CombinedDatetimelikeAccessor:
+class CombinedDatetimelikeAccessor(
+ DatetimeAccessor[T_DataArray], TimedeltaAccessor[T_DataArray]
+):
+ def __new__(cls, obj: T_DataArray) -> CombinedDatetimelikeAccessor:
+ # CombinedDatetimelikeAccessor isn't really instantiated. Instead
+ # we need to choose which parent (datetime or timedelta) is
+ # appropriate. Since we're checking the dtypes anyway, we'll just
+ # do all the validation here.
if not _contains_datetime_like_objects(obj.variable):
+ # We use an AttributeError here so that `obj.dt` raises an error that
+ # `getattr` expects; https://github.com/pydata/xarray/issues/8718. It's a
+ # bit unusual in a `__new__`, but that's the only case where we use this
+ # class.
raise AttributeError(
- "'.dt' accessor only available for DataArray with datetime64 timedelta64 dtype or for arrays containing cftime datetime objects."
- )
+ "'.dt' accessor only available for "
+ "DataArray with datetime64 timedelta64 dtype or "
+ "for arrays containing cftime datetime "
+ "objects."
+ )
+
if is_np_timedelta_like(obj.dtype):
- return TimedeltaAccessor(obj)
+ return TimedeltaAccessor(obj) # type: ignore[return-value]
else:
- return DatetimeAccessor(obj)
+ return DatetimeAccessor(obj) # type: ignore[return-value]
diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py
index a8a779e5..a48fbc91 100644
--- a/xarray/core/accessor_str.py
+++ b/xarray/core/accessor_str.py
@@ -1,4 +1,44 @@
+# The StringAccessor class defined below is an adaptation of the
+# pandas string methods source code (see pd.core.strings)
+
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
from __future__ import annotations
+
import codecs
import re
import textwrap
@@ -8,41 +48,100 @@ from operator import or_ as set_union
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Generic
from unicodedata import normalize
+
import numpy as np
+
from xarray.core import duck_array_ops
from xarray.core.computation import apply_ufunc
from xarray.core.types import T_DataArray
+
if TYPE_CHECKING:
from numpy.typing import DTypeLike
+
from xarray.core.dataarray import DataArray
-_cpython_optimized_encoders = ('utf-8', 'utf8', 'latin-1', 'latin1',
- 'iso-8859-1', 'mbcs', 'ascii')
-_cpython_optimized_decoders = _cpython_optimized_encoders + ('utf-16', 'utf-32'
- )
+_cpython_optimized_encoders = (
+ "utf-8",
+ "utf8",
+ "latin-1",
+ "latin1",
+ "iso-8859-1",
+ "mbcs",
+ "ascii",
+)
+_cpython_optimized_decoders = _cpython_optimized_encoders + ("utf-16", "utf-32")
-def _contains_obj_type(*, pat: Any, checker: Any) ->bool:
+
+def _contains_obj_type(*, pat: Any, checker: Any) -> bool:
"""Determine if the object fits some rule or is array of objects that do so."""
- pass
+ if isinstance(checker, type):
+ targtype = checker
+ checker = lambda x: isinstance(x, targtype)
+
+ if checker(pat):
+ return True
+
+ # If it is not an object array it can't contain compiled re
+ if getattr(pat, "dtype", "no") != np.object_:
+ return False
+
+ return _apply_str_ufunc(func=checker, obj=pat).all()
-def _contains_str_like(pat: Any) ->bool:
+def _contains_str_like(pat: Any) -> bool:
"""Determine if the object is a str-like or array of str-like."""
- pass
+ if isinstance(pat, (str, bytes)):
+ return True
+ if not hasattr(pat, "dtype"):
+ return False
-def _contains_compiled_re(pat: Any) ->bool:
+ return pat.dtype.kind in ["U", "S"]
+
+
+def _contains_compiled_re(pat: Any) -> bool:
"""Determine if the object is a compiled re or array of compiled re."""
- pass
+ return _contains_obj_type(pat=pat, checker=re.Pattern)
-def _contains_callable(pat: Any) ->bool:
+def _contains_callable(pat: Any) -> bool:
"""Determine if the object is a callable or array of callables."""
- pass
+ return _contains_obj_type(pat=pat, checker=callable)
+
+
+def _apply_str_ufunc(
+ *,
+ func: Callable,
+ obj: Any,
+ dtype: DTypeLike = None,
+ output_core_dims: list | tuple = ((),),
+ output_sizes: Mapping[Any, int] | None = None,
+ func_args: tuple = (),
+ func_kwargs: Mapping = {},
+) -> Any:
+ # TODO handling of na values ?
+ if dtype is None:
+ dtype = obj.dtype
+
+ dask_gufunc_kwargs = dict()
+ if output_sizes is not None:
+ dask_gufunc_kwargs["output_sizes"] = output_sizes
+
+ return apply_ufunc(
+ func,
+ obj,
+ *func_args,
+ vectorize=True,
+ dask="parallelized",
+ output_dtypes=[dtype],
+ output_core_dims=output_core_dims,
+ dask_gufunc_kwargs=dask_gufunc_kwargs,
+ **func_kwargs,
+ )
class StringAccessor(Generic[T_DataArray]):
- """Vectorized string functions for string-like arrays.
+ r"""Vectorized string functions for string-like arrays.
Similar to pandas, fields can be accessed through the `.str` attribute
for applicable DataArrays.
@@ -99,24 +198,88 @@ class StringAccessor(Generic[T_DataArray]):
>>> da2 = xr.DataArray([1, 2, 3], dims=["Y"])
>>> da1 % {"a": da2}
<xarray.DataArray (X: 1)> Size: 8B
- array(['<xarray.DataArray (Y: 3)> Size: 24B\\narray([1, 2, 3])\\nDimensions without coordinates: Y'],
+ array(['<xarray.DataArray (Y: 3)> Size: 24B\narray([1, 2, 3])\nDimensions without coordinates: Y'],
dtype=object)
Dimensions without coordinates: X
"""
- __slots__ = '_obj',
- def __init__(self, obj: T_DataArray) ->None:
+ __slots__ = ("_obj",)
+
+ def __init__(self, obj: T_DataArray) -> None:
self._obj = obj
- def _stringify(self, invar: Any) ->(str | bytes | Any):
+ def _stringify(self, invar: Any) -> str | bytes | Any:
"""
Convert a string-like to the correct string/bytes type.
This is mostly here to tell mypy a pattern is a str/bytes not a re.Pattern.
"""
- pass
+ if hasattr(invar, "astype"):
+ return invar.astype(self._obj.dtype.kind)
+ else:
+ return self._obj.dtype.type(invar)
+
+ def _apply(
+ self,
+ *,
+ func: Callable,
+ dtype: DTypeLike = None,
+ output_core_dims: list | tuple = ((),),
+ output_sizes: Mapping[Any, int] | None = None,
+ func_args: tuple = (),
+ func_kwargs: Mapping = {},
+ ) -> T_DataArray:
+ return _apply_str_ufunc(
+ obj=self._obj,
+ func=func,
+ dtype=dtype,
+ output_core_dims=output_core_dims,
+ output_sizes=output_sizes,
+ func_args=func_args,
+ func_kwargs=func_kwargs,
+ )
+
+ def _re_compile(
+ self,
+ *,
+ pat: str | bytes | Pattern | Any,
+ flags: int = 0,
+ case: bool | None = None,
+ ) -> Pattern | Any:
+ is_compiled_re = isinstance(pat, re.Pattern)
+
+ if is_compiled_re and flags != 0:
+ raise ValueError("Flags cannot be set when pat is a compiled regex.")
+
+ if is_compiled_re and case is not None:
+ raise ValueError("Case cannot be set when pat is a compiled regex.")
+
+ if is_compiled_re:
+ # no-op, needed to tell mypy this isn't a string
+ return re.compile(pat)
+
+ if case is None:
+ case = True
+
+ # The case is handled by the re flags internally.
+ # Add it to the flags if necessary.
+ if not case:
+ flags |= re.IGNORECASE
+
+ if getattr(pat, "dtype", None) != np.object_:
+ pat = self._stringify(pat)
+
+ def func(x):
+ return re.compile(x, flags=flags)
+
+ if isinstance(pat, np.ndarray):
+ # apply_ufunc doesn't work for numpy arrays with output object dtypes
+ func_ = np.vectorize(func)
+ return func_(pat)
+ else:
+ return _apply_str_ufunc(func=func, obj=pat, dtype=np.object_)
- def len(self) ->T_DataArray:
+ def len(self) -> T_DataArray:
"""
Compute the length of each string in the array.
@@ -124,21 +287,30 @@ class StringAccessor(Generic[T_DataArray]):
-------
lengths array : array of int
"""
- pass
+ return self._apply(func=len, dtype=int)
- def __getitem__(self, key: (int | slice)) ->T_DataArray:
+ def __getitem__(
+ self,
+ key: int | slice,
+ ) -> T_DataArray:
if isinstance(key, slice):
return self.slice(start=key.start, stop=key.stop, step=key.step)
else:
return self.get(key)
- def __add__(self, other: Any) ->T_DataArray:
- return self.cat(other, sep='')
+ def __add__(self, other: Any) -> T_DataArray:
+ return self.cat(other, sep="")
- def __mul__(self, num: (int | Any)) ->T_DataArray:
+ def __mul__(
+ self,
+ num: int | Any,
+ ) -> T_DataArray:
return self.repeat(num)
- def __mod__(self, other: Any) ->T_DataArray:
+ def __mod__(
+ self,
+ other: Any,
+ ) -> T_DataArray:
if isinstance(other, dict):
other = {key: self._stringify(val) for key, val in other.items()}
return self._apply(func=lambda x: x % other)
@@ -148,7 +320,11 @@ class StringAccessor(Generic[T_DataArray]):
else:
return self._apply(func=lambda x, y: x % y, func_args=(other,))
- def get(self, i: (int | Any), default: (str | bytes)='') ->T_DataArray:
+ def get(
+ self,
+ i: int | Any,
+ default: str | bytes = "",
+ ) -> T_DataArray:
"""
Extract character number `i` from each string in the array.
@@ -167,10 +343,21 @@ class StringAccessor(Generic[T_DataArray]):
-------
items : array of object
"""
- pass
- def slice(self, start: (int | Any | None)=None, stop: (int | Any | None
- )=None, step: (int | Any | None)=None) ->T_DataArray:
+ def f(x, iind):
+ islice = slice(-1, None) if iind == -1 else slice(iind, iind + 1)
+ item = x[islice]
+
+ return item if item else default
+
+ return self._apply(func=f, func_args=(i,))
+
+ def slice(
+ self,
+ start: int | Any | None = None,
+ stop: int | Any | None = None,
+ step: int | Any | None = None,
+ ) -> T_DataArray:
"""
Slice substrings from each string in the array.
@@ -193,10 +380,15 @@ class StringAccessor(Generic[T_DataArray]):
-------
sliced strings : same type as values
"""
- pass
+ f = lambda x, istart, istop, istep: x[slice(istart, istop, istep)]
+ return self._apply(func=f, func_args=(start, stop, step))
- def slice_replace(self, start: (int | Any | None)=None, stop: (int |
- Any | None)=None, repl: (str | bytes | Any)='') ->T_DataArray:
+ def slice_replace(
+ self,
+ start: int | Any | None = None,
+ stop: int | Any | None = None,
+ repl: str | bytes | Any = "",
+ ) -> T_DataArray:
"""
Replace a positional slice of a string with another value.
@@ -221,9 +413,24 @@ class StringAccessor(Generic[T_DataArray]):
-------
replaced : same type as values
"""
- pass
+ repl = self._stringify(repl)
- def cat(self, *others, sep: (str | bytes | Any)='') ->T_DataArray:
+ def func(x, istart, istop, irepl):
+ if len(x[istart:istop]) == 0:
+ local_stop = istart
+ else:
+ local_stop = istop
+ y = self._stringify("")
+ if istart is not None:
+ y += x[:istart]
+ y += irepl
+ if istop is not None:
+ y += x[local_stop:]
+ return y
+
+ return self._apply(func=func, func_args=(start, stop, repl))
+
+ def cat(self, *others, sep: str | bytes | Any = "") -> T_DataArray:
"""
Concatenate strings elementwise in the DataArray with other strings.
@@ -291,10 +498,24 @@ class StringAccessor(Generic[T_DataArray]):
pandas.Series.str.cat
str.join
"""
- pass
+ sep = self._stringify(sep)
+ others = tuple(self._stringify(x) for x in others)
+ others = others + (sep,)
+
+ # sep will go at the end of the input arguments.
+ func = lambda *x: x[-1].join(x[:-1])
- def join(self, dim: Hashable=None, sep: (str | bytes | Any)=''
- ) ->T_DataArray:
+ return self._apply(
+ func=func,
+ func_args=others,
+ dtype=self._obj.dtype.kind,
+ )
+
+ def join(
+ self,
+ dim: Hashable = None,
+ sep: str | bytes | Any = "",
+ ) -> T_DataArray:
"""
Concatenate strings in a DataArray along a particular dimension.
@@ -345,9 +566,27 @@ class StringAccessor(Generic[T_DataArray]):
pandas.Series.str.join
str.join
"""
- pass
+ if self._obj.ndim > 1 and dim is None:
+ raise ValueError("Dimension must be specified for multidimensional arrays.")
- def format(self, *args: Any, **kwargs: Any) ->T_DataArray:
+ if self._obj.ndim > 1:
+ # Move the target dimension to the start and split along it
+ dimshifted = list(self._obj.transpose(dim, ...))
+ elif self._obj.ndim == 1:
+ dimshifted = list(self._obj)
+ else:
+ dimshifted = [self._obj]
+
+ start, *others = dimshifted
+
+ # concatenate the resulting arrays
+ return start.str.cat(*others, sep=sep)
+
+ def format(
+ self,
+ *args: Any,
+ **kwargs: Any,
+ ) -> T_DataArray:
"""
Perform python string formatting on each element of the DataArray.
@@ -420,9 +659,14 @@ class StringAccessor(Generic[T_DataArray]):
--------
str.format
"""
- pass
+ args = tuple(self._stringify(x) for x in args)
+ kwargs = {key: self._stringify(val) for key, val in kwargs.items()}
+ func = lambda x, *args, **kwargs: self._obj.dtype.type.format(
+ x, *args, **kwargs
+ )
+ return self._apply(func=func, func_args=args, func_kwargs={"kwargs": kwargs})
- def capitalize(self) ->T_DataArray:
+ def capitalize(self) -> T_DataArray:
"""
Convert strings in the array to be capitalized.
@@ -447,9 +691,9 @@ class StringAccessor(Generic[T_DataArray]):
dtype='<U14')
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.capitalize())
- def lower(self) ->T_DataArray:
+ def lower(self) -> T_DataArray:
"""
Convert strings in the array to lowercase.
@@ -470,9 +714,9 @@ class StringAccessor(Generic[T_DataArray]):
array(['temperature', 'pressure'], dtype='<U11')
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.lower())
- def swapcase(self) ->T_DataArray:
+ def swapcase(self) -> T_DataArray:
"""
Convert strings in the array to be swapcased.
@@ -494,9 +738,9 @@ class StringAccessor(Generic[T_DataArray]):
array(['TEMPERATURE', 'pressure', 'hUmIdItY'], dtype='<U11')
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.swapcase())
- def title(self) ->T_DataArray:
+ def title(self) -> T_DataArray:
"""
Convert strings in the array to titlecase.
@@ -517,9 +761,9 @@ class StringAccessor(Generic[T_DataArray]):
array(['Temperature', 'Pressure', 'Humidity'], dtype='<U11')
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.title())
- def upper(self) ->T_DataArray:
+ def upper(self) -> T_DataArray:
"""
Convert strings in the array to uppercase.
@@ -540,9 +784,9 @@ class StringAccessor(Generic[T_DataArray]):
array(['TEMPERATURE', 'HUMIDITY'], dtype='<U11')
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.upper())
- def casefold(self) ->T_DataArray:
+ def casefold(self) -> T_DataArray:
"""
Convert strings in the array to be casefolded.
@@ -581,9 +825,12 @@ class StringAccessor(Generic[T_DataArray]):
array(['ss', 'i̇'], dtype='<U2')
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.casefold())
- def normalize(self, form: str) ->T_DataArray:
+ def normalize(
+ self,
+ form: str,
+ ) -> T_DataArray:
"""
Return the Unicode normal form for the strings in the datarray.
@@ -600,9 +847,9 @@ class StringAccessor(Generic[T_DataArray]):
normalized : same type as values
"""
- pass
+ return self._apply(func=lambda x: normalize(form, x))
- def isalnum(self) ->T_DataArray:
+ def isalnum(self) -> T_DataArray:
"""
Check whether all characters in each string are alphanumeric.
@@ -624,9 +871,9 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isalnum(), dtype=bool)
- def isalpha(self) ->T_DataArray:
+ def isalpha(self) -> T_DataArray:
"""
Check whether all characters in each string are alphabetic.
@@ -648,9 +895,9 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isalpha(), dtype=bool)
- def isdecimal(self) ->T_DataArray:
+ def isdecimal(self) -> T_DataArray:
"""
Check whether all characters in each string are decimal.
@@ -672,9 +919,9 @@ class StringAccessor(Generic[T_DataArray]):
array([False, True, True])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isdecimal(), dtype=bool)
- def isdigit(self) ->T_DataArray:
+ def isdigit(self) -> T_DataArray:
"""
Check whether all characters in each string are digits.
@@ -696,9 +943,9 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, True, False, False])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isdigit(), dtype=bool)
- def islower(self) ->T_DataArray:
+ def islower(self) -> T_DataArray:
"""
Check whether all characters in each string are lowercase.
@@ -721,9 +968,9 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.islower(), dtype=bool)
- def isnumeric(self) ->T_DataArray:
+ def isnumeric(self) -> T_DataArray:
"""
Check whether all characters in each string are numeric.
@@ -745,9 +992,9 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False, False, False])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isnumeric(), dtype=bool)
- def isspace(self) ->T_DataArray:
+ def isspace(self) -> T_DataArray:
"""
Check whether all characters in each string are spaces.
@@ -769,9 +1016,9 @@ class StringAccessor(Generic[T_DataArray]):
array([False, True, True, True])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isspace(), dtype=bool)
- def istitle(self) ->T_DataArray:
+ def istitle(self) -> T_DataArray:
"""
Check whether all characters in each string are titlecase.
@@ -801,9 +1048,9 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False])
Dimensions without coordinates: title
"""
- pass
+ return self._apply(func=lambda x: x.istitle(), dtype=bool)
- def isupper(self) ->T_DataArray:
+ def isupper(self) -> T_DataArray:
"""
Check whether all characters in each string are uppercase.
@@ -825,10 +1072,11 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False])
Dimensions without coordinates: x
"""
- pass
+ return self._apply(func=lambda x: x.isupper(), dtype=bool)
- def count(self, pat: (str | bytes | Pattern | Any), flags: int=0, case:
- (bool | None)=None) ->T_DataArray:
+ def count(
+ self, pat: str | bytes | Pattern | Any, flags: int = 0, case: bool | None = None
+ ) -> T_DataArray:
"""
Count occurrences of pattern in each string of the array.
@@ -899,9 +1147,12 @@ class StringAccessor(Generic[T_DataArray]):
[0, 1]])
Dimensions without coordinates: x, y
"""
- pass
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
+
+ func = lambda x, ipat: len(ipat.findall(x))
+ return self._apply(func=func, func_args=(pat,), dtype=int)
- def startswith(self, pat: (str | bytes | Any)) ->T_DataArray:
+ def startswith(self, pat: str | bytes | Any) -> T_DataArray:
"""
Test if the start of each string in the array matches a pattern.
@@ -933,9 +1184,11 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False])
Dimensions without coordinates: x
"""
- pass
+ pat = self._stringify(pat)
+ func = lambda x, y: x.startswith(y)
+ return self._apply(func=func, func_args=(pat,), dtype=bool)
- def endswith(self, pat: (str | bytes | Any)) ->T_DataArray:
+ def endswith(self, pat: str | bytes | Any) -> T_DataArray:
"""
Test if the end of each string in the array matches a pattern.
@@ -967,10 +1220,16 @@ class StringAccessor(Generic[T_DataArray]):
array([ True, False, False])
Dimensions without coordinates: x
"""
- pass
+ pat = self._stringify(pat)
+ func = lambda x, y: x.endswith(y)
+ return self._apply(func=func, func_args=(pat,), dtype=bool)
- def pad(self, width: (int | Any), side: str='left', fillchar: (str |
- bytes | Any)=' ') ->T_DataArray:
+ def pad(
+ self,
+ width: int | Any,
+ side: str = "left",
+ fillchar: str | bytes | Any = " ",
+ ) -> T_DataArray:
"""
Pad strings in the array up to width.
@@ -1054,17 +1313,39 @@ class StringAccessor(Generic[T_DataArray]):
['0000NZ39', '----NZ39']], dtype='<U8')
Dimensions without coordinates: x, y
"""
- pass
-
- def _padder(self, *, func: Callable, width: (int | Any), fillchar: (str |
- bytes | Any)=' ') ->T_DataArray:
+ if side == "left":
+ func = self.rjust
+ elif side == "right":
+ func = self.ljust
+ elif side == "both":
+ func = self.center
+ else: # pragma: no cover
+ raise ValueError("Invalid side")
+
+ return func(width=width, fillchar=fillchar)
+
+ def _padder(
+ self,
+ *,
+ func: Callable,
+ width: int | Any,
+ fillchar: str | bytes | Any = " ",
+ ) -> T_DataArray:
"""
Wrapper function to handle padding operations
"""
- pass
+ fillchar = self._stringify(fillchar)
+
+ def overfunc(x, iwidth, ifillchar):
+ if len(ifillchar) != 1:
+ raise TypeError("fillchar must be a character, not str")
+ return func(x, int(iwidth), ifillchar)
+
+ return self._apply(func=overfunc, func_args=(width, fillchar))
- def center(self, width: (int | Any), fillchar: (str | bytes | Any)=' '
- ) ->T_DataArray:
+ def center(
+ self, width: int | Any, fillchar: str | bytes | Any = " "
+ ) -> T_DataArray:
"""
Pad left and right side of each string in the array.
@@ -1084,10 +1365,14 @@ class StringAccessor(Generic[T_DataArray]):
-------
filled : same type as values
"""
- pass
+ func = self._obj.dtype.type.center
+ return self._padder(func=func, width=width, fillchar=fillchar)
- def ljust(self, width: (int | Any), fillchar: (str | bytes | Any)=' '
- ) ->T_DataArray:
+ def ljust(
+ self,
+ width: int | Any,
+ fillchar: str | bytes | Any = " ",
+ ) -> T_DataArray:
"""
Pad right side of each string in the array.
@@ -1107,10 +1392,14 @@ class StringAccessor(Generic[T_DataArray]):
-------
filled : same type as values
"""
- pass
+ func = self._obj.dtype.type.ljust
+ return self._padder(func=func, width=width, fillchar=fillchar)
- def rjust(self, width: (int | Any), fillchar: (str | bytes | Any)=' '
- ) ->T_DataArray:
+ def rjust(
+ self,
+ width: int | Any,
+ fillchar: str | bytes | Any = " ",
+ ) -> T_DataArray:
"""
Pad left side of each string in the array.
@@ -1130,9 +1419,10 @@ class StringAccessor(Generic[T_DataArray]):
-------
filled : same type as values
"""
- pass
+ func = self._obj.dtype.type.rjust
+ return self._padder(func=func, width=width, fillchar=fillchar)
- def zfill(self, width: (int | Any)) ->T_DataArray:
+ def zfill(self, width: int | Any) -> T_DataArray:
"""
Pad each string in the array by prepending '0' characters.
@@ -1153,10 +1443,15 @@ class StringAccessor(Generic[T_DataArray]):
-------
filled : same type as values
"""
- pass
+ return self.rjust(width, fillchar="0")
- def contains(self, pat: (str | bytes | Pattern | Any), case: (bool |
- None)=None, flags: int=0, regex: bool=True) ->T_DataArray:
+ def contains(
+ self,
+ pat: str | bytes | Pattern | Any,
+ case: bool | None = None,
+ flags: int = 0,
+ regex: bool = True,
+ ) -> T_DataArray:
"""
Test if pattern or regex is contained within each string of the array.
@@ -1193,10 +1488,42 @@ class StringAccessor(Generic[T_DataArray]):
given pattern is contained within the string of each element
of the array.
"""
- pass
+ is_compiled_re = _contains_compiled_re(pat)
+ if is_compiled_re and not regex:
+ raise ValueError(
+ "Must use regular expression matching for regular expression object."
+ )
+
+ if regex:
+ if not is_compiled_re:
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
+
+ def func(x, ipat):
+ if ipat.groups > 0: # pragma: no cover
+ raise ValueError("This pattern has match groups.")
+ return bool(ipat.search(x))
- def match(self, pat: (str | bytes | Pattern | Any), case: (bool | None)
- =None, flags: int=0) ->T_DataArray:
+ else:
+ pat = self._stringify(pat)
+ if case or case is None:
+ func = lambda x, ipat: ipat in x
+ elif self._obj.dtype.char == "U":
+ uppered = self.casefold()
+ uppat = StringAccessor(pat).casefold() # type: ignore[type-var] # hack?
+ return uppered.str.contains(uppat, regex=False) # type: ignore[return-value]
+ else:
+ uppered = self.upper()
+ uppat = StringAccessor(pat).upper() # type: ignore[type-var] # hack?
+ return uppered.str.contains(uppat, regex=False) # type: ignore[return-value]
+
+ return self._apply(func=func, func_args=(pat,), dtype=bool)
+
+ def match(
+ self,
+ pat: str | bytes | Pattern | Any,
+ case: bool | None = None,
+ flags: int = 0,
+ ) -> T_DataArray:
"""
Determine if each string in the array matches a regular expression.
@@ -1223,10 +1550,14 @@ class StringAccessor(Generic[T_DataArray]):
-------
matched : array of bool
"""
- pass
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
+
+ func = lambda x, ipat: bool(ipat.match(x))
+ return self._apply(func=func, func_args=(pat,), dtype=bool)
- def strip(self, to_strip: (str | bytes | Any)=None, side: str='both'
- ) ->T_DataArray:
+ def strip(
+ self, to_strip: str | bytes | Any = None, side: str = "both"
+ ) -> T_DataArray:
"""
Remove leading and trailing characters.
@@ -1249,9 +1580,21 @@ class StringAccessor(Generic[T_DataArray]):
-------
stripped : same type as values
"""
- pass
+ if to_strip is not None:
+ to_strip = self._stringify(to_strip)
+
+ if side == "both":
+ func = lambda x, y: x.strip(y)
+ elif side == "left":
+ func = lambda x, y: x.lstrip(y)
+ elif side == "right":
+ func = lambda x, y: x.rstrip(y)
+ else: # pragma: no cover
+ raise ValueError("Invalid side")
- def lstrip(self, to_strip: (str | bytes | Any)=None) ->T_DataArray:
+ return self._apply(func=func, func_args=(to_strip,))
+
+ def lstrip(self, to_strip: str | bytes | Any = None) -> T_DataArray:
"""
Remove leading characters.
@@ -1272,9 +1615,9 @@ class StringAccessor(Generic[T_DataArray]):
-------
stripped : same type as values
"""
- pass
+ return self.strip(to_strip, side="left")
- def rstrip(self, to_strip: (str | bytes | Any)=None) ->T_DataArray:
+ def rstrip(self, to_strip: str | bytes | Any = None) -> T_DataArray:
"""
Remove trailing characters.
@@ -1295,9 +1638,9 @@ class StringAccessor(Generic[T_DataArray]):
-------
stripped : same type as values
"""
- pass
+ return self.strip(to_strip, side="right")
- def wrap(self, width: (int | Any), **kwargs) ->T_DataArray:
+ def wrap(self, width: int | Any, **kwargs) -> T_DataArray:
"""
Wrap long strings in the array in paragraphs with length less than `width`.
@@ -1319,10 +1662,13 @@ class StringAccessor(Generic[T_DataArray]):
-------
wrapped : same type as values
"""
- pass
+ ifunc = lambda x: textwrap.TextWrapper(width=x, **kwargs)
+ tw = StringAccessor(width)._apply(func=ifunc, dtype=np.object_) # type: ignore[type-var] # hack?
+ func = lambda x, itw: "\n".join(itw.wrap(x))
+ return self._apply(func=func, func_args=(tw,))
- def translate(self, table: Mapping[Any, str | bytes | int | None]
- ) ->T_DataArray:
+ # Mapping is only covariant in its values, maybe use a custom CovariantMapping?
+ def translate(self, table: Mapping[Any, str | bytes | int | None]) -> T_DataArray:
"""
Map characters of each string through the given mapping table.
@@ -1338,9 +1684,13 @@ class StringAccessor(Generic[T_DataArray]):
-------
translated : same type as values
"""
- pass
+ func = lambda x: x.translate(table)
+ return self._apply(func=func)
- def repeat(self, repeats: (int | Any)) ->T_DataArray:
+ def repeat(
+ self,
+ repeats: int | Any,
+ ) -> T_DataArray:
"""
Repeat each string in the array.
@@ -1358,10 +1708,16 @@ class StringAccessor(Generic[T_DataArray]):
repeated : same type as values
Array of repeated string objects.
"""
- pass
+ func = lambda x, y: x * y
+ return self._apply(func=func, func_args=(repeats,))
- def find(self, sub: (str | bytes | Any), start: (int | Any)=0, end: (
- int | Any)=None, side: str='left') ->T_DataArray:
+ def find(
+ self,
+ sub: str | bytes | Any,
+ start: int | Any = 0,
+ end: int | Any = None,
+ side: str = "left",
+ ) -> T_DataArray:
"""
Return lowest or highest indexes in each strings in the array
where the substring is fully contained between [start:end].
@@ -1388,10 +1744,24 @@ class StringAccessor(Generic[T_DataArray]):
-------
found : array of int
"""
- pass
+ sub = self._stringify(sub)
+
+ if side == "left":
+ method = "find"
+ elif side == "right":
+ method = "rfind"
+ else: # pragma: no cover
+ raise ValueError("Invalid side")
- def rfind(self, sub: (str | bytes | Any), start: (int | Any)=0, end: (
- int | Any)=None) ->T_DataArray:
+ func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend)
+ return self._apply(func=func, func_args=(sub, start, end), dtype=int)
+
+ def rfind(
+ self,
+ sub: str | bytes | Any,
+ start: int | Any = 0,
+ end: int | Any = None,
+ ) -> T_DataArray:
"""
Return highest indexes in each strings in the array
where the substring is fully contained between [start:end].
@@ -1416,10 +1786,15 @@ class StringAccessor(Generic[T_DataArray]):
-------
found : array of int
"""
- pass
+ return self.find(sub, start=start, end=end, side="right")
- def index(self, sub: (str | bytes | Any), start: (int | Any)=0, end: (
- int | Any)=None, side: str='left') ->T_DataArray:
+ def index(
+ self,
+ sub: str | bytes | Any,
+ start: int | Any = 0,
+ end: int | Any = None,
+ side: str = "left",
+ ) -> T_DataArray:
"""
Return lowest or highest indexes in each strings where the substring is
fully contained between [start:end]. This is the same as
@@ -1452,10 +1827,24 @@ class StringAccessor(Generic[T_DataArray]):
ValueError
substring is not found
"""
- pass
+ sub = self._stringify(sub)
+
+ if side == "left":
+ method = "index"
+ elif side == "right":
+ method = "rindex"
+ else: # pragma: no cover
+ raise ValueError("Invalid side")
+
+ func = lambda x, isub, istart, iend: getattr(x, method)(isub, istart, iend)
+ return self._apply(func=func, func_args=(sub, start, end), dtype=int)
- def rindex(self, sub: (str | bytes | Any), start: (int | Any)=0, end: (
- int | Any)=None) ->T_DataArray:
+ def rindex(
+ self,
+ sub: str | bytes | Any,
+ start: int | Any = 0,
+ end: int | Any = None,
+ ) -> T_DataArray:
"""
Return highest indexes in each strings where the substring is
fully contained between [start:end]. This is the same as
@@ -1486,11 +1875,17 @@ class StringAccessor(Generic[T_DataArray]):
ValueError
substring is not found
"""
- pass
+ return self.index(sub, start=start, end=end, side="right")
- def replace(self, pat: (str | bytes | Pattern | Any), repl: (str |
- bytes | Callable | Any), n: (int | Any)=-1, case: (bool | None)=
- None, flags: int=0, regex: bool=True) ->T_DataArray:
+ def replace(
+ self,
+ pat: str | bytes | Pattern | Any,
+ repl: str | bytes | Callable | Any,
+ n: int | Any = -1,
+ case: bool | None = None,
+ flags: int = 0,
+ regex: bool = True,
+ ) -> T_DataArray:
"""
Replace occurrences of pattern/regex in the array with some string.
@@ -1531,11 +1926,38 @@ class StringAccessor(Generic[T_DataArray]):
A copy of the object with all matching occurrences of `pat`
replaced by `repl`.
"""
- pass
-
- def extract(self, pat: (str | bytes | Pattern | Any), dim: Hashable,
- case: (bool | None)=None, flags: int=0) ->T_DataArray:
- """
+ if _contains_str_like(repl):
+ repl = self._stringify(repl)
+ elif not _contains_callable(repl): # pragma: no cover
+ raise TypeError("repl must be a string or callable")
+
+ is_compiled_re = _contains_compiled_re(pat)
+ if not regex and is_compiled_re:
+ raise ValueError(
+ "Cannot use a compiled regex as replacement pattern with regex=False"
+ )
+
+ if not regex and callable(repl):
+ raise ValueError("Cannot use a callable replacement when regex=False")
+
+ if regex:
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
+ func = lambda x, ipat, irepl, i_n: ipat.sub(
+ repl=irepl, string=x, count=i_n if i_n >= 0 else 0
+ )
+ else:
+ pat = self._stringify(pat)
+ func = lambda x, ipat, irepl, i_n: x.replace(ipat, irepl, i_n)
+ return self._apply(func=func, func_args=(pat, repl, n))
+
+ def extract(
+ self,
+ pat: str | bytes | Pattern | Any,
+ dim: Hashable,
+ case: bool | None = None,
+ flags: int = 0,
+ ) -> T_DataArray:
+ r"""
Extract the first match of capture groups in the regex pat as a new
dimension in a DataArray.
@@ -1601,7 +2023,7 @@ class StringAccessor(Generic[T_DataArray]):
Extract matches
- >>> value.str.extract(r"(\\w+)_Xy_(\\d*)", dim="match")
+ >>> value.str.extract(r"(\w+)_Xy_(\d*)", dim="match")
<xarray.DataArray (X: 2, Y: 3, match: 2)> Size: 288B
array([[['a', '0'],
['bab', '110'],
@@ -1620,12 +2042,70 @@ class StringAccessor(Generic[T_DataArray]):
re.search
pandas.Series.str.extract
"""
- pass
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
- def extractall(self, pat: (str | bytes | Pattern | Any), group_dim:
- Hashable, match_dim: Hashable, case: (bool | None)=None, flags: int=0
- ) ->T_DataArray:
- """
+ if isinstance(pat, re.Pattern):
+ maxgroups = pat.groups
+ else:
+ maxgroups = (
+ _apply_str_ufunc(obj=pat, func=lambda x: x.groups, dtype=np.int_)
+ .max()
+ .data.tolist()
+ )
+
+ if maxgroups == 0:
+ raise ValueError("No capture groups found in pattern.")
+
+ if dim is None and maxgroups != 1:
+ raise ValueError(
+ "Dimension must be specified if more than one capture group is given."
+ )
+
+ if dim is not None and dim in self._obj.dims:
+ raise KeyError(f"Dimension '{dim}' already present in DataArray.")
+
+ def _get_res_single(val, pat):
+ match = pat.search(val)
+ if match is None:
+ return ""
+ res = match.group(1)
+ if res is None:
+ res = ""
+ return res
+
+ def _get_res_multi(val, pat):
+ match = pat.search(val)
+ if match is None:
+ return np.array([""], val.dtype)
+ match = match.groups()
+ match = [grp if grp is not None else "" for grp in match]
+ return np.array(match, val.dtype)
+
+ if dim is None:
+ return self._apply(func=_get_res_single, func_args=(pat,))
+ else:
+ # dtype MUST be object or strings can be truncated
+ # See: https://github.com/numpy/numpy/issues/8352
+ return duck_array_ops.astype(
+ self._apply(
+ func=_get_res_multi,
+ func_args=(pat,),
+ dtype=np.object_,
+ output_core_dims=[[dim]],
+ output_sizes={dim: maxgroups},
+ ),
+ self._obj.dtype.kind,
+ )
+
+ def extractall(
+ self,
+ pat: str | bytes | Pattern | Any,
+ group_dim: Hashable,
+ match_dim: Hashable,
+ case: bool | None = None,
+ flags: int = 0,
+ ) -> T_DataArray:
+ r"""
Extract all matches of capture groups in the regex pat as new
dimensions in a DataArray.
@@ -1696,7 +2176,7 @@ class StringAccessor(Generic[T_DataArray]):
Extract matches
>>> value.str.extractall(
- ... r"(\\w+)_Xy_(\\d*)", group_dim="group", match_dim="match"
+ ... r"(\w+)_Xy_(\d*)", group_dim="group", match_dim="match"
... )
<xarray.DataArray (X: 2, Y: 3, group: 3, match: 2)> Size: 1kB
array([[[['a', '0'],
@@ -1733,11 +2213,75 @@ class StringAccessor(Generic[T_DataArray]):
re.findall
pandas.Series.str.extractall
"""
- pass
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
- def findall(self, pat: (str | bytes | Pattern | Any), case: (bool |
- None)=None, flags: int=0) ->T_DataArray:
- """
+ if group_dim in self._obj.dims:
+ raise KeyError(
+ f"Group dimension '{group_dim}' already present in DataArray."
+ )
+
+ if match_dim in self._obj.dims:
+ raise KeyError(
+ f"Match dimension '{match_dim}' already present in DataArray."
+ )
+
+ if group_dim == match_dim:
+ raise KeyError(
+ f"Group dimension '{group_dim}' is the same as match dimension '{match_dim}'."
+ )
+
+ _get_count = lambda x, ipat: len(ipat.findall(x))
+ maxcount = (
+ self._apply(func=_get_count, func_args=(pat,), dtype=np.int_)
+ .max()
+ .data.tolist()
+ )
+
+ if isinstance(pat, re.Pattern):
+ maxgroups = pat.groups
+ else:
+ maxgroups = (
+ _apply_str_ufunc(obj=pat, func=lambda x: x.groups, dtype=np.int_)
+ .max()
+ .data.tolist()
+ )
+
+ def _get_res(val, ipat, imaxcount=maxcount, dtype=self._obj.dtype):
+ if ipat.groups == 0:
+ raise ValueError("No capture groups found in pattern.")
+ matches = ipat.findall(val)
+ res = np.zeros([maxcount, ipat.groups], dtype)
+
+ if ipat.groups == 1:
+ for imatch, match in enumerate(matches):
+ res[imatch, 0] = match
+ else:
+ for imatch, match in enumerate(matches):
+ for jmatch, submatch in enumerate(match):
+ res[imatch, jmatch] = submatch
+
+ return res
+
+ return duck_array_ops.astype(
+ self._apply(
+ # dtype MUST be object or strings can be truncated
+ # See: https://github.com/numpy/numpy/issues/8352
+ func=_get_res,
+ func_args=(pat,),
+ dtype=np.object_,
+ output_core_dims=[[group_dim, match_dim]],
+ output_sizes={group_dim: maxgroups, match_dim: maxcount},
+ ),
+ self._obj.dtype.kind,
+ )
+
+ def findall(
+ self,
+ pat: str | bytes | Pattern | Any,
+ case: bool | None = None,
+ flags: int = 0,
+ ) -> T_DataArray:
+ r"""
Find all occurrences of pattern or regular expression in the DataArray.
Equivalent to applying re.findall() to all the elements in the DataArray.
@@ -1797,7 +2341,7 @@ class StringAccessor(Generic[T_DataArray]):
Extract matches
- >>> value.str.findall(r"(\\w+)_Xy_(\\d*)")
+ >>> value.str.findall(r"(\w+)_Xy_(\d*)")
<xarray.DataArray (X: 2, Y: 3)> Size: 48B
array([[list([('a', '0')]), list([('bab', '110'), ('baab', '1100')]),
list([('abc', '01'), ('cbc', '2210')])],
@@ -1814,17 +2358,56 @@ class StringAccessor(Generic[T_DataArray]):
re.findall
pandas.Series.str.findall
"""
- pass
+ pat = self._re_compile(pat=pat, flags=flags, case=case)
+
+ def func(x, ipat):
+ if ipat.groups == 0:
+ raise ValueError("No capture groups found in pattern.")
+
+ return ipat.findall(x)
+
+ return self._apply(func=func, func_args=(pat,), dtype=np.object_)
- def _partitioner(self, *, func: Callable, dim: (Hashable | None), sep:
- (str | bytes | Any | None)) ->T_DataArray:
+ def _partitioner(
+ self,
+ *,
+ func: Callable,
+ dim: Hashable | None,
+ sep: str | bytes | Any | None,
+ ) -> T_DataArray:
"""
Implements logic for `partition` and `rpartition`.
"""
- pass
-
- def partition(self, dim: (Hashable | None), sep: (str | bytes | Any)=' '
- ) ->T_DataArray:
+ sep = self._stringify(sep)
+
+ if dim is None:
+ listfunc = lambda x, isep: list(func(x, isep))
+ return self._apply(func=listfunc, func_args=(sep,), dtype=np.object_)
+
+ # _apply breaks on an empty array in this case
+ if not self._obj.size:
+ return self._obj.copy().expand_dims({dim: 0}, axis=-1)
+
+ arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype)
+
+ # dtype MUST be object or strings can be truncated
+ # See: https://github.com/numpy/numpy/issues/8352
+ return duck_array_ops.astype(
+ self._apply(
+ func=arrfunc,
+ func_args=(sep,),
+ dtype=np.object_,
+ output_core_dims=[[dim]],
+ output_sizes={dim: 3},
+ ),
+ self._obj.dtype.kind,
+ )
+
+ def partition(
+ self,
+ dim: Hashable | None,
+ sep: str | bytes | Any = " ",
+ ) -> T_DataArray:
"""
Split the strings in the DataArray at the first occurrence of separator `sep`.
@@ -1856,10 +2439,13 @@ class StringAccessor(Generic[T_DataArray]):
str.partition
pandas.Series.str.partition
"""
- pass
+ return self._partitioner(func=self._obj.dtype.type.partition, dim=dim, sep=sep)
- def rpartition(self, dim: (Hashable | None), sep: (str | bytes | Any)=' '
- ) ->T_DataArray:
+ def rpartition(
+ self,
+ dim: Hashable | None,
+ sep: str | bytes | Any = " ",
+ ) -> T_DataArray:
"""
Split the strings in the DataArray at the last occurrence of separator `sep`.
@@ -1891,18 +2477,67 @@ class StringAccessor(Generic[T_DataArray]):
str.rpartition
pandas.Series.str.rpartition
"""
- pass
+ return self._partitioner(func=self._obj.dtype.type.rpartition, dim=dim, sep=sep)
- def _splitter(self, *, func: Callable, pre: bool, dim: Hashable, sep: (
- str | bytes | Any | None), maxsplit: int) ->DataArray:
+ def _splitter(
+ self,
+ *,
+ func: Callable,
+ pre: bool,
+ dim: Hashable,
+ sep: str | bytes | Any | None,
+ maxsplit: int,
+ ) -> DataArray:
"""
Implements logic for `split` and `rsplit`.
"""
- pass
-
- def split(self, dim: (Hashable | None), sep: (str | bytes | Any)=None,
- maxsplit: int=-1) ->DataArray:
- """
+ if sep is not None:
+ sep = self._stringify(sep)
+
+ if dim is None:
+ f_none = lambda x, isep: func(x, isep, maxsplit)
+ return self._apply(func=f_none, func_args=(sep,), dtype=np.object_)
+
+ # _apply breaks on an empty array in this case
+ if not self._obj.size:
+ return self._obj.copy().expand_dims({dim: 0}, axis=-1)
+
+ f_count = lambda x, isep: max(len(func(x, isep, maxsplit)), 1)
+ maxsplit = (
+ self._apply(func=f_count, func_args=(sep,), dtype=np.int_).max().data.item()
+ - 1
+ )
+
+ def _dosplit(mystr, sep, maxsplit=maxsplit, dtype=self._obj.dtype):
+ res = func(mystr, sep, maxsplit)
+ if len(res) < maxsplit + 1:
+ pad = [""] * (maxsplit + 1 - len(res))
+ if pre:
+ res += pad
+ else:
+ res = pad + res
+ return np.array(res, dtype=dtype)
+
+ # dtype MUST be object or strings can be truncated
+ # See: https://github.com/numpy/numpy/issues/8352
+ return duck_array_ops.astype(
+ self._apply(
+ func=_dosplit,
+ func_args=(sep,),
+ dtype=np.object_,
+ output_core_dims=[[dim]],
+ output_sizes={dim: maxsplit},
+ ),
+ self._obj.dtype.kind,
+ )
+
+ def split(
+ self,
+ dim: Hashable | None,
+ sep: str | bytes | Any = None,
+ maxsplit: int = -1,
+ ) -> DataArray:
+ r"""
Split strings in a DataArray around the given separator/delimiter `sep`.
Splits the string in the DataArray from the beginning,
@@ -1933,8 +2568,8 @@ class StringAccessor(Generic[T_DataArray]):
>>> values = xr.DataArray(
... [
- ... ["abc def", "spam\\t\\teggs\\tswallow", "red_blue"],
- ... ["test0\\ntest1\\ntest2\\n\\ntest3", "", "abra ka\\nda\\tbra"],
+ ... ["abc def", "spam\t\teggs\tswallow", "red_blue"],
+ ... ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"],
... ],
... dims=["X", "Y"],
... )
@@ -1944,12 +2579,12 @@ class StringAccessor(Generic[T_DataArray]):
>>> values.str.split(dim="splitted", maxsplit=1)
<xarray.DataArray (X: 2, Y: 3, splitted: 2)> Size: 864B
array([[['abc', 'def'],
- ['spam', 'eggs\\tswallow'],
+ ['spam', 'eggs\tswallow'],
['red_blue', '']],
<BLANKLINE>
- [['test0', 'test1\\ntest2\\n\\ntest3'],
+ [['test0', 'test1\ntest2\n\ntest3'],
['', ''],
- ['abra', 'ka\\nda\\tbra']]], dtype='<U18')
+ ['abra', 'ka\nda\tbra']]], dtype='<U18')
Dimensions without coordinates: X, Y, splitted
Split as many times as needed and put the results in a new dimension
@@ -1969,10 +2604,10 @@ class StringAccessor(Generic[T_DataArray]):
>>> values.str.split(dim=None, maxsplit=1)
<xarray.DataArray (X: 2, Y: 3)> Size: 48B
- array([[list(['abc', 'def']), list(['spam', 'eggs\\tswallow']),
+ array([[list(['abc', 'def']), list(['spam', 'eggs\tswallow']),
list(['red_blue'])],
- [list(['test0', 'test1\\ntest2\\n\\ntest3']), list([]),
- list(['abra', 'ka\\nda\\tbra'])]], dtype=object)
+ [list(['test0', 'test1\ntest2\n\ntest3']), list([]),
+ list(['abra', 'ka\nda\tbra'])]], dtype=object)
Dimensions without coordinates: X, Y
Split as many times as needed and put the results in a list
@@ -1990,12 +2625,12 @@ class StringAccessor(Generic[T_DataArray]):
>>> values.str.split(dim="splitted", sep=" ")
<xarray.DataArray (X: 2, Y: 3, splitted: 3)> Size: 2kB
array([[['abc', 'def', ''],
- ['spam\\t\\teggs\\tswallow', '', ''],
+ ['spam\t\teggs\tswallow', '', ''],
['red_blue', '', '']],
<BLANKLINE>
- [['test0\\ntest1\\ntest2\\n\\ntest3', '', ''],
+ [['test0\ntest1\ntest2\n\ntest3', '', ''],
['', '', ''],
- ['abra', '', 'ka\\nda\\tbra']]], dtype='<U24')
+ ['abra', '', 'ka\nda\tbra']]], dtype='<U24')
Dimensions without coordinates: X, Y, splitted
See Also
@@ -2004,11 +2639,21 @@ class StringAccessor(Generic[T_DataArray]):
str.split
pandas.Series.str.split
"""
- pass
-
- def rsplit(self, dim: (Hashable | None), sep: (str | bytes | Any)=None,
- maxsplit: (int | Any)=-1) ->DataArray:
- """
+ return self._splitter(
+ func=self._obj.dtype.type.split,
+ pre=True,
+ dim=dim,
+ sep=sep,
+ maxsplit=maxsplit,
+ )
+
+ def rsplit(
+ self,
+ dim: Hashable | None,
+ sep: str | bytes | Any = None,
+ maxsplit: int | Any = -1,
+ ) -> DataArray:
+ r"""
Split strings in a DataArray around the given separator/delimiter `sep`.
Splits the string in the DataArray from the end,
@@ -2041,8 +2686,8 @@ class StringAccessor(Generic[T_DataArray]):
>>> values = xr.DataArray(
... [
- ... ["abc def", "spam\\t\\teggs\\tswallow", "red_blue"],
- ... ["test0\\ntest1\\ntest2\\n\\ntest3", "", "abra ka\\nda\\tbra"],
+ ... ["abc def", "spam\t\teggs\tswallow", "red_blue"],
+ ... ["test0\ntest1\ntest2\n\ntest3", "", "abra ka\nda\tbra"],
... ],
... dims=["X", "Y"],
... )
@@ -2052,12 +2697,12 @@ class StringAccessor(Generic[T_DataArray]):
>>> values.str.rsplit(dim="splitted", maxsplit=1)
<xarray.DataArray (X: 2, Y: 3, splitted: 2)> Size: 816B
array([[['abc', 'def'],
- ['spam\\t\\teggs', 'swallow'],
+ ['spam\t\teggs', 'swallow'],
['', 'red_blue']],
<BLANKLINE>
- [['test0\\ntest1\\ntest2', 'test3'],
+ [['test0\ntest1\ntest2', 'test3'],
['', ''],
- ['abra ka\\nda', 'bra']]], dtype='<U17')
+ ['abra ka\nda', 'bra']]], dtype='<U17')
Dimensions without coordinates: X, Y, splitted
Split as many times as needed and put the results in a new dimension
@@ -2077,10 +2722,10 @@ class StringAccessor(Generic[T_DataArray]):
>>> values.str.rsplit(dim=None, maxsplit=1)
<xarray.DataArray (X: 2, Y: 3)> Size: 48B
- array([[list(['abc', 'def']), list(['spam\\t\\teggs', 'swallow']),
+ array([[list(['abc', 'def']), list(['spam\t\teggs', 'swallow']),
list(['red_blue'])],
- [list(['test0\\ntest1\\ntest2', 'test3']), list([]),
- list(['abra ka\\nda', 'bra'])]], dtype=object)
+ [list(['test0\ntest1\ntest2', 'test3']), list([]),
+ list(['abra ka\nda', 'bra'])]], dtype=object)
Dimensions without coordinates: X, Y
Split as many times as needed and put the results in a list
@@ -2098,12 +2743,12 @@ class StringAccessor(Generic[T_DataArray]):
>>> values.str.rsplit(dim="splitted", sep=" ")
<xarray.DataArray (X: 2, Y: 3, splitted: 3)> Size: 2kB
array([[['', 'abc', 'def'],
- ['', '', 'spam\\t\\teggs\\tswallow'],
+ ['', '', 'spam\t\teggs\tswallow'],
['', '', 'red_blue']],
<BLANKLINE>
- [['', '', 'test0\\ntest1\\ntest2\\n\\ntest3'],
+ [['', '', 'test0\ntest1\ntest2\n\ntest3'],
['', '', ''],
- ['abra', '', 'ka\\nda\\tbra']]], dtype='<U24')
+ ['abra', '', 'ka\nda\tbra']]], dtype='<U24')
Dimensions without coordinates: X, Y, splitted
See Also
@@ -2112,10 +2757,19 @@ class StringAccessor(Generic[T_DataArray]):
str.rsplit
pandas.Series.str.rsplit
"""
- pass
+ return self._splitter(
+ func=self._obj.dtype.type.rsplit,
+ pre=False,
+ dim=dim,
+ sep=sep,
+ maxsplit=maxsplit,
+ )
- def get_dummies(self, dim: Hashable, sep: (str | bytes | Any)='|'
- ) ->DataArray:
+ def get_dummies(
+ self,
+ dim: Hashable,
+ sep: str | bytes | Any = "|",
+ ) -> DataArray:
"""
Return DataArray of dummy/indicator variables.
@@ -2170,9 +2824,27 @@ class StringAccessor(Generic[T_DataArray]):
--------
pandas.Series.str.get_dummies
"""
- pass
+ # _apply breaks on an empty array in this case
+ if not self._obj.size:
+ return self._obj.copy().expand_dims({dim: 0}, axis=-1)
- def decode(self, encoding: str, errors: str='strict') ->T_DataArray:
+ sep = self._stringify(sep)
+ f_set = lambda x, isep: set(x.split(isep)) - {self._stringify("")}
+ setarr = self._apply(func=f_set, func_args=(sep,), dtype=np.object_)
+ vals = sorted(reduce(set_union, setarr.data.ravel()))
+
+ func = lambda x: np.array([val in x for val in vals], dtype=np.bool_)
+ res = _apply_str_ufunc(
+ func=func,
+ obj=setarr,
+ output_core_dims=[[dim]],
+ output_sizes={dim: len(vals)},
+ dtype=np.bool_,
+ )
+ res.coords[dim] = vals
+ return res
+
+ def decode(self, encoding: str, errors: str = "strict") -> T_DataArray:
"""
Decode character string in the array using indicated encoding.
@@ -2191,9 +2863,14 @@ class StringAccessor(Generic[T_DataArray]):
-------
decoded : same type as values
"""
- pass
+ if encoding in _cpython_optimized_decoders:
+ func = lambda x: x.decode(encoding, errors)
+ else:
+ decoder = codecs.getdecoder(encoding)
+ func = lambda x: decoder(x, errors)[0]
+ return self._apply(func=func, dtype=np.str_)
- def encode(self, encoding: str, errors: str='strict') ->T_DataArray:
+ def encode(self, encoding: str, errors: str = "strict") -> T_DataArray:
"""
Encode character string in the array using indicated encoding.
@@ -2212,4 +2889,9 @@ class StringAccessor(Generic[T_DataArray]):
-------
encoded : same type as values
"""
- pass
+ if encoding in _cpython_optimized_encoders:
+ func = lambda x: x.encode(encoding, errors)
+ else:
+ encoder = codecs.getencoder(encoding)
+ func = lambda x: encoder(x, errors)[0]
+ return self._apply(func=func, dtype=np.bytes_)
diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py
index 3ec2a057..44fc7319 100644
--- a/xarray/core/alignment.py
+++ b/xarray/core/alignment.py
@@ -1,33 +1,93 @@
from __future__ import annotations
+
import functools
import operator
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload
+
import numpy as np
import pandas as pd
+
from xarray.core import dtypes
-from xarray.core.indexes import Index, Indexes, PandasIndex, PandasMultiIndex, indexes_all_equal, safe_cast_to_index
+from xarray.core.indexes import (
+ Index,
+ Indexes,
+ PandasIndex,
+ PandasMultiIndex,
+ indexes_all_equal,
+ safe_cast_to_index,
+)
from xarray.core.types import T_Alignable
from xarray.core.utils import is_dict_like, is_full_slice
from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
- from xarray.core.types import Alignable, JoinOptions, T_DataArray, T_Dataset, T_DuckArray
-
-
-def reindex_variables(variables: Mapping[Any, Variable], dim_pos_indexers:
- Mapping[Any, Any], copy: bool=True, fill_value: Any=dtypes.NA, sparse:
- bool=False) ->dict[Hashable, Variable]:
+ from xarray.core.types import (
+ Alignable,
+ JoinOptions,
+ T_DataArray,
+ T_Dataset,
+ T_DuckArray,
+ )
+
+
+def reindex_variables(
+ variables: Mapping[Any, Variable],
+ dim_pos_indexers: Mapping[Any, Any],
+ copy: bool = True,
+ fill_value: Any = dtypes.NA,
+ sparse: bool = False,
+) -> dict[Hashable, Variable]:
"""Conform a dictionary of variables onto a new set of variables reindexed
with dimension positional indexers and possibly filled with missing values.
Not public API.
"""
- pass
+ new_variables = {}
+ dim_sizes = calculate_dimensions(variables)
+
+ masked_dims = set()
+ unchanged_dims = set()
+ for dim, indxr in dim_pos_indexers.items():
+ # Negative values in dim_pos_indexers mean values missing in the new index
+ # See ``Index.reindex_like``.
+ if (indxr < 0).any():
+ masked_dims.add(dim)
+ elif np.array_equal(indxr, np.arange(dim_sizes.get(dim, 0))):
+ unchanged_dims.add(dim)
+
+ for name, var in variables.items():
+ if isinstance(fill_value, dict):
+ fill_value_ = fill_value.get(name, dtypes.NA)
+ else:
+ fill_value_ = fill_value
+
+ if sparse:
+ var = var._as_sparse(fill_value=fill_value_)
+ indxr = tuple(
+ slice(None) if d in unchanged_dims else dim_pos_indexers.get(d, slice(None))
+ for d in var.dims
+ )
+ needs_masking = any(d in masked_dims for d in var.dims)
+
+ if needs_masking:
+ new_var = var._getitem_with_mask(indxr, fill_value=fill_value_)
+ elif all(is_full_slice(k) for k in indxr):
+ # no reindexing necessary
+ # here we need to manually deal with copying data, since
+ # we neither created a new ndarray nor used fancy indexing
+ new_var = var.copy(deep=copy)
+ else:
+ new_var = var[indxr]
+
+ new_variables[name] = new_var
+
+ return new_variables
CoordNamesAndDims = tuple[tuple[Hashable, tuple[Hashable, ...]], ...]
@@ -48,6 +108,7 @@ class Aligner(Generic[T_Alignable]):
aligned_objects = aligner.results
"""
+
objects: tuple[T_Alignable, ...]
results: tuple[T_Alignable, ...]
objects_matching_indexes: tuple[dict[MatchingIndexKey, Index], ...]
@@ -68,51 +129,161 @@ class Aligner(Generic[T_Alignable]):
unindexed_dim_sizes: dict[Hashable, set]
new_indexes: Indexes[Index]
- def __init__(self, objects: Iterable[T_Alignable], join: str='inner',
- indexes: (Mapping[Any, Any] | None)=None, exclude_dims: (str |
- Iterable[Hashable])=frozenset(), exclude_vars: Iterable[Hashable]=
- frozenset(), method: (str | None)=None, tolerance: (float |
- Iterable[float] | str | None)=None, copy: bool=True, fill_value:
- Any=dtypes.NA, sparse: bool=False):
+ def __init__(
+ self,
+ objects: Iterable[T_Alignable],
+ join: str = "inner",
+ indexes: Mapping[Any, Any] | None = None,
+ exclude_dims: str | Iterable[Hashable] = frozenset(),
+ exclude_vars: Iterable[Hashable] = frozenset(),
+ method: str | None = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value: Any = dtypes.NA,
+ sparse: bool = False,
+ ):
self.objects = tuple(objects)
self.objects_matching_indexes = ()
- if join not in ['inner', 'outer', 'override', 'exact', 'left', 'right'
- ]:
- raise ValueError(f'invalid value for join: {join}')
+
+ if join not in ["inner", "outer", "override", "exact", "left", "right"]:
+ raise ValueError(f"invalid value for join: {join}")
self.join = join
+
self.copy = copy
self.fill_value = fill_value
self.sparse = sparse
+
if method is None and tolerance is None:
self.reindex_kwargs = {}
else:
- self.reindex_kwargs = {'method': method, 'tolerance': tolerance}
+ self.reindex_kwargs = {"method": method, "tolerance": tolerance}
+
if isinstance(exclude_dims, str):
exclude_dims = [exclude_dims]
self.exclude_dims = frozenset(exclude_dims)
self.exclude_vars = frozenset(exclude_vars)
+
if indexes is None:
indexes = {}
self.indexes, self.index_vars = self._normalize_indexes(indexes)
+
self.all_indexes = {}
self.all_index_vars = {}
self.unindexed_dim_sizes = {}
+
self.aligned_indexes = {}
self.aligned_index_vars = {}
self.reindex = {}
+
self.results = tuple()
- def _normalize_indexes(self, indexes: Mapping[Any, Any | T_DuckArray]
- ) ->tuple[NormalizedIndexes, NormalizedIndexVars]:
+ def _normalize_indexes(
+ self,
+ indexes: Mapping[Any, Any | T_DuckArray],
+ ) -> tuple[NormalizedIndexes, NormalizedIndexVars]:
"""Normalize the indexes/indexers used for re-indexing or alignment.
Return dictionaries of xarray Index objects and coordinate variables
such that we can group matching indexes based on the dictionary keys.
"""
- pass
-
- def assert_no_index_conflict(self) ->None:
+ if isinstance(indexes, Indexes):
+ xr_variables = dict(indexes.variables)
+ else:
+ xr_variables = {}
+
+ xr_indexes: dict[Hashable, Index] = {}
+ for k, idx in indexes.items():
+ if not isinstance(idx, Index):
+ if getattr(idx, "dims", (k,)) != (k,):
+ raise ValueError(
+ f"Indexer has dimensions {idx.dims} that are different "
+ f"from that to be indexed along '{k}'"
+ )
+ data: T_DuckArray = as_compatible_data(idx)
+ pd_idx = safe_cast_to_index(data)
+ pd_idx.name = k
+ if isinstance(pd_idx, pd.MultiIndex):
+ idx = PandasMultiIndex(pd_idx, k)
+ else:
+ idx = PandasIndex(pd_idx, k, coord_dtype=data.dtype)
+ xr_variables.update(idx.create_variables())
+ xr_indexes[k] = idx
+
+ normalized_indexes = {}
+ normalized_index_vars = {}
+ for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index():
+ coord_names_and_dims = []
+ all_dims: set[Hashable] = set()
+
+ for name, var in index_vars.items():
+ dims = var.dims
+ coord_names_and_dims.append((name, dims))
+ all_dims.update(dims)
+
+ exclude_dims = all_dims & self.exclude_dims
+ if exclude_dims == all_dims:
+ continue
+ elif exclude_dims:
+ excl_dims_str = ", ".join(str(d) for d in exclude_dims)
+ incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims)
+ raise ValueError(
+ f"cannot exclude dimension(s) {excl_dims_str} from alignment because "
+ "these are used by an index together with non-excluded dimensions "
+ f"{incl_dims_str}"
+ )
+
+ key = (tuple(coord_names_and_dims), type(idx))
+ normalized_indexes[key] = idx
+ normalized_index_vars[key] = index_vars
+
+ return normalized_indexes, normalized_index_vars
+
+ def find_matching_indexes(self) -> None:
+ all_indexes: dict[MatchingIndexKey, list[Index]]
+ all_index_vars: dict[MatchingIndexKey, list[dict[Hashable, Variable]]]
+ all_indexes_dim_sizes: dict[MatchingIndexKey, dict[Hashable, set]]
+ objects_matching_indexes: list[dict[MatchingIndexKey, Index]]
+
+ all_indexes = defaultdict(list)
+ all_index_vars = defaultdict(list)
+ all_indexes_dim_sizes = defaultdict(lambda: defaultdict(set))
+ objects_matching_indexes = []
+
+ for obj in self.objects:
+ obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes)
+ objects_matching_indexes.append(obj_indexes)
+ for key, idx in obj_indexes.items():
+ all_indexes[key].append(idx)
+ for key, index_vars in obj_index_vars.items():
+ all_index_vars[key].append(index_vars)
+ for dim, size in calculate_dimensions(index_vars).items():
+ all_indexes_dim_sizes[key][dim].add(size)
+
+ self.objects_matching_indexes = tuple(objects_matching_indexes)
+ self.all_indexes = all_indexes
+ self.all_index_vars = all_index_vars
+
+ if self.join == "override":
+ for dim_sizes in all_indexes_dim_sizes.values():
+ for dim, sizes in dim_sizes.items():
+ if len(sizes) > 1:
+ raise ValueError(
+ "cannot align objects with join='override' with matching indexes "
+ f"along dimension {dim!r} that don't have the same size"
+ )
+
+ def find_matching_unindexed_dims(self) -> None:
+ unindexed_dim_sizes = defaultdict(set)
+
+ for obj in self.objects:
+ for dim in obj.dims:
+ if dim not in self.exclude_dims and dim not in obj.xindexes.dims:
+ unindexed_dim_sizes[dim].add(obj.sizes[dim])
+
+ self.unindexed_dim_sizes = unindexed_dim_sizes
+
+ def assert_no_index_conflict(self) -> None:
"""Check for uniqueness of both coordinate and dimension names across all sets
of matching indexes.
@@ -126,9 +297,34 @@ class Aligner(Generic[T_Alignable]):
(ref: https://github.com/pydata/xarray/issues/1603#issuecomment-442965602)
"""
- pass
-
- def _need_reindex(self, dim, cmp_indexes) ->bool:
+ matching_keys = set(self.all_indexes) | set(self.indexes)
+
+ coord_count: dict[Hashable, int] = defaultdict(int)
+ dim_count: dict[Hashable, int] = defaultdict(int)
+ for coord_names_dims, _ in matching_keys:
+ dims_set: set[Hashable] = set()
+ for name, dims in coord_names_dims:
+ coord_count[name] += 1
+ dims_set.update(dims)
+ for dim in dims_set:
+ dim_count[dim] += 1
+
+ for count, msg in [(coord_count, "coordinates"), (dim_count, "dimensions")]:
+ dup = {k: v for k, v in count.items() if v > 1}
+ if dup:
+ items_msg = ", ".join(
+ f"{k!r} ({v} conflicting indexes)" for k, v in dup.items()
+ )
+ raise ValueError(
+ "cannot re-index or align objects with conflicting indexes found for "
+ f"the following {msg}: {items_msg}\n"
+ "Conflicting indexes may occur when\n"
+ "- they relate to different sets of coordinate and/or dimension names\n"
+ "- they don't have the same type\n"
+ "- they may be used to reindex data along common dimensions"
+ )
+
+ def _need_reindex(self, dim, cmp_indexes) -> bool:
"""Whether or not we need to reindex variables for a set of
matching indexes.
@@ -139,23 +335,354 @@ class Aligner(Generic[T_Alignable]):
pandas). This is useful, e.g., for overwriting such duplicate indexes.
"""
- pass
+ if not indexes_all_equal(cmp_indexes):
+ # always reindex when matching indexes are not equal
+ return True
+
+ unindexed_dims_sizes = {}
+ for d in dim:
+ if d in self.unindexed_dim_sizes:
+ sizes = self.unindexed_dim_sizes[d]
+ if len(sizes) > 1:
+ # reindex if different sizes are found for unindexed dims
+ return True
+ else:
+ unindexed_dims_sizes[d] = next(iter(sizes))
+
+ if unindexed_dims_sizes:
+ indexed_dims_sizes = {}
+ for cmp in cmp_indexes:
+ index_vars = cmp[1]
+ for var in index_vars.values():
+ indexed_dims_sizes.update(var.sizes)
+
+ for d, size in unindexed_dims_sizes.items():
+ if indexed_dims_sizes.get(d, -1) != size:
+ # reindex if unindexed dimension size doesn't match
+ return True
+
+ return False
+
+ def _get_index_joiner(self, index_cls) -> Callable:
+ if self.join in ["outer", "inner"]:
+ return functools.partial(
+ functools.reduce,
+ functools.partial(index_cls.join, how=self.join),
+ )
+ elif self.join == "left":
+ return operator.itemgetter(0)
+ elif self.join == "right":
+ return operator.itemgetter(-1)
+ elif self.join == "override":
+ # We rewrite all indexes and then use join='left'
+ return operator.itemgetter(0)
+ else:
+ # join='exact' return dummy lambda (error is raised)
+ return lambda _: None
- def align_indexes(self) ->None:
+ def align_indexes(self) -> None:
"""Compute all aligned indexes and their corresponding coordinate variables."""
- pass
-
-
-T_Obj1 = TypeVar('T_Obj1', bound='Alignable')
-T_Obj2 = TypeVar('T_Obj2', bound='Alignable')
-T_Obj3 = TypeVar('T_Obj3', bound='Alignable')
-T_Obj4 = TypeVar('T_Obj4', bound='Alignable')
-T_Obj5 = TypeVar('T_Obj5', bound='Alignable')
-
-def align(*objects: T_Alignable, join: JoinOptions='inner', copy: bool=True,
- indexes=None, exclude: (str | Iterable[Hashable])=frozenset(),
- fill_value=dtypes.NA) ->tuple[T_Alignable, ...]:
+ aligned_indexes = {}
+ aligned_index_vars = {}
+ reindex = {}
+ new_indexes = {}
+ new_index_vars = {}
+
+ for key, matching_indexes in self.all_indexes.items():
+ matching_index_vars = self.all_index_vars[key]
+ dims = {d for coord in matching_index_vars[0].values() for d in coord.dims}
+ index_cls = key[1]
+
+ if self.join == "override":
+ joined_index = matching_indexes[0]
+ joined_index_vars = matching_index_vars[0]
+ need_reindex = False
+ elif key in self.indexes:
+ joined_index = self.indexes[key]
+ joined_index_vars = self.index_vars[key]
+ cmp_indexes = list(
+ zip(
+ [joined_index] + matching_indexes,
+ [joined_index_vars] + matching_index_vars,
+ )
+ )
+ need_reindex = self._need_reindex(dims, cmp_indexes)
+ else:
+ if len(matching_indexes) > 1:
+ need_reindex = self._need_reindex(
+ dims,
+ list(zip(matching_indexes, matching_index_vars)),
+ )
+ else:
+ need_reindex = False
+ if need_reindex:
+ if self.join == "exact":
+ raise ValueError(
+ "cannot align objects with join='exact' where "
+ "index/labels/sizes are not equal along "
+ "these coordinates (dimensions): "
+ + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0])
+ )
+ joiner = self._get_index_joiner(index_cls)
+ joined_index = joiner(matching_indexes)
+ if self.join == "left":
+ joined_index_vars = matching_index_vars[0]
+ elif self.join == "right":
+ joined_index_vars = matching_index_vars[-1]
+ else:
+ joined_index_vars = joined_index.create_variables()
+ else:
+ joined_index = matching_indexes[0]
+ joined_index_vars = matching_index_vars[0]
+
+ reindex[key] = need_reindex
+ aligned_indexes[key] = joined_index
+ aligned_index_vars[key] = joined_index_vars
+
+ for name, var in joined_index_vars.items():
+ new_indexes[name] = joined_index
+ new_index_vars[name] = var
+
+ # Explicitly provided indexes that are not found in objects to align
+ # may relate to unindexed dimensions so we add them too
+ for key, idx in self.indexes.items():
+ if key not in aligned_indexes:
+ index_vars = self.index_vars[key]
+ reindex[key] = False
+ aligned_indexes[key] = idx
+ aligned_index_vars[key] = index_vars
+ for name, var in index_vars.items():
+ new_indexes[name] = idx
+ new_index_vars[name] = var
+
+ self.aligned_indexes = aligned_indexes
+ self.aligned_index_vars = aligned_index_vars
+ self.reindex = reindex
+ self.new_indexes = Indexes(new_indexes, new_index_vars)
+
+ def assert_unindexed_dim_sizes_equal(self) -> None:
+ for dim, sizes in self.unindexed_dim_sizes.items():
+ index_size = self.new_indexes.dims.get(dim)
+ if index_size is not None:
+ sizes.add(index_size)
+ add_err_msg = (
+ f" (note: an index is found along that dimension "
+ f"with size={index_size!r})"
+ )
+ else:
+ add_err_msg = ""
+ if len(sizes) > 1:
+ raise ValueError(
+ f"cannot reindex or align along dimension {dim!r} "
+ f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg
+ )
+
+ def override_indexes(self) -> None:
+ objects = list(self.objects)
+
+ for i, obj in enumerate(objects[1:]):
+ new_indexes = {}
+ new_variables = {}
+ matching_indexes = self.objects_matching_indexes[i + 1]
+
+ for key, aligned_idx in self.aligned_indexes.items():
+ obj_idx = matching_indexes.get(key)
+ if obj_idx is not None:
+ for name, var in self.aligned_index_vars[key].items():
+ new_indexes[name] = aligned_idx
+ new_variables[name] = var.copy(deep=self.copy)
+
+ objects[i + 1] = obj._overwrite_indexes(new_indexes, new_variables)
+
+ self.results = tuple(objects)
+
+ def _get_dim_pos_indexers(
+ self,
+ matching_indexes: dict[MatchingIndexKey, Index],
+ ) -> dict[Hashable, Any]:
+ dim_pos_indexers = {}
+
+ for key, aligned_idx in self.aligned_indexes.items():
+ obj_idx = matching_indexes.get(key)
+ if obj_idx is not None:
+ if self.reindex[key]:
+ indexers = obj_idx.reindex_like(aligned_idx, **self.reindex_kwargs)
+ dim_pos_indexers.update(indexers)
+
+ return dim_pos_indexers
+
+ def _get_indexes_and_vars(
+ self,
+ obj: T_Alignable,
+ matching_indexes: dict[MatchingIndexKey, Index],
+ ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
+ new_indexes = {}
+ new_variables = {}
+
+ for key, aligned_idx in self.aligned_indexes.items():
+ index_vars = self.aligned_index_vars[key]
+ obj_idx = matching_indexes.get(key)
+ if obj_idx is None:
+ # add the index if it relates to unindexed dimensions in obj
+ index_vars_dims = {d for var in index_vars.values() for d in var.dims}
+ if index_vars_dims <= set(obj.dims):
+ obj_idx = aligned_idx
+ if obj_idx is not None:
+ for name, var in index_vars.items():
+ new_indexes[name] = aligned_idx
+ new_variables[name] = var.copy(deep=self.copy)
+
+ return new_indexes, new_variables
+
+ def _reindex_one(
+ self,
+ obj: T_Alignable,
+ matching_indexes: dict[MatchingIndexKey, Index],
+ ) -> T_Alignable:
+ new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes)
+ dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes)
+
+ return obj._reindex_callback(
+ self,
+ dim_pos_indexers,
+ new_variables,
+ new_indexes,
+ self.fill_value,
+ self.exclude_dims,
+ self.exclude_vars,
+ )
+
+ def reindex_all(self) -> None:
+ self.results = tuple(
+ self._reindex_one(obj, matching_indexes)
+ for obj, matching_indexes in zip(
+ self.objects, self.objects_matching_indexes
+ )
+ )
+
+ def align(self) -> None:
+ if not self.indexes and len(self.objects) == 1:
+ # fast path for the trivial case
+ (obj,) = self.objects
+ self.results = (obj.copy(deep=self.copy),)
+ return
+
+ self.find_matching_indexes()
+ self.find_matching_unindexed_dims()
+ self.assert_no_index_conflict()
+ self.align_indexes()
+ self.assert_unindexed_dim_sizes_equal()
+
+ if self.join == "override":
+ self.override_indexes()
+ elif self.join == "exact" and not self.copy:
+ self.results = self.objects
+ else:
+ self.reindex_all()
+
+
+T_Obj1 = TypeVar("T_Obj1", bound="Alignable")
+T_Obj2 = TypeVar("T_Obj2", bound="Alignable")
+T_Obj3 = TypeVar("T_Obj3", bound="Alignable")
+T_Obj4 = TypeVar("T_Obj4", bound="Alignable")
+T_Obj5 = TypeVar("T_Obj5", bound="Alignable")
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1]: ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2]: ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3]: ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: ...
+
+
+@overload
+def align(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ obj5: T_Obj5,
+ /,
+ *,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: ...
+
+
+@overload
+def align(
+ *objects: T_Alignable,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Alignable, ...]: ...
+
+
+def align(
+ *objects: T_Alignable,
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ fill_value=dtypes.NA,
+) -> tuple[T_Alignable, ...]:
"""
Given any number of Dataset and/or DataArray objects, returns new
objects with aligned indexes and dimension sizes.
@@ -344,44 +871,279 @@ def align(*objects: T_Alignable, join: JoinOptions='inner', copy: bool=True,
* lon (lon) float64 16B 100.0 120.0
"""
- pass
-
-
-def deep_align(objects: Iterable[Any], join: JoinOptions='inner', copy:
- bool=True, indexes=None, exclude: (str | Iterable[Hashable])=frozenset(
- ), raise_on_invalid: bool=True, fill_value=dtypes.NA) ->list[Any]:
+ aligner = Aligner(
+ objects,
+ join=join,
+ copy=copy,
+ indexes=indexes,
+ exclude_dims=exclude,
+ fill_value=fill_value,
+ )
+ aligner.align()
+ return aligner.results
+
+
+def deep_align(
+ objects: Iterable[Any],
+ join: JoinOptions = "inner",
+ copy: bool = True,
+ indexes=None,
+ exclude: str | Iterable[Hashable] = frozenset(),
+ raise_on_invalid: bool = True,
+ fill_value=dtypes.NA,
+) -> list[Any]:
"""Align objects for merging, recursing into dictionary values.
This function is not public API.
"""
- pass
+ from xarray.core.coordinates import Coordinates
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ if indexes is None:
+ indexes = {}
+
+ def is_alignable(obj):
+ return isinstance(obj, (Coordinates, DataArray, Dataset))
+
+ positions: list[int] = []
+ keys: list[type[object] | Hashable] = []
+ out: list[Any] = []
+ targets: list[Alignable] = []
+ no_key: Final = object()
+ not_replaced: Final = object()
+ for position, variables in enumerate(objects):
+ if is_alignable(variables):
+ positions.append(position)
+ keys.append(no_key)
+ targets.append(variables)
+ out.append(not_replaced)
+ elif is_dict_like(variables):
+ current_out = {}
+ for k, v in variables.items():
+ if is_alignable(v) and k not in indexes:
+ # Skip variables in indexes for alignment, because these
+ # should to be overwritten instead:
+ # https://github.com/pydata/xarray/issues/725
+ # https://github.com/pydata/xarray/issues/3377
+ # TODO(shoyer): doing this here feels super-hacky -- can we
+ # move it explicitly into merge instead?
+ positions.append(position)
+ keys.append(k)
+ targets.append(v)
+ current_out[k] = not_replaced
+ else:
+ current_out[k] = v
+ out.append(current_out)
+ elif raise_on_invalid:
+ raise ValueError(
+ "object to align is neither an xarray.Dataset, "
+ f"an xarray.DataArray nor a dictionary: {variables!r}"
+ )
+ else:
+ out.append(variables)
+
+ aligned = align(
+ *targets,
+ join=join,
+ copy=copy,
+ indexes=indexes,
+ exclude=exclude,
+ fill_value=fill_value,
+ )
+
+ for position, key, aligned_obj in zip(positions, keys, aligned):
+ if key is no_key:
+ out[position] = aligned_obj
+ else:
+ out[position][key] = aligned_obj
+
+ return out
-def reindex(obj: T_Alignable, indexers: Mapping[Any, Any], method: (str |
- None)=None, tolerance: (float | Iterable[float] | str | None)=None,
- copy: bool=True, fill_value: Any=dtypes.NA, sparse: bool=False,
- exclude_vars: Iterable[Hashable]=frozenset()) ->T_Alignable:
+def reindex(
+ obj: T_Alignable,
+ indexers: Mapping[Any, Any],
+ method: str | None = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value: Any = dtypes.NA,
+ sparse: bool = False,
+ exclude_vars: Iterable[Hashable] = frozenset(),
+) -> T_Alignable:
"""Re-index either a Dataset or a DataArray.
Not public API.
"""
- pass
+
+ # TODO: (benbovy - explicit indexes): uncomment?
+ # --> from reindex docstrings: "any mis-matched dimension is simply ignored"
+ # bad_keys = [k for k in indexers if k not in obj._indexes and k not in obj.dims]
+ # if bad_keys:
+ # raise ValueError(
+ # f"indexer keys {bad_keys} do not correspond to any indexed coordinate "
+ # "or unindexed dimension in the object to reindex"
+ # )
+
+ aligner = Aligner(
+ (obj,),
+ indexes=indexers,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ sparse=sparse,
+ exclude_vars=exclude_vars,
+ )
+ aligner.align()
+ return aligner.results[0]
-def reindex_like(obj: T_Alignable, other: (Dataset | DataArray), method: (
- str | None)=None, tolerance: (float | Iterable[float] | str | None)=
- None, copy: bool=True, fill_value: Any=dtypes.NA) ->T_Alignable:
+def reindex_like(
+ obj: T_Alignable,
+ other: Dataset | DataArray,
+ method: str | None = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value: Any = dtypes.NA,
+) -> T_Alignable:
"""Re-index either a Dataset or a DataArray like another Dataset/DataArray.
Not public API.
"""
- pass
-
+ if not other._indexes:
+ # This check is not performed in Aligner.
+ for dim in other.dims:
+ if dim in obj.dims:
+ other_size = other.sizes[dim]
+ obj_size = obj.sizes[dim]
+ if other_size != obj_size:
+ raise ValueError(
+ "different size for unlabeled "
+ f"dimension on argument {dim!r}: {other_size!r} vs {obj_size!r}"
+ )
+
+ return reindex(
+ obj,
+ indexers=other.xindexes,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ )
+
+
+def _get_broadcast_dims_map_common_coords(args, exclude):
+ common_coords = {}
+ dims_map = {}
+ for arg in args:
+ for dim in arg.dims:
+ if dim not in common_coords and dim not in exclude:
+ dims_map[dim] = arg.sizes[dim]
+ if dim in arg._indexes:
+ common_coords.update(arg.xindexes.get_all_coords(dim))
+
+ return dims_map, common_coords
+
+
+def _broadcast_helper(
+ arg: T_Alignable, exclude, dims_map, common_coords
+) -> T_Alignable:
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
-def broadcast(*args: T_Alignable, exclude: (str | Iterable[Hashable] | None
- )=None) ->tuple[T_Alignable, ...]:
+ def _set_dims(var):
+ # Add excluded dims to a copy of dims_map
+ var_dims_map = dims_map.copy()
+ for dim in exclude:
+ with suppress(ValueError):
+ # ignore dim not in var.dims
+ var_dims_map[dim] = var.shape[var.dims.index(dim)]
+
+ return var.set_dims(var_dims_map)
+
+ def _broadcast_array(array: T_DataArray) -> T_DataArray:
+ data = _set_dims(array.variable)
+ coords = dict(array.coords)
+ coords.update(common_coords)
+ return array.__class__(
+ data, coords, data.dims, name=array.name, attrs=array.attrs
+ )
+
+ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
+ data_vars = {k: _set_dims(ds.variables[k]) for k in ds.data_vars}
+ coords = dict(ds.coords)
+ coords.update(common_coords)
+ return ds.__class__(data_vars, coords, ds.attrs)
+
+ # remove casts once https://github.com/python/mypy/issues/12800 is resolved
+ if isinstance(arg, DataArray):
+ return cast(T_Alignable, _broadcast_array(arg))
+ elif isinstance(arg, Dataset):
+ return cast(T_Alignable, _broadcast_dataset(arg))
+ else:
+ raise ValueError("all input must be Dataset or DataArray objects")
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Obj1]: ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Obj1, T_Obj2]: ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ /,
+ *,
+ exclude: str | Iterable[Hashable] | None = None,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3]: ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ /,
+ *,
+ exclude: str | Iterable[Hashable] | None = None,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: ...
+
+
+@overload
+def broadcast(
+ obj1: T_Obj1,
+ obj2: T_Obj2,
+ obj3: T_Obj3,
+ obj4: T_Obj4,
+ obj5: T_Obj5,
+ /,
+ *,
+ exclude: str | Iterable[Hashable] | None = None,
+) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: ...
+
+
+@overload
+def broadcast(
+ *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Alignable, ...]: ...
+
+
+def broadcast(
+ *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
+) -> tuple[T_Alignable, ...]:
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.
@@ -444,4 +1206,12 @@ def broadcast(*args: T_Alignable, exclude: (str | Iterable[Hashable] | None
a (x, y) int64 48B 1 1 2 2 3 3
b (x, y) int64 48B 5 6 5 6 5 6
"""
- pass
+
+ if exclude is None:
+ exclude = set()
+ args = align(*args, join="outer", copy=False, exclude=exclude)
+
+ dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude)
+ result = [_broadcast_helper(arg, exclude, dims_map, common_coords) for arg in args]
+
+ return tuple(result)
diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py
index 27115444..734d7b32 100644
--- a/xarray/core/arithmetic.py
+++ b/xarray/core/arithmetic.py
@@ -1,8 +1,19 @@
"""Base classes implementing arithmetic for xarray objects."""
+
from __future__ import annotations
+
import numbers
+
import numpy as np
-from xarray.core._typed_ops import DataArrayGroupByOpsMixin, DataArrayOpsMixin, DatasetGroupByOpsMixin, DatasetOpsMixin, VariableOpsMixin
+
+# _typed_ops.py is a generated file
+from xarray.core._typed_ops import (
+ DataArrayGroupByOpsMixin,
+ DataArrayOpsMixin,
+ DatasetGroupByOpsMixin,
+ DatasetOpsMixin,
+ VariableOpsMixin,
+)
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.ops import IncludeNumpySameMethods, IncludeReduceMethods
from xarray.core.options import OPTIONS, _get_keep_attrs
@@ -14,58 +25,117 @@ class SupportsArithmetic:
Used by Dataset, DataArray, Variable and GroupBy.
"""
+
__slots__ = ()
- _HANDLED_TYPES = np.generic, numbers.Number, bytes, str
+
+ # TODO: implement special methods for arithmetic here rather than injecting
+ # them in xarray/core/ops.py. Ideally, do so by inheriting from
+ # numpy.lib.mixins.NDArrayOperatorsMixin.
+
+ # TODO: allow extending this with some sort of registration system
+ _HANDLED_TYPES = (
+ np.generic,
+ numbers.Number,
+ bytes,
+ str,
+ )
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
from xarray.core.computation import apply_ufunc
- out = kwargs.get('out', ())
- for x in (inputs + out):
- if not is_duck_array(x) and not isinstance(x, self.
- _HANDLED_TYPES + (SupportsArithmetic,)):
+
+ # See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
+ out = kwargs.get("out", ())
+ for x in inputs + out:
+ if not is_duck_array(x) and not isinstance(
+ x, self._HANDLED_TYPES + (SupportsArithmetic,)
+ ):
return NotImplemented
+
if ufunc.signature is not None:
raise NotImplementedError(
- f'{ufunc} not supported: xarray objects do not directly implement generalized ufuncs. Instead, use xarray.apply_ufunc or explicitly convert to xarray objects to NumPy arrays (e.g., with `.values`).'
- )
- if method != '__call__':
+ f"{ufunc} not supported: xarray objects do not directly implement "
+ "generalized ufuncs. Instead, use xarray.apply_ufunc or "
+ "explicitly convert to xarray objects to NumPy arrays "
+ "(e.g., with `.values`)."
+ )
+
+ if method != "__call__":
+ # TODO: support other methods, e.g., reduce and accumulate.
raise NotImplementedError(
- f'{method} method for ufunc {ufunc} is not implemented on xarray objects, which currently only support the __call__ method. As an alternative, consider explicitly converting xarray objects to NumPy arrays (e.g., with `.values`).'
- )
+ f"{method} method for ufunc {ufunc} is not implemented on xarray objects, "
+ "which currently only support the __call__ method. As an "
+ "alternative, consider explicitly converting xarray objects "
+ "to NumPy arrays (e.g., with `.values`)."
+ )
+
if any(isinstance(o, SupportsArithmetic) for o in out):
+ # TODO: implement this with logic like _inplace_binary_op. This
+ # will be necessary to use NDArrayOperatorsMixin.
raise NotImplementedError(
- 'xarray objects are not yet supported in the `out` argument for ufuncs. As an alternative, consider explicitly converting xarray objects to NumPy arrays (e.g., with `.values`).'
- )
- join = dataset_join = OPTIONS['arithmetic_join']
- return apply_ufunc(ufunc, *inputs, input_core_dims=((),) * ufunc.
- nin, output_core_dims=((),) * ufunc.nout, join=join,
- dataset_join=dataset_join, dataset_fill_value=np.nan, kwargs=
- kwargs, dask='allowed', keep_attrs=_get_keep_attrs(default=True))
-
-
-class VariableArithmetic(ImplementsArrayReduce, IncludeNumpySameMethods,
- SupportsArithmetic, VariableOpsMixin):
+ "xarray objects are not yet supported in the `out` argument "
+ "for ufuncs. As an alternative, consider explicitly "
+ "converting xarray objects to NumPy arrays (e.g., with "
+ "`.values`)."
+ )
+
+ join = dataset_join = OPTIONS["arithmetic_join"]
+
+ return apply_ufunc(
+ ufunc,
+ *inputs,
+ input_core_dims=((),) * ufunc.nin,
+ output_core_dims=((),) * ufunc.nout,
+ join=join,
+ dataset_join=dataset_join,
+ dataset_fill_value=np.nan,
+ kwargs=kwargs,
+ dask="allowed",
+ keep_attrs=_get_keep_attrs(default=True),
+ )
+
+
+class VariableArithmetic(
+ ImplementsArrayReduce,
+ IncludeNumpySameMethods,
+ SupportsArithmetic,
+ VariableOpsMixin,
+):
__slots__ = ()
+ # prioritize our operations over those of numpy.ndarray (priority=0)
__array_priority__ = 50
-class DatasetArithmetic(ImplementsDatasetReduce, SupportsArithmetic,
- DatasetOpsMixin):
+class DatasetArithmetic(
+ ImplementsDatasetReduce,
+ SupportsArithmetic,
+ DatasetOpsMixin,
+):
__slots__ = ()
__array_priority__ = 50
-class DataArrayArithmetic(ImplementsArrayReduce, IncludeNumpySameMethods,
- SupportsArithmetic, DataArrayOpsMixin):
+class DataArrayArithmetic(
+ ImplementsArrayReduce,
+ IncludeNumpySameMethods,
+ SupportsArithmetic,
+ DataArrayOpsMixin,
+):
__slots__ = ()
+ # priority must be higher than Variable to properly work with binary ufuncs
__array_priority__ = 60
-class DataArrayGroupbyArithmetic(SupportsArithmetic, DataArrayGroupByOpsMixin):
+class DataArrayGroupbyArithmetic(
+ SupportsArithmetic,
+ DataArrayGroupByOpsMixin,
+):
__slots__ = ()
-class DatasetGroupbyArithmetic(SupportsArithmetic, DatasetGroupByOpsMixin):
+class DatasetGroupbyArithmetic(
+ SupportsArithmetic,
+ DatasetGroupByOpsMixin,
+):
__slots__ = ()
diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py
index a22b22b1..3a94513d 100644
--- a/xarray/core/array_api_compat.py
+++ b/xarray/core/array_api_compat.py
@@ -1 +1,44 @@
import numpy as np
+
+
+def is_weak_scalar_type(t):
+ return isinstance(t, (bool, int, float, complex, str, bytes))
+
+
+def _future_array_api_result_type(*arrays_and_dtypes, xp):
+ # fallback implementation for `xp.result_type` with python scalars. Can be removed once a
+ # version of the Array API that includes https://github.com/data-apis/array-api/issues/805
+ # can be required
+ strongly_dtyped = [t for t in arrays_and_dtypes if not is_weak_scalar_type(t)]
+ weakly_dtyped = [t for t in arrays_and_dtypes if is_weak_scalar_type(t)]
+
+ if not strongly_dtyped:
+ strongly_dtyped = [
+ xp.asarray(x) if not isinstance(x, type) else x for x in weakly_dtyped
+ ]
+ weakly_dtyped = []
+
+ dtype = xp.result_type(*strongly_dtyped)
+ if not weakly_dtyped:
+ return dtype
+
+ possible_dtypes = {
+ complex: "complex64",
+ float: "float32",
+ int: "int8",
+ bool: "bool",
+ str: "str",
+ bytes: "bytes",
+ }
+ dtypes = [possible_dtypes.get(type(x), "object") for x in weakly_dtyped]
+
+ return xp.result_type(dtype, *dtypes)
+
+
+def result_type(*arrays_and_dtypes, xp) -> np.dtype:
+ if xp is np or any(
+ isinstance(getattr(t, "dtype", t), np.dtype) for t in arrays_and_dtypes
+ ):
+ return xp.result_type(*arrays_and_dtypes)
+ else:
+ return _future_array_api_result_type(*arrays_and_dtypes, xp=xp)
diff --git a/xarray/core/combine.py b/xarray/core/combine.py
index d78cb540..5cb0a341 100644
--- a/xarray/core/combine.py
+++ b/xarray/core/combine.py
@@ -1,19 +1,27 @@
from __future__ import annotations
+
import itertools
from collections import Counter
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Literal, Union
+
import pandas as pd
+
from xarray.core import dtypes
from xarray.core.concat import concat
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.merge import merge
from xarray.core.utils import iterate_nested
+
if TYPE_CHECKING:
from xarray.core.types import CombineAttrsOptions, CompatOptions, JoinOptions
+def _infer_concat_order_from_positions(datasets):
+ return dict(_infer_tile_ids_from_nested_list(datasets, ()))
+
+
def _infer_tile_ids_from_nested_list(entry, current_pos):
"""
Given a list of lists (of lists...) of objects, returns a iterator
@@ -35,7 +43,112 @@ def _infer_tile_ids_from_nested_list(entry, current_pos):
-------
combined_tile_ids : dict[tuple(int, ...), obj]
"""
- pass
+
+ if isinstance(entry, list):
+ for i, item in enumerate(entry):
+ yield from _infer_tile_ids_from_nested_list(item, current_pos + (i,))
+ else:
+ yield current_pos, entry
+
+
+def _ensure_same_types(series, dim):
+ if series.dtype == object:
+ types = set(series.map(type))
+ if len(types) > 1:
+ try:
+ import cftime
+
+ cftimes = any(issubclass(t, cftime.datetime) for t in types)
+ except ImportError:
+ cftimes = False
+
+ types = ", ".join(t.__name__ for t in types)
+
+ error_msg = (
+ f"Cannot combine along dimension '{dim}' with mixed types."
+ f" Found: {types}."
+ )
+ if cftimes:
+ error_msg = (
+ f"{error_msg} If importing data directly from a file then "
+ f"setting `use_cftime=True` may fix this issue."
+ )
+
+ raise TypeError(error_msg)
+
+
+def _infer_concat_order_from_coords(datasets):
+ concat_dims = []
+ tile_ids = [() for ds in datasets]
+
+ # All datasets have same variables because they've been grouped as such
+ ds0 = datasets[0]
+ for dim in ds0.dims:
+ # Check if dim is a coordinate dimension
+ if dim in ds0:
+ # Need to read coordinate values to do ordering
+ indexes = [ds._indexes.get(dim) for ds in datasets]
+ if any(index is None for index in indexes):
+ raise ValueError(
+ "Every dimension needs a coordinate for "
+ "inferring concatenation order"
+ )
+
+ # TODO (benbovy, flexible indexes): support flexible indexes?
+ indexes = [index.to_pandas_index() for index in indexes]
+
+ # If dimension coordinate values are same on every dataset then
+ # should be leaving this dimension alone (it's just a "bystander")
+ if not all(index.equals(indexes[0]) for index in indexes[1:]):
+ # Infer order datasets should be arranged in along this dim
+ concat_dims.append(dim)
+
+ if all(index.is_monotonic_increasing for index in indexes):
+ ascending = True
+ elif all(index.is_monotonic_decreasing for index in indexes):
+ ascending = False
+ else:
+ raise ValueError(
+ f"Coordinate variable {dim} is neither "
+ "monotonically increasing nor "
+ "monotonically decreasing on all datasets"
+ )
+
+ # Assume that any two datasets whose coord along dim starts
+ # with the same value have the same coord values throughout.
+ if any(index.size == 0 for index in indexes):
+ raise ValueError("Cannot handle size zero dimensions")
+ first_items = pd.Index([index[0] for index in indexes])
+
+ series = first_items.to_series()
+
+ # ensure series does not contain mixed types, e.g. cftime calendars
+ _ensure_same_types(series, dim)
+
+ # Sort datasets along dim
+ # We want rank but with identical elements given identical
+ # position indices - they should be concatenated along another
+ # dimension, not along this one
+ rank = series.rank(
+ method="dense", ascending=ascending, numeric_only=False
+ )
+ order = rank.astype(int).values - 1
+
+ # Append positions along extra dimension to structure which
+ # encodes the multi-dimensional concatenation order
+ tile_ids = [
+ tile_id + (position,) for tile_id, position in zip(tile_ids, order)
+ ]
+
+ if len(datasets) > 1 and not concat_dims:
+ raise ValueError(
+ "Could not find any dimension coordinates to use to "
+ "order the datasets for concatenation"
+ )
+
+ combined_ids = dict(zip(tile_ids, datasets))
+
+ return combined_ids, concat_dims
def _check_dimension_depth_tile_ids(combined_tile_ids):
@@ -43,17 +156,43 @@ def _check_dimension_depth_tile_ids(combined_tile_ids):
Check all tuples are the same length, i.e. check that all lists are
nested to the same depth.
"""
- pass
+ tile_ids = combined_tile_ids.keys()
+ nesting_depths = [len(tile_id) for tile_id in tile_ids]
+ if not nesting_depths:
+ nesting_depths = [0]
+ if set(nesting_depths) != {nesting_depths[0]}:
+ raise ValueError(
+ "The supplied objects do not form a hypercube because"
+ " sub-lists do not have consistent depths"
+ )
+ # return these just to be reused in _check_shape_tile_ids
+ return tile_ids, nesting_depths
def _check_shape_tile_ids(combined_tile_ids):
"""Check all lists along one dimension are same length."""
- pass
-
-
-def _combine_nd(combined_ids, concat_dims, data_vars='all', coords=
- 'different', compat: CompatOptions='no_conflicts', fill_value=dtypes.NA,
- join: JoinOptions='outer', combine_attrs: CombineAttrsOptions='drop'):
+ tile_ids, nesting_depths = _check_dimension_depth_tile_ids(combined_tile_ids)
+ for dim in range(nesting_depths[0]):
+ indices_along_dim = [tile_id[dim] for tile_id in tile_ids]
+ occurrences = Counter(indices_along_dim)
+ if len(set(occurrences.values())) != 1:
+ raise ValueError(
+ "The supplied objects do not form a hypercube "
+ "because sub-lists do not have consistent "
+ f"lengths along dimension {dim}"
+ )
+
+
+def _combine_nd(
+ combined_ids,
+ concat_dims,
+ data_vars="all",
+ coords="different",
+ compat: CompatOptions = "no_conflicts",
+ fill_value=dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "drop",
+):
"""
Combines an N-dimensional structure of datasets into one by applying a
series of either concat and merge operations along each dimension.
@@ -76,27 +215,171 @@ def _combine_nd(combined_ids, concat_dims, data_vars='all', coords=
-------
combined_ds : xarray.Dataset
"""
- pass
-
-def _combine_1d(datasets, concat_dim, compat: CompatOptions='no_conflicts',
- data_vars='all', coords='different', fill_value=dtypes.NA, join:
- JoinOptions='outer', combine_attrs: CombineAttrsOptions='drop'):
+ example_tile_id = next(iter(combined_ids.keys()))
+
+ n_dims = len(example_tile_id)
+ if len(concat_dims) != n_dims:
+ raise ValueError(
+ f"concat_dims has length {len(concat_dims)} but the datasets "
+ f"passed are nested in a {n_dims}-dimensional structure"
+ )
+
+ # Each iteration of this loop reduces the length of the tile_ids tuples
+ # by one. It always combines along the first dimension, removing the first
+ # element of the tuple
+ for concat_dim in concat_dims:
+ combined_ids = _combine_all_along_first_dim(
+ combined_ids,
+ dim=concat_dim,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ (combined_ds,) = combined_ids.values()
+ return combined_ds
+
+
+def _combine_all_along_first_dim(
+ combined_ids,
+ dim,
+ data_vars,
+ coords,
+ compat: CompatOptions,
+ fill_value=dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "drop",
+):
+ # Group into lines of datasets which must be combined along dim
+ # need to sort by _new_tile_id first for groupby to work
+ # TODO: is the sorted need?
+ combined_ids = dict(sorted(combined_ids.items(), key=_new_tile_id))
+ grouped = itertools.groupby(combined_ids.items(), key=_new_tile_id)
+
+ # Combine all of these datasets along dim
+ new_combined_ids = {}
+ for new_id, group in grouped:
+ combined_ids = dict(sorted(group))
+ datasets = combined_ids.values()
+ new_combined_ids[new_id] = _combine_1d(
+ datasets, dim, compat, data_vars, coords, fill_value, join, combine_attrs
+ )
+ return new_combined_ids
+
+
+def _combine_1d(
+ datasets,
+ concat_dim,
+ compat: CompatOptions = "no_conflicts",
+ data_vars="all",
+ coords="different",
+ fill_value=dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "drop",
+):
"""
Applies either concat or merge to 1D list of datasets depending on value
of concat_dim
"""
- pass
-
-
-DATASET_HYPERCUBE = Union[Dataset, Iterable['DATASET_HYPERCUBE']]
-
-def combine_nested(datasets: DATASET_HYPERCUBE, concat_dim: (str |
- DataArray | None | Sequence[str | DataArray | pd.Index | None]), compat:
- str='no_conflicts', data_vars: str='all', coords: str='different',
- fill_value: object=dtypes.NA, join: JoinOptions='outer', combine_attrs:
- CombineAttrsOptions='drop') ->Dataset:
+ if concat_dim is not None:
+ try:
+ combined = concat(
+ datasets,
+ dim=concat_dim,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ except ValueError as err:
+ if "encountered unexpected variable" in str(err):
+ raise ValueError(
+ "These objects cannot be combined using only "
+ "xarray.combine_nested, instead either use "
+ "xarray.combine_by_coords, or do it manually "
+ "with xarray.concat, xarray.merge and "
+ "xarray.align"
+ )
+ else:
+ raise
+ else:
+ combined = merge(
+ datasets,
+ compat=compat,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+
+ return combined
+
+
+def _new_tile_id(single_id_ds_pair):
+ tile_id, ds = single_id_ds_pair
+ return tile_id[1:]
+
+
+def _nested_combine(
+ datasets,
+ concat_dims,
+ compat,
+ data_vars,
+ coords,
+ ids,
+ fill_value=dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "drop",
+):
+ if len(datasets) == 0:
+ return Dataset()
+
+ # Arrange datasets for concatenation
+ # Use information from the shape of the user input
+ if not ids:
+ # Determine tile_IDs by structure of input in N-D
+ # (i.e. ordering in list-of-lists)
+ combined_ids = _infer_concat_order_from_positions(datasets)
+ else:
+ # Already sorted so just use the ids already passed
+ combined_ids = dict(zip(ids, datasets))
+
+ # Check that the inferred shape is combinable
+ _check_shape_tile_ids(combined_ids)
+
+ # Apply series of concatenate or merge operations along each dimension
+ combined = _combine_nd(
+ combined_ids,
+ concat_dims,
+ compat=compat,
+ data_vars=data_vars,
+ coords=coords,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ return combined
+
+
+# Define type for arbitrarily-nested list of lists recursively:
+DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]]
+
+
+def combine_nested(
+ datasets: DATASET_HYPERCUBE,
+ concat_dim: str | DataArray | None | Sequence[str | DataArray | pd.Index | None],
+ compat: str = "no_conflicts",
+ data_vars: str = "all",
+ coords: str = "different",
+ fill_value: object = dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "drop",
+) -> Dataset:
"""
Explicitly combine an N-dimensional grid of datasets into one by using a
succession of concat and merge operations along each dimension of the grid.
@@ -130,7 +413,8 @@ def combine_nested(datasets: DATASET_HYPERCUBE, concat_dim: (str |
nested-list input along which to merge.
Must be the same length as the depth of the list passed to
``datasets``.
- compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional
+ compat : {"identical", "equals", "broadcast_equals", \
+ "no_conflicts", "override"}, optional
String indicating how to compare variables of the same name for
potential merge conflicts:
@@ -164,7 +448,8 @@ def combine_nested(datasets: DATASET_HYPERCUBE, concat_dim: (str |
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "drop"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -276,13 +561,45 @@ def combine_nested(datasets: DATASET_HYPERCUBE, concat_dim: (str |
concat
merge
"""
- pass
-
-
-def _combine_single_variable_hypercube(datasets, fill_value=dtypes.NA,
- data_vars='all', coords='different', compat: CompatOptions=
- 'no_conflicts', join: JoinOptions='outer', combine_attrs:
- CombineAttrsOptions='no_conflicts'):
+ mixed_datasets_and_arrays = any(
+ isinstance(obj, Dataset) for obj in iterate_nested(datasets)
+ ) and any(
+ isinstance(obj, DataArray) and obj.name is None
+ for obj in iterate_nested(datasets)
+ )
+ if mixed_datasets_and_arrays:
+ raise ValueError("Can't combine datasets with unnamed arrays.")
+
+ if isinstance(concat_dim, (str, DataArray)) or concat_dim is None:
+ concat_dim = [concat_dim]
+
+ # The IDs argument tells _nested_combine that datasets aren't yet sorted
+ return _nested_combine(
+ datasets,
+ concat_dims=concat_dim,
+ compat=compat,
+ data_vars=data_vars,
+ coords=coords,
+ ids=False,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+
+
+def vars_as_keys(ds):
+ return tuple(sorted(ds))
+
+
+def _combine_single_variable_hypercube(
+ datasets,
+ fill_value=dtypes.NA,
+ data_vars="all",
+ coords="different",
+ compat: CompatOptions = "no_conflicts",
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "no_conflicts",
+):
"""
Attempt to combine a list of Datasets into a hypercube using their
coordinates.
@@ -293,14 +610,55 @@ def _combine_single_variable_hypercube(datasets, fill_value=dtypes.NA,
This function is NOT part of the public API.
"""
- pass
-
-
-def combine_by_coords(data_objects: Iterable[Dataset | DataArray]=[],
- compat: CompatOptions='no_conflicts', data_vars: (Literal['all',
- 'minimal', 'different'] | list[str])='all', coords: str='different',
- fill_value: object=dtypes.NA, join: JoinOptions='outer', combine_attrs:
- CombineAttrsOptions='no_conflicts') ->(Dataset | DataArray):
+ if len(datasets) == 0:
+ raise ValueError(
+ "At least one Dataset is required to resolve variable names "
+ "for combined hypercube."
+ )
+
+ combined_ids, concat_dims = _infer_concat_order_from_coords(list(datasets))
+
+ if fill_value is None:
+ # check that datasets form complete hypercube
+ _check_shape_tile_ids(combined_ids)
+ else:
+ # check only that all datasets have same dimension depth for these
+ # vars
+ _check_dimension_depth_tile_ids(combined_ids)
+
+ # Concatenate along all of concat_dims one by one to create single ds
+ concatenated = _combine_nd(
+ combined_ids,
+ concat_dims=concat_dims,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+
+ # Check the overall coordinates are monotonically increasing
+ for dim in concat_dims:
+ indexes = concatenated.indexes.get(dim)
+ if not (indexes.is_monotonic_increasing or indexes.is_monotonic_decreasing):
+ raise ValueError(
+ "Resulting object does not have monotonic"
+ f" global indexes along dimension {dim}"
+ )
+
+ return concatenated
+
+
+def combine_by_coords(
+ data_objects: Iterable[Dataset | DataArray] = [],
+ compat: CompatOptions = "no_conflicts",
+ data_vars: Literal["all", "minimal", "different"] | list[str] = "all",
+ coords: str = "different",
+ fill_value: object = dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "no_conflicts",
+) -> Dataset | DataArray:
"""
Attempt to auto-magically combine the given datasets (or data arrays)
@@ -380,7 +738,8 @@ def combine_by_coords(data_objects: Iterable[Dataset | DataArray]=[],
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "no_conflicts"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "no_conflicts"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -551,4 +910,68 @@ def combine_by_coords(data_objects: Iterable[Dataset | DataArray]=[],
Finally, if you attempt to combine a mix of unnamed DataArrays with either named
DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation).
"""
- pass
+
+ if not data_objects:
+ return Dataset()
+
+ objs_are_unnamed_dataarrays = [
+ isinstance(data_object, DataArray) and data_object.name is None
+ for data_object in data_objects
+ ]
+ if any(objs_are_unnamed_dataarrays):
+ if all(objs_are_unnamed_dataarrays):
+ # Combine into a single larger DataArray
+ temp_datasets = [
+ unnamed_dataarray._to_temp_dataset()
+ for unnamed_dataarray in data_objects
+ ]
+
+ combined_temp_dataset = _combine_single_variable_hypercube(
+ temp_datasets,
+ fill_value=fill_value,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ return DataArray()._from_temp_dataset(combined_temp_dataset)
+ else:
+ # Must be a mix of unnamed dataarrays with either named dataarrays or with datasets
+ # Can't combine these as we wouldn't know whether to merge or concatenate the arrays
+ raise ValueError(
+ "Can't automatically combine unnamed DataArrays with either named DataArrays or Datasets."
+ )
+ else:
+ # Promote any named DataArrays to single-variable Datasets to simplify combining
+ data_objects = [
+ obj.to_dataset() if isinstance(obj, DataArray) else obj
+ for obj in data_objects
+ ]
+
+ # Group by data vars
+ sorted_datasets = sorted(data_objects, key=vars_as_keys)
+ grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys)
+
+ # Perform the multidimensional combine on each group of data variables
+ # before merging back together
+ concatenated_grouped_by_data_vars = tuple(
+ _combine_single_variable_hypercube(
+ tuple(datasets_with_same_vars),
+ fill_value=fill_value,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
+ for vars, datasets_with_same_vars in grouped_by_vars
+ )
+
+ return merge(
+ concatenated_grouped_by_data_vars,
+ compat=compat,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ )
diff --git a/xarray/core/common.py b/xarray/core/common.py
index c819d798..1e9c8ed8 100644
--- a/xarray/core/common.py
+++ b/xarray/core/common.py
@@ -1,69 +1,139 @@
from __future__ import annotations
+
import warnings
from collections.abc import Hashable, Iterable, Iterator, Mapping
from contextlib import suppress
from html import escape
from textwrap import dedent
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
+
import numpy as np
import pandas as pd
+
from xarray.core import dtypes, duck_array_ops, formatting, formatting_html, ops
from xarray.core.indexing import BasicIndexer, ExplicitlyIndexed
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.utils import Frozen, either_dict_or_kwargs, is_scalar
+from xarray.core.utils import (
+ Frozen,
+ either_dict_or_kwargs,
+ is_scalar,
+)
from xarray.namedarray.core import _raise_if_any_duplicate_dimensions
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.namedarray.pycompat import is_chunked_array
+
try:
import cftime
except ImportError:
cftime = None
+
+# Used as a sentinel value to indicate a all dimensions
ALL_DIMS = ...
+
+
if TYPE_CHECKING:
import datetime
+
from numpy.typing import DTypeLike
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index
from xarray.core.resample import Resample
from xarray.core.rolling_exp import RollingExp
- from xarray.core.types import DatetimeLike, DTypeLikeSave, ScalarOrArray, Self, SideOptions, T_Chunks, T_DataWithCoords, T_Variable
+ from xarray.core.types import (
+ DatetimeLike,
+ DTypeLikeSave,
+ ScalarOrArray,
+ Self,
+ SideOptions,
+ T_Chunks,
+ T_DataWithCoords,
+ T_Variable,
+ )
from xarray.core.variable import Variable
from xarray.groupers import Resampler
+
DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]]
-T_Resample = TypeVar('T_Resample', bound='Resample')
-C = TypeVar('C')
-T = TypeVar('T')
+
+
+T_Resample = TypeVar("T_Resample", bound="Resample")
+C = TypeVar("C")
+T = TypeVar("T")
class ImplementsArrayReduce:
__slots__ = ()
+
+ @classmethod
+ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool):
+ if include_skipna:
+
+ def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):
+ return self.reduce(
+ func=func, dim=dim, axis=axis, skipna=skipna, **kwargs
+ )
+
+ else:
+
+ def wrapped_func(self, dim=None, axis=None, **kwargs): # type: ignore[misc]
+ return self.reduce(func=func, dim=dim, axis=axis, **kwargs)
+
+ return wrapped_func
+
_reduce_extra_args_docstring = dedent(
- """ dim : str or sequence of str, optional
+ """\
+ dim : str or sequence of str, optional
Dimension(s) over which to apply `{name}`.
axis : int or sequence of int, optional
Axis(es) over which to apply `{name}`. Only one of the 'dim'
and 'axis' arguments can be supplied. If neither are supplied, then
`{name}` is calculated over axes."""
- )
+ )
+
_cum_extra_args_docstring = dedent(
- """ dim : str or sequence of str, optional
+ """\
+ dim : str or sequence of str, optional
Dimension over which to apply `{name}`.
axis : int or sequence of int, optional
Axis over which to apply `{name}`. Only one of the 'dim'
and 'axis' arguments can be supplied."""
- )
+ )
class ImplementsDatasetReduce:
__slots__ = ()
+
+ @classmethod
+ def _reduce_method(cls, func: Callable, include_skipna: bool, numeric_only: bool):
+ if include_skipna:
+
+ def wrapped_func(self, dim=None, skipna=None, **kwargs):
+ return self.reduce(
+ func=func,
+ dim=dim,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+ else:
+
+ def wrapped_func(self, dim=None, **kwargs): # type: ignore[misc]
+ return self.reduce(
+ func=func, dim=dim, numeric_only=numeric_only, **kwargs
+ )
+
+ return wrapped_func
+
_reduce_extra_args_docstring = dedent(
"""
dim : str or sequence of str, optional
Dimension(s) over which to apply `{name}`. By default `{name}` is
applied over all dimensions.
"""
- ).strip()
+ ).strip()
+
_cum_extra_args_docstring = dedent(
"""
dim : str or sequence of str, optional
@@ -72,49 +142,70 @@ class ImplementsDatasetReduce:
Axis over which to apply `{name}`. Only one of the 'dim'
and 'axis' arguments can be supplied.
"""
- ).strip()
+ ).strip()
class AbstractArray:
"""Shared base class for DataArray and Variable."""
+
__slots__ = ()
- def __bool__(self: Any) ->bool:
+ def __bool__(self: Any) -> bool:
return bool(self.values)
- def __float__(self: Any) ->float:
+ def __float__(self: Any) -> float:
return float(self.values)
- def __int__(self: Any) ->int:
+ def __int__(self: Any) -> int:
return int(self.values)
- def __complex__(self: Any) ->complex:
+ def __complex__(self: Any) -> complex:
return complex(self.values)
- def __array__(self: Any, dtype: (DTypeLike | None)=None) ->np.ndarray:
+ def __array__(self: Any, dtype: DTypeLike | None = None) -> np.ndarray:
return np.asarray(self.values, dtype=dtype)
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
return formatting.array_repr(self)
- def __format__(self: Any, format_spec: str='') ->str:
- if format_spec != '':
+ def _repr_html_(self):
+ if OPTIONS["display_style"] == "text":
+ return f"<pre>{escape(repr(self))}</pre>"
+ return formatting_html.array_repr(self)
+
+ def __format__(self: Any, format_spec: str = "") -> str:
+ if format_spec != "":
if self.shape == ():
+ # Scalar values might be ok use format_spec with instead of repr:
return self.data.__format__(format_spec)
else:
+ # TODO: If it's an array the formatting.array_repr(self) should
+ # take format_spec as an input. If we'd only use self.data we
+ # lose all the information about coords for example which is
+ # important information:
raise NotImplementedError(
- f'Using format_spec is only supported when shape is (). Got shape = {self.shape}.'
- )
+ "Using format_spec is only supported"
+ f" when shape is (). Got shape = {self.shape}."
+ )
else:
return self.__repr__()
- def __iter__(self: Any) ->Iterator[Any]:
+ def _iter(self: Any) -> Iterator[Any]:
+ for n in range(len(self)):
+ yield self[n]
+
+ def __iter__(self: Any) -> Iterator[Any]:
if self.ndim == 0:
- raise TypeError('iteration over a 0-d array')
+ raise TypeError("iteration over a 0-d array")
return self._iter()
- def get_axis_num(self, dim: (Hashable | Iterable[Hashable])) ->(int |
- tuple[int, ...]):
+ @overload
+ def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ...
+
+ @overload
+ def get_axis_num(self, dim: Hashable) -> int: ...
+
+ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
"""Return axis number(s) corresponding to dimension(s) in this array.
Parameters
@@ -127,10 +218,20 @@ class AbstractArray:
int or tuple of int
Axis number or numbers corresponding to the given dimensions.
"""
- pass
+ if not isinstance(dim, str) and isinstance(dim, Iterable):
+ return tuple(self._get_axis_num(d) for d in dim)
+ else:
+ return self._get_axis_num(dim)
+
+ def _get_axis_num(self: Any, dim: Hashable) -> int:
+ _raise_if_any_duplicate_dimensions(self.dims)
+ try:
+ return self.dims.index(dim)
+ except ValueError:
+ raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}")
@property
- def sizes(self: Any) ->Mapping[Hashable, int]:
+ def sizes(self: Any) -> Mapping[Hashable, int]:
"""Ordered mapping from dimension names to lengths.
Immutable.
@@ -139,11 +240,12 @@ class AbstractArray:
--------
Dataset.sizes
"""
- pass
+ return Frozen(dict(zip(self.dims, self.shape)))
class AttrAccessMixin:
"""Mixin class that allows getting keys with attribute access"""
+
__slots__ = ()
def __init_subclass__(cls, **kwargs):
@@ -151,41 +253,59 @@ class AttrAccessMixin:
raise error in the core xarray module and a FutureWarning in third-party
extensions.
"""
- if not hasattr(object.__new__(cls), '__dict__'):
+ if not hasattr(object.__new__(cls), "__dict__"):
pass
- elif cls.__module__.startswith('xarray.'):
- raise AttributeError(
- f'{cls.__name__} must explicitly define __slots__')
+ elif cls.__module__.startswith("xarray."):
+ raise AttributeError(f"{cls.__name__} must explicitly define __slots__")
else:
cls.__setattr__ = cls._setattr_dict
warnings.warn(
- f'xarray subclass {cls.__name__} should explicitly define __slots__'
- , FutureWarning, stacklevel=2)
+ f"xarray subclass {cls.__name__} should explicitly define __slots__",
+ FutureWarning,
+ stacklevel=2,
+ )
super().__init_subclass__(**kwargs)
@property
- def _attr_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for attribute-style access"""
- pass
+ yield from ()
@property
- def _item_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for key-autocompletion"""
- pass
+ yield from ()
- def __getattr__(self, name: str) ->Any:
- if name not in {'__dict__', '__setstate__'}:
+ def __getattr__(self, name: str) -> Any:
+ if name not in {"__dict__", "__setstate__"}:
+ # this avoids an infinite loop when pickle looks for the
+ # __setstate__ attribute before the xarray object is initialized
for source in self._attr_sources:
with suppress(KeyError):
return source[name]
raise AttributeError(
- f'{type(self).__name__!r} object has no attribute {name!r}')
+ f"{type(self).__name__!r} object has no attribute {name!r}"
+ )
- def _setattr_dict(self, name: str, value: Any) ->None:
+ # This complicated two-method design boosts overall performance of simple operations
+ # - particularly DataArray methods that perform a _to_temp_dataset() round-trip - by
+ # a whopping 8% compared to a single method that checks hasattr(self, "__dict__") at
+ # runtime before every single assignment. All of this is just temporary until the
+ # FutureWarning can be changed into a hard crash.
+ def _setattr_dict(self, name: str, value: Any) -> None:
"""Deprecated third party subclass (see ``__init_subclass__`` above)"""
- pass
-
- def __setattr__(self, name: str, value: Any) ->None:
+ object.__setattr__(self, name, value)
+ if name in self.__dict__:
+ # Custom, non-slotted attr, or improperly assigned variable?
+ warnings.warn(
+ f"Setting attribute {name!r} on a {type(self).__name__!r} object. Explicitly define __slots__ "
+ "to suppress this warning for legitimate custom attributes and "
+ "raise an error when attempting variables assignments.",
+ FutureWarning,
+ stacklevel=2,
+ )
+
+ def __setattr__(self, name: str, value: Any) -> None:
"""Objects with ``__slots__`` raise AttributeError if you try setting an
undeclared attribute. This is desirable, but the error message could use some
improvement.
@@ -193,31 +313,47 @@ class AttrAccessMixin:
try:
object.__setattr__(self, name, value)
except AttributeError as e:
- if str(e
- ) != f'{type(self).__name__!r} object has no attribute {name!r}':
+ # Don't accidentally shadow custom AttributeErrors, e.g.
+ # DataArray.dims.setter
+ if str(e) != f"{type(self).__name__!r} object has no attribute {name!r}":
raise
raise AttributeError(
- f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ styleassignment (e.g., `ds['name'] = ...`) instead of assigning variables."
- ) from e
+ f"cannot set attribute {name!r} on a {type(self).__name__!r} object. Use __setitem__ style"
+ "assignment (e.g., `ds['name'] = ...`) instead of assigning variables."
+ ) from e
- def __dir__(self) ->list[str]:
+ def __dir__(self) -> list[str]:
"""Provide method name lookup and completion. Only provide 'public'
methods.
"""
- extra_attrs = {item for source in self._attr_sources for item in
- source if isinstance(item, str)}
+ extra_attrs = {
+ item
+ for source in self._attr_sources
+ for item in source
+ if isinstance(item, str)
+ }
return sorted(set(dir(type(self))) | extra_attrs)
- def _ipython_key_completions_(self) ->list[str]:
+ def _ipython_key_completions_(self) -> list[str]:
"""Provide method for the key-autocompletions in IPython.
See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion
For the details.
"""
- pass
+ items = {
+ item
+ for source in self._item_sources
+ for item in source
+ if isinstance(item, str)
+ }
+ return list(items)
class TreeAttrAccessMixin(AttrAccessMixin):
"""Mixin class that allows getting keys with attribute access"""
+
+ # TODO: Ensure ipython tab completion can include both child datatrees and
+ # variables from Dataset objects on relevant nodes.
+
__slots__ = ()
def __init_subclass__(cls, **kwargs):
@@ -226,24 +362,57 @@ class TreeAttrAccessMixin(AttrAccessMixin):
``DataTree`` has some dynamically defined attributes in addition to those
defined in ``__slots__``. (GH9068)
"""
- if not hasattr(object.__new__(cls), '__dict__'):
+ if not hasattr(object.__new__(cls), "__dict__"):
pass
-def get_squeeze_dims(xarray_obj, dim: (Hashable | Iterable[Hashable] | None
- )=None, axis: (int | Iterable[int] | None)=None) ->list[Hashable]:
+def get_squeeze_dims(
+ xarray_obj,
+ dim: Hashable | Iterable[Hashable] | None = None,
+ axis: int | Iterable[int] | None = None,
+) -> list[Hashable]:
"""Get a list of dimensions to squeeze out."""
- pass
+ if dim is not None and axis is not None:
+ raise ValueError("cannot use both parameters `axis` and `dim`")
+ if dim is None and axis is None:
+ return [d for d, s in xarray_obj.sizes.items() if s == 1]
+
+ if isinstance(dim, Iterable) and not isinstance(dim, str):
+ dim = list(dim)
+ elif dim is not None:
+ dim = [dim]
+ else:
+ assert axis is not None
+ if isinstance(axis, int):
+ axis = [axis]
+ axis = list(axis)
+ if any(not isinstance(a, int) for a in axis):
+ raise TypeError("parameter `axis` must be int or iterable of int.")
+ alldims = list(xarray_obj.sizes.keys())
+ dim = [alldims[a] for a in axis]
+
+ if any(xarray_obj.sizes[k] > 1 for k in dim):
+ raise ValueError(
+ "cannot select a dimension to squeeze out "
+ "which has length greater than one"
+ )
+ return dim
class DataWithCoords(AttrAccessMixin):
"""Shared base class for Dataset and DataArray."""
+
_close: Callable[[], None] | None
_indexes: dict[Hashable, Index]
- __slots__ = '_close',
- def squeeze(self, dim: (Hashable | Iterable[Hashable] | None)=None,
- drop: bool=False, axis: (int | Iterable[int] | None)=None) ->Self:
+ __slots__ = ("_close",)
+
+ def squeeze(
+ self,
+ dim: Hashable | Iterable[Hashable] | None = None,
+ drop: bool = False,
+ axis: int | Iterable[int] | None = None,
+ ) -> Self:
"""Return a new object with squeezed data.
Parameters
@@ -268,10 +437,16 @@ class DataWithCoords(AttrAccessMixin):
--------
numpy.squeeze
"""
- pass
-
- def clip(self, min: (ScalarOrArray | None)=None, max: (ScalarOrArray |
- None)=None, *, keep_attrs: (bool | None)=None) ->Self:
+ dims = get_squeeze_dims(self, dim, axis)
+ return self.isel(drop=drop, **{d: 0 for d in dims})
+
+ def clip(
+ self,
+ min: ScalarOrArray | None = None,
+ max: ScalarOrArray | None = None,
+ *,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""
Return an array whose values are limited to ``[min, max]``.
At least one of max or min must be given.
@@ -297,14 +472,37 @@ class DataWithCoords(AttrAccessMixin):
--------
numpy.clip : equivalent function
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ if keep_attrs is None:
+ # When this was a unary func, the default was True, so retaining the
+ # default.
+ keep_attrs = _get_keep_attrs(default=True)
- def get_index(self, key: Hashable) ->pd.Index:
+ return apply_ufunc(
+ np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed"
+ )
+
+ def get_index(self, key: Hashable) -> pd.Index:
"""Get an index for a dimension, with fall-back to a default RangeIndex"""
- pass
+ if key not in self.dims:
+ raise KeyError(key)
- def assign_coords(self, coords: (Mapping | None)=None, **coords_kwargs: Any
- ) ->Self:
+ try:
+ return self._indexes[key].to_pandas_index()
+ except KeyError:
+ return pd.Index(range(self.sizes[key]), name=key)
+
+ def _calc_assign_results(
+ self: C, kwargs: Mapping[Any, T | Callable[[C], T]]
+ ) -> dict[Hashable, T]:
+ return {k: v(self) if callable(v) else v for k, v in kwargs.items()}
+
+ def assign_coords(
+ self,
+ coords: Mapping | None = None,
+ **coords_kwargs: Any,
+ ) -> Self:
"""Assign new coordinates to this object.
Returns a new object with all the original data in addition to the new
@@ -433,9 +631,21 @@ class DataWithCoords(AttrAccessMixin):
Dataset.swap_dims
Dataset.set_coords
"""
- pass
+ from xarray.core.coordinates import Coordinates
- def assign_attrs(self, *args: Any, **kwargs: Any) ->Self:
+ coords_combined = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords")
+ data = self.copy(deep=False)
+
+ results: Coordinates | dict[Hashable, Any]
+ if isinstance(coords, Coordinates):
+ results = coords
+ else:
+ results = self._calc_assign_results(coords_combined)
+
+ data.coords.update(results)
+ return data
+
+ def assign_attrs(self, *args: Any, **kwargs: Any) -> Self:
"""Assign new attrs to this object.
Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``.
@@ -486,10 +696,16 @@ class DataWithCoords(AttrAccessMixin):
--------
Dataset.assign
"""
- pass
-
- def pipe(self, func: (Callable[..., T] | tuple[Callable[..., T], str]),
- *args: Any, **kwargs: Any) ->T:
+ out = self.copy(deep=False)
+ out.attrs.update(*args, **kwargs)
+ return out
+
+ def pipe(
+ self,
+ func: Callable[..., T] | tuple[Callable[..., T], str],
+ *args: Any,
+ **kwargs: Any,
+ ) -> T:
"""
Apply ``func(self, *args, **kwargs)``
@@ -605,11 +821,23 @@ class DataWithCoords(AttrAccessMixin):
--------
pandas.DataFrame.pipe
"""
- pass
-
- def rolling_exp(self: T_DataWithCoords, window: (Mapping[Any, int] |
- None)=None, window_type: str='span', **window_kwargs) ->RollingExp[
- T_DataWithCoords]:
+ if isinstance(func, tuple):
+ func, target = func
+ if target in kwargs:
+ raise ValueError(
+ f"{target} is both the pipe target and a keyword argument"
+ )
+ kwargs[target] = self
+ return func(*args, **kwargs)
+ else:
+ return func(self, *args, **kwargs)
+
+ def rolling_exp(
+ self: T_DataWithCoords,
+ window: Mapping[Any, int] | None = None,
+ window_type: str = "span",
+ **window_kwargs,
+ ) -> RollingExp[T_DataWithCoords]:
"""
Exponentially-weighted moving window.
Similar to EWM in pandas
@@ -633,14 +861,31 @@ class DataWithCoords(AttrAccessMixin):
--------
core.rolling_exp.RollingExp
"""
- pass
-
- def _resample(self, resample_cls: type[T_Resample], indexer: (Mapping[
- Hashable, str | Resampler] | None), skipna: (bool | None), closed:
- (SideOptions | None), label: (SideOptions | None), offset: (pd.
- Timedelta | datetime.timedelta | str | None), origin: (str |
- DatetimeLike), restore_coord_dims: (bool | None), **indexer_kwargs:
- (str | Resampler)) ->T_Resample:
+ from xarray.core import rolling_exp
+
+ if "keep_attrs" in window_kwargs:
+ warnings.warn(
+ "Passing ``keep_attrs`` to ``rolling_exp`` has no effect. Pass"
+ " ``keep_attrs`` directly to the applied function, e.g."
+ " ``rolling_exp(...).mean(keep_attrs=False)``."
+ )
+
+ window = either_dict_or_kwargs(window, window_kwargs, "rolling_exp")
+
+ return rolling_exp.RollingExp(self, window, window_type)
+
+ def _resample(
+ self,
+ resample_cls: type[T_Resample],
+ indexer: Mapping[Hashable, str | Resampler] | None,
+ skipna: bool | None,
+ closed: SideOptions | None,
+ label: SideOptions | None,
+ offset: pd.Timedelta | datetime.timedelta | str | None,
+ origin: str | DatetimeLike,
+ restore_coord_dims: bool | None,
+ **indexer_kwargs: str | Resampler,
+ ) -> T_Resample:
"""Returns a Resample object for performing resampling operations.
Handles both downsampling and upsampling. The resampled
@@ -799,9 +1044,46 @@ class DataWithCoords(AttrAccessMixin):
----------
.. [1] https://pandas.pydata.org/docs/user_guide/timeseries.html#dateoffset-objects
"""
- pass
+ # TODO support non-string indexer after removing the old API.
+
+ from xarray.core.dataarray import DataArray
+ from xarray.core.groupby import ResolvedGrouper
+ from xarray.core.resample import RESAMPLE_DIM
+ from xarray.groupers import Resampler, TimeResampler
+
+ indexer = either_dict_or_kwargs(indexer, indexer_kwargs, "resample")
+ if len(indexer) != 1:
+ raise ValueError("Resampling only supported along single dimensions.")
+ dim, freq = next(iter(indexer.items()))
+
+ dim_name: Hashable = dim
+ dim_coord = self[dim]
+
+ group = DataArray(
+ dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM
+ )
- def where(self, cond: Any, other: Any=dtypes.NA, drop: bool=False) ->Self:
+ grouper: Resampler
+ if isinstance(freq, str):
+ grouper = TimeResampler(
+ freq=freq, closed=closed, label=label, origin=origin, offset=offset
+ )
+ elif isinstance(freq, Resampler):
+ grouper = freq
+ else:
+ raise ValueError("freq must be a str or a Resampler object")
+
+ rgrouper = ResolvedGrouper(grouper, group, self)
+
+ return resample_cls(
+ self,
+ (rgrouper,),
+ dim=dim_name,
+ resample_dim=RESAMPLE_DIM,
+ restore_coord_dims=restore_coord_dims,
+ )
+
+ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
"""Filter elements from this object according to a condition.
Returns elements from 'DataArray', where 'cond' is True,
@@ -889,9 +1171,47 @@ class DataWithCoords(AttrAccessMixin):
numpy.where : corresponding numpy function
where : equivalent function
"""
- pass
+ from xarray.core.alignment import align
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ if callable(cond):
+ cond = cond(self)
+ if callable(other):
+ other = other(self)
+
+ if drop:
+ if not isinstance(cond, (Dataset, DataArray)):
+ raise TypeError(
+ f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)."
+ )
+
+ self, cond = align(self, cond)
+
+ def _dataarray_indexer(dim: Hashable) -> DataArray:
+ return cond.any(dim=(d for d in cond.dims if d != dim))
+
+ def _dataset_indexer(dim: Hashable) -> DataArray:
+ cond_wdim = cond.drop_vars(
+ var for var in cond if dim not in cond[var].dims
+ )
+ keepany = cond_wdim.any(dim=(d for d in cond.dims if d != dim))
+ return keepany.to_dataarray().any("variable")
+
+ _get_indexer = (
+ _dataarray_indexer if isinstance(cond, DataArray) else _dataset_indexer
+ )
+
+ indexers = {}
+ for dim in cond.sizes.keys():
+ indexers[dim] = _get_indexer(dim)
+
+ self = self.isel(**indexers)
+ cond = cond.isel(**indexers)
+
+ return ops.where_method(self, cond, other)
- def set_close(self, close: (Callable[[], None] | None)) ->None:
+ def set_close(self, close: Callable[[], None] | None) -> None:
"""Register the function that releases any resources linked to this object.
This method controls how xarray cleans up resources associated
@@ -905,13 +1225,15 @@ class DataWithCoords(AttrAccessMixin):
The function that when called like ``close()`` releases
any resources linked to this object.
"""
- pass
+ self._close = close
- def close(self) ->None:
+ def close(self) -> None:
"""Release any resources linked to this object."""
- pass
+ if self._close is not None:
+ self._close()
+ self._close = None
- def isnull(self, keep_attrs: (bool | None)=None) ->Self:
+ def isnull(self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is a missing value.
Parameters
@@ -942,9 +1264,19 @@ class DataWithCoords(AttrAccessMixin):
array([False, True, False])
Dimensions without coordinates: x
"""
- pass
+ from xarray.core.computation import apply_ufunc
- def notnull(self, keep_attrs: (bool | None)=None) ->Self:
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ return apply_ufunc(
+ duck_array_ops.isnull,
+ self,
+ dask="allowed",
+ keep_attrs=keep_attrs,
+ )
+
+ def notnull(self, keep_attrs: bool | None = None) -> Self:
"""Test each value in the array for whether it is not a missing value.
Parameters
@@ -975,9 +1307,19 @@ class DataWithCoords(AttrAccessMixin):
array([ True, False, True])
Dimensions without coordinates: x
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
- def isin(self, test_elements: Any) ->Self:
+ return apply_ufunc(
+ duck_array_ops.notnull,
+ self,
+ dask="allowed",
+ keep_attrs=keep_attrs,
+ )
+
+ def isin(self, test_elements: Any) -> Self:
"""Tests each value in the array for whether it is in test elements.
Parameters
@@ -1004,10 +1346,37 @@ class DataWithCoords(AttrAccessMixin):
--------
numpy.isin
"""
- pass
+ from xarray.core.computation import apply_ufunc
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+ from xarray.core.variable import Variable
+
+ if isinstance(test_elements, Dataset):
+ raise TypeError(
+ f"isin() argument must be convertible to an array: {test_elements}"
+ )
+ elif isinstance(test_elements, (Variable, DataArray)):
+ # need to explicitly pull out data to support dask arrays as the
+ # second argument
+ test_elements = test_elements.data
+
+ return apply_ufunc(
+ duck_array_ops.isin,
+ self,
+ kwargs=dict(test_elements=test_elements),
+ dask="allowed",
+ )
- def astype(self, dtype, *, order=None, casting=None, subok=None, copy=
- None, keep_attrs=True) ->Self:
+ def astype(
+ self,
+ dtype,
+ *,
+ order=None,
+ casting=None,
+ subok=None,
+ copy=None,
+ keep_attrs=True,
+ ) -> Self:
"""
Copy of the xarray object, with data cast to a specified type.
Leaves coordinate dtype unchanged.
@@ -1060,22 +1429,100 @@ class DataWithCoords(AttrAccessMixin):
dask.array.Array.astype
sparse.COO.astype
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ kwargs = dict(order=order, casting=casting, subok=subok, copy=copy)
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
+ return apply_ufunc(
+ duck_array_ops.astype,
+ self,
+ dtype,
+ kwargs=kwargs,
+ keep_attrs=keep_attrs,
+ dask="allowed",
+ )
- def __enter__(self) ->Self:
+ def __enter__(self) -> Self:
return self
- def __exit__(self, exc_type, exc_value, traceback) ->None:
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()
def __getitem__(self, value):
+ # implementations of this class should implement this method
raise NotImplementedError()
-def full_like(other: (Dataset | DataArray | Variable), fill_value: Any,
- dtype: (DTypeMaybeMapping | None)=None, *, chunks: T_Chunks=None,
- chunked_array_type: (str | None)=None, from_array_kwargs: (dict[str,
- Any] | None)=None) ->(Dataset | DataArray | Variable):
+@overload
+def full_like(
+ other: DataArray,
+ fill_value: Any,
+ dtype: DTypeLikeSave | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> DataArray: ...
+
+
+@overload
+def full_like(
+ other: Dataset,
+ fill_value: Any,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset: ...
+
+
+@overload
+def full_like(
+ other: Variable,
+ fill_value: Any,
+ dtype: DTypeLikeSave | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Variable: ...
+
+
+@overload
+def full_like(
+ other: Dataset | DataArray,
+ fill_value: Any,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = {},
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray: ...
+
+
+@overload
+def full_like(
+ other: Dataset | DataArray | Variable,
+ fill_value: Any,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray | Variable: ...
+
+
+def full_like(
+ other: Dataset | DataArray | Variable,
+ fill_value: Any,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray | Variable:
"""
Return a new object with the same shape and type as a given object.
@@ -1194,20 +1641,172 @@ def full_like(other: (Dataset | DataArray | Variable), fill_value: Any,
ones_like
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+ from xarray.core.variable import Variable
+ if not is_scalar(fill_value) and not (
+ isinstance(other, Dataset) and isinstance(fill_value, dict)
+ ):
+ raise ValueError(
+ f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead."
+ )
+
+ if isinstance(other, Dataset):
+ if not isinstance(fill_value, dict):
+ fill_value = {k: fill_value for k in other.data_vars.keys()}
-def _full_like_variable(other: Variable, fill_value: Any, dtype: (DTypeLike |
- None)=None, chunks: T_Chunks=None, chunked_array_type: (str | None)=
- None, from_array_kwargs: (dict[str, Any] | None)=None) ->Variable:
+ dtype_: Mapping[Any, DTypeLikeSave]
+ if not isinstance(dtype, Mapping):
+ dtype_ = {k: dtype for k in other.data_vars.keys()}
+ else:
+ dtype_ = dtype
+
+ data_vars = {
+ k: _full_like_variable(
+ v.variable,
+ fill_value.get(k, dtypes.NA),
+ dtype_.get(k, None),
+ chunks,
+ chunked_array_type,
+ from_array_kwargs,
+ )
+ for k, v in other.data_vars.items()
+ }
+ return Dataset(data_vars, coords=other.coords, attrs=other.attrs)
+ elif isinstance(other, DataArray):
+ if isinstance(dtype, Mapping):
+ raise ValueError("'dtype' cannot be dict-like when passing a DataArray")
+ return DataArray(
+ _full_like_variable(
+ other.variable,
+ fill_value,
+ dtype,
+ chunks,
+ chunked_array_type,
+ from_array_kwargs,
+ ),
+ dims=other.dims,
+ coords=other.coords,
+ attrs=other.attrs,
+ name=other.name,
+ )
+ elif isinstance(other, Variable):
+ if isinstance(dtype, Mapping):
+ raise ValueError("'dtype' cannot be dict-like when passing a Variable")
+ return _full_like_variable(
+ other, fill_value, dtype, chunks, chunked_array_type, from_array_kwargs
+ )
+ else:
+ raise TypeError("Expected DataArray, Dataset, or Variable")
+
+
+def _full_like_variable(
+ other: Variable,
+ fill_value: Any,
+ dtype: DTypeLike | None = None,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Variable:
"""Inner function of full_like, where other must be a variable"""
- pass
+ from xarray.core.variable import Variable
+
+ if fill_value is dtypes.NA:
+ fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype)
+
+ if (
+ is_chunked_array(other.data)
+ or chunked_array_type is not None
+ or chunks is not None
+ ):
+ if chunked_array_type is None:
+ chunkmanager = get_chunked_array_type(other.data)
+ else:
+ chunkmanager = guess_chunkmanager(chunked_array_type)
+
+ if dtype is None:
+ dtype = other.dtype
+ if from_array_kwargs is None:
+ from_array_kwargs = {}
-def zeros_like(other: (Dataset | DataArray | Variable), dtype: (
- DTypeMaybeMapping | None)=None, *, chunks: T_Chunks=None,
- chunked_array_type: (str | None)=None, from_array_kwargs: (dict[str,
- Any] | None)=None) ->(Dataset | DataArray | Variable):
+ data = chunkmanager.array_api.full(
+ other.shape,
+ fill_value,
+ dtype=dtype,
+ chunks=chunks if chunks else other.data.chunks,
+ **from_array_kwargs,
+ )
+ else:
+ data = np.full_like(other.data, fill_value, dtype=dtype)
+
+ return Variable(dims=other.dims, data=data, attrs=other.attrs)
+
+
+@overload
+def zeros_like(
+ other: DataArray,
+ dtype: DTypeLikeSave | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> DataArray: ...
+
+
+@overload
+def zeros_like(
+ other: Dataset,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset: ...
+
+
+@overload
+def zeros_like(
+ other: Variable,
+ dtype: DTypeLikeSave | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Variable: ...
+
+
+@overload
+def zeros_like(
+ other: Dataset | DataArray,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray: ...
+
+
+@overload
+def zeros_like(
+ other: Dataset | DataArray | Variable,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray | Variable: ...
+
+
+def zeros_like(
+ other: Dataset | DataArray | Variable,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray | Variable:
"""Return a new object of zeros with the same shape and
type as a given dataarray or dataset.
@@ -1272,13 +1871,79 @@ def zeros_like(other: (Dataset | DataArray | Variable), dtype: (
full_like
"""
- pass
-
-
-def ones_like(other: (Dataset | DataArray | Variable), dtype: (
- DTypeMaybeMapping | None)=None, *, chunks: T_Chunks=None,
- chunked_array_type: (str | None)=None, from_array_kwargs: (dict[str,
- Any] | None)=None) ->(Dataset | DataArray | Variable):
+ return full_like(
+ other,
+ 0,
+ dtype,
+ chunks=chunks,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=from_array_kwargs,
+ )
+
+
+@overload
+def ones_like(
+ other: DataArray,
+ dtype: DTypeLikeSave | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> DataArray: ...
+
+
+@overload
+def ones_like(
+ other: Dataset,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset: ...
+
+
+@overload
+def ones_like(
+ other: Variable,
+ dtype: DTypeLikeSave | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Variable: ...
+
+
+@overload
+def ones_like(
+ other: Dataset | DataArray,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray: ...
+
+
+@overload
+def ones_like(
+ other: Dataset | DataArray | Variable,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray | Variable: ...
+
+
+def ones_like(
+ other: Dataset | DataArray | Variable,
+ dtype: DTypeMaybeMapping | None = None,
+ *,
+ chunks: T_Chunks = None,
+ chunked_array_type: str | None = None,
+ from_array_kwargs: dict[str, Any] | None = None,
+) -> Dataset | DataArray | Variable:
"""Return a new object of ones with the same shape and
type as a given dataarray or dataset.
@@ -1335,31 +2000,64 @@ def ones_like(other: (Dataset | DataArray | Variable), dtype: (
full_like
"""
- pass
+ return full_like(
+ other,
+ 1,
+ dtype,
+ chunks=chunks,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=from_array_kwargs,
+ )
+
+
+def get_chunksizes(
+ variables: Iterable[Variable],
+) -> Mapping[Any, tuple[int, ...]]:
+ chunks: dict[Any, tuple[int, ...]] = {}
+ for v in variables:
+ if hasattr(v._data, "chunks"):
+ for dim, c in v.chunksizes.items():
+ if dim in chunks and c != chunks[dim]:
+ raise ValueError(
+ f"Object has inconsistent chunks along dimension {dim}. "
+ "This can be fixed by calling unify_chunks()."
+ )
+ chunks[dim] = c
+ return Frozen(chunks)
-def is_np_datetime_like(dtype: DTypeLike) ->bool:
+def is_np_datetime_like(dtype: DTypeLike) -> bool:
"""Check if a dtype is a subclass of the numpy datetime types"""
- pass
+ return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
-def is_np_timedelta_like(dtype: DTypeLike) ->bool:
+def is_np_timedelta_like(dtype: DTypeLike) -> bool:
"""Check whether dtype is of the timedelta64 dtype."""
- pass
+ return np.issubdtype(dtype, np.timedelta64)
-def _contains_cftime_datetimes(array: Any) ->bool:
+def _contains_cftime_datetimes(array: Any) -> bool:
"""Check if a array inside a Variable contains cftime.datetime objects"""
- pass
+ if cftime is None:
+ return False
+
+ if array.dtype == np.dtype("O") and array.size > 0:
+ first_idx = (0,) * array.ndim
+ if isinstance(array, ExplicitlyIndexed):
+ first_idx = BasicIndexer(first_idx)
+ sample = array[first_idx]
+ return isinstance(np.asarray(sample).item(), cftime.datetime)
+
+ return False
-def contains_cftime_datetimes(var: T_Variable) ->bool:
+def contains_cftime_datetimes(var: T_Variable) -> bool:
"""Check if an xarray.Variable contains cftime.datetime objects"""
- pass
+ return _contains_cftime_datetimes(var._data)
-def _contains_datetime_like_objects(var: T_Variable) ->bool:
+def _contains_datetime_like_objects(var: T_Variable) -> bool:
"""Check if a variable contains datetime like objects (either
np.datetime64, np.timedelta64, or cftime.datetime)
"""
- pass
+ return is_np_datetime_like(var.dtype) or contains_cftime_datetimes(var)
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index fd8f4fb2..5d21d083 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -1,7 +1,9 @@
"""
Functions for applying functions that act on arrays to xarray's labeled data.
"""
+
from __future__ import annotations
+
import functools
import itertools
import operator
@@ -9,7 +11,9 @@ import warnings
from collections import Counter
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload
+
import numpy as np
+
from xarray.core import dtypes, duck_array_ops, utils
from xarray.core.alignment import align, deep_align
from xarray.core.common import zeros_like
@@ -24,25 +28,32 @@ from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
from xarray.util.deprecation_helpers import deprecate_dims
+
if TYPE_CHECKING:
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import CombineAttrsOptions, JoinOptions
- MissingCoreDimOptions = Literal['raise', 'copy', 'drop']
-_NO_FILL_VALUE = utils.ReprObject('<no-fill-value>')
-_DEFAULT_NAME = utils.ReprObject('<default-name>')
-_JOINS_WITHOUT_FILL_VALUES = frozenset({'inner', 'exact'})
+
+ MissingCoreDimOptions = Literal["raise", "copy", "drop"]
+
+_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
+_DEFAULT_NAME = utils.ReprObject("<default-name>")
+_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
def _first_of_type(args, kind):
"""Return either first object of type 'kind' or raise if not found."""
- pass
+ for arg in args:
+ if isinstance(arg, kind):
+ return arg
+
+ raise ValueError("This should be unreachable.")
def _all_of_type(args, kind):
"""Return all objects of type 'kind'"""
- pass
+ return [arg for arg in args if isinstance(arg, kind)]
class _UFuncSignature:
@@ -57,8 +68,14 @@ class _UFuncSignature:
output_core_dims : tuple[tuple]
Core dimension names on each output variable.
"""
- __slots__ = ('input_core_dims', 'output_core_dims',
- '_all_input_core_dims', '_all_output_core_dims', '_all_core_dims')
+
+ __slots__ = (
+ "input_core_dims",
+ "output_core_dims",
+ "_all_input_core_dims",
+ "_all_output_core_dims",
+ "_all_core_dims",
+ )
def __init__(self, input_core_dims, output_core_dims=((),)):
self.input_core_dims = tuple(tuple(a) for a in input_core_dims)
@@ -67,10 +84,48 @@ class _UFuncSignature:
self._all_output_core_dims = None
self._all_core_dims = None
+ @property
+ def all_input_core_dims(self):
+ if self._all_input_core_dims is None:
+ self._all_input_core_dims = frozenset(
+ dim for dims in self.input_core_dims for dim in dims
+ )
+ return self._all_input_core_dims
+
+ @property
+ def all_output_core_dims(self):
+ if self._all_output_core_dims is None:
+ self._all_output_core_dims = frozenset(
+ dim for dims in self.output_core_dims for dim in dims
+ )
+ return self._all_output_core_dims
+
+ @property
+ def all_core_dims(self):
+ if self._all_core_dims is None:
+ self._all_core_dims = self.all_input_core_dims | self.all_output_core_dims
+ return self._all_core_dims
+
+ @property
+ def dims_map(self):
+ return {
+ core_dim: f"dim{n}" for n, core_dim in enumerate(sorted(self.all_core_dims))
+ }
+
+ @property
+ def num_inputs(self):
+ return len(self.input_core_dims)
+
+ @property
+ def num_outputs(self):
+ return len(self.output_core_dims)
+
def __eq__(self, other):
try:
- return (self.input_core_dims == other.input_core_dims and self.
- output_core_dims == other.output_core_dims)
+ return (
+ self.input_core_dims == other.input_core_dims
+ and self.output_core_dims == other.output_core_dims
+ )
except AttributeError:
return False
@@ -78,16 +133,12 @@ class _UFuncSignature:
return not self == other
def __repr__(self):
- return (
- f'{type(self).__name__}({list(self.input_core_dims)!r}, {list(self.output_core_dims)!r})'
- )
+ return f"{type(self).__name__}({list(self.input_core_dims)!r}, {list(self.output_core_dims)!r})"
def __str__(self):
- lhs = ','.join('({})'.format(','.join(dims)) for dims in self.
- input_core_dims)
- rhs = ','.join('({})'.format(','.join(dims)) for dims in self.
- output_core_dims)
- return f'{lhs}->{rhs}'
+ lhs = ",".join("({})".format(",".join(dims)) for dims in self.input_core_dims)
+ rhs = ",".join("({})".format(",".join(dims)) for dims in self.output_core_dims)
+ return f"{lhs}->{rhs}"
def to_gufunc_string(self, exclude_dims=frozenset()):
"""Create an equivalent signature string for a NumPy gufunc.
@@ -97,13 +148,66 @@ class _UFuncSignature:
Also creates unique names for input_core_dims contained in exclude_dims.
"""
- pass
-
-
-def build_output_coords_and_indexes(args: Iterable[Any], signature:
- _UFuncSignature, exclude_dims: Set=frozenset(), combine_attrs:
- CombineAttrsOptions='override') ->tuple[list[dict[Any, Variable]], list
- [dict[Any, Index]]]:
+ input_core_dims = [
+ [self.dims_map[dim] for dim in core_dims]
+ for core_dims in self.input_core_dims
+ ]
+ output_core_dims = [
+ [self.dims_map[dim] for dim in core_dims]
+ for core_dims in self.output_core_dims
+ ]
+
+ # enumerate input_core_dims contained in exclude_dims to make them unique
+ if exclude_dims:
+ exclude_dims = [self.dims_map[dim] for dim in exclude_dims]
+
+ counter: Counter = Counter()
+
+ def _enumerate(dim):
+ if dim in exclude_dims:
+ n = counter[dim]
+ counter.update([dim])
+ dim = f"{dim}_{n}"
+ return dim
+
+ input_core_dims = [
+ [_enumerate(dim) for dim in arg] for arg in input_core_dims
+ ]
+
+ alt_signature = type(self)(input_core_dims, output_core_dims)
+ return str(alt_signature)
+
+
+def result_name(objects: Iterable[Any]) -> Any:
+ # use the same naming heuristics as pandas:
+ # https://github.com/blaze/blaze/issues/458#issuecomment-51936356
+ names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects}
+ names.discard(_DEFAULT_NAME)
+ if len(names) == 1:
+ (name,) = names
+ else:
+ name = None
+ return name
+
+
+def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]:
+ coords_list = []
+ for arg in args:
+ try:
+ coords = arg.coords
+ except AttributeError:
+ pass # skip this argument
+ else:
+ coords_list.append(coords)
+ return coords_list
+
+
+def build_output_coords_and_indexes(
+ args: Iterable[Any],
+ signature: _UFuncSignature,
+ exclude_dims: Set = frozenset(),
+ combine_attrs: CombineAttrsOptions = "override",
+) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]:
"""Build output coordinates and indexes for an operation.
Parameters
@@ -116,7 +220,8 @@ def build_output_coords_and_indexes(args: Iterable[Any], signature:
exclude_dims : set, optional
Dimensions excluded from the operation. Coordinates along these
dimensions are dropped.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "drop"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -136,21 +241,167 @@ def build_output_coords_and_indexes(args: Iterable[Any], signature:
-------
Dictionaries of Variable and Index objects with merged coordinates.
"""
- pass
-
-
-def apply_dataarray_vfunc(func, *args, signature: _UFuncSignature, join:
- JoinOptions='inner', exclude_dims=frozenset(), keep_attrs='override') ->(
- tuple[DataArray, ...] | DataArray):
+ coords_list = _get_coords_list(args)
+
+ if len(coords_list) == 1 and not exclude_dims:
+ # we can skip the expensive merge
+ (unpacked_coords,) = coords_list
+ merged_vars = dict(unpacked_coords.variables)
+ merged_indexes = dict(unpacked_coords.xindexes)
+ else:
+ merged_vars, merged_indexes = merge_coordinates_without_align(
+ coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs
+ )
+
+ output_coords = []
+ output_indexes = []
+ for output_dims in signature.output_core_dims:
+ dropped_dims = signature.all_input_core_dims - set(output_dims)
+ if dropped_dims:
+ filtered_coords = {
+ k: v for k, v in merged_vars.items() if dropped_dims.isdisjoint(v.dims)
+ }
+ filtered_indexes = filter_indexes_from_coords(
+ merged_indexes, set(filtered_coords)
+ )
+ else:
+ filtered_coords = merged_vars
+ filtered_indexes = merged_indexes
+ output_coords.append(filtered_coords)
+ output_indexes.append(filtered_indexes)
+
+ return output_coords, output_indexes
+
+
+def apply_dataarray_vfunc(
+ func,
+ *args,
+ signature: _UFuncSignature,
+ join: JoinOptions = "inner",
+ exclude_dims=frozenset(),
+ keep_attrs="override",
+) -> tuple[DataArray, ...] | DataArray:
"""Apply a variable level function over DataArray, Variable and/or ndarray
objects.
"""
- pass
+ from xarray.core.dataarray import DataArray
+ if len(args) > 1:
+ args = tuple(
+ deep_align(
+ args,
+ join=join,
+ copy=False,
+ exclude=exclude_dims,
+ raise_on_invalid=False,
+ )
+ )
+
+ objs = _all_of_type(args, DataArray)
+
+ if keep_attrs == "drop":
+ name = result_name(args)
+ else:
+ first_obj = _first_of_type(args, DataArray)
+ name = first_obj.name
+ result_coords, result_indexes = build_output_coords_and_indexes(
+ args, signature, exclude_dims, combine_attrs=keep_attrs
+ )
+
+ data_vars = [getattr(a, "variable", a) for a in args]
+ result_var = func(*data_vars)
+
+ out: tuple[DataArray, ...] | DataArray
+ if signature.num_outputs > 1:
+ out = tuple(
+ DataArray(
+ variable, coords=coords, indexes=indexes, name=name, fastpath=True
+ )
+ for variable, coords, indexes in zip(
+ result_var, result_coords, result_indexes
+ )
+ )
+ else:
+ (coords,) = result_coords
+ (indexes,) = result_indexes
+ out = DataArray(
+ result_var, coords=coords, indexes=indexes, name=name, fastpath=True
+ )
+
+ attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
+ if isinstance(out, tuple):
+ for da in out:
+ da.attrs = attrs
+ else:
+ out.attrs = attrs
+
+ return out
+
+
+def ordered_set_union(all_keys: list[Iterable]) -> Iterable:
+ return {key: None for keys in all_keys for key in keys}.keys()
+
+
+def ordered_set_intersection(all_keys: list[Iterable]) -> Iterable:
+ intersection = set(all_keys[0])
+ for keys in all_keys[1:]:
+ intersection.intersection_update(keys)
+ return [key for key in all_keys[0] if key in intersection]
+
+
+def assert_and_return_exact_match(all_keys):
+ first_keys = all_keys[0]
+ for keys in all_keys[1:]:
+ if keys != first_keys:
+ raise ValueError(
+ "exact match required for all data variable names, "
+ f"but {list(keys)} != {list(first_keys)}: {set(keys) ^ set(first_keys)} are not in both."
+ )
+ return first_keys
+
+
+_JOINERS: dict[str, Callable] = {
+ "inner": ordered_set_intersection,
+ "outer": ordered_set_union,
+ "left": operator.itemgetter(0),
+ "right": operator.itemgetter(-1),
+ "exact": assert_and_return_exact_match,
+}
+
+
+def join_dict_keys(objects: Iterable[Mapping | Any], how: str = "inner") -> Iterable:
+ joiner = _JOINERS[how]
+ all_keys = [obj.keys() for obj in objects if hasattr(obj, "keys")]
+ return joiner(all_keys)
+
+
+def collect_dict_values(
+ objects: Iterable[Mapping | Any], keys: Iterable, fill_value: object = None
+) -> list[list]:
+ return [
+ [obj.get(key, fill_value) if is_dict_like(obj) else obj for obj in objects]
+ for key in keys
+ ]
+
+
+def _as_variables_or_variable(arg) -> Variable | tuple[Variable]:
+ try:
+ return arg.variables
+ except AttributeError:
+ try:
+ return arg.variable
+ except AttributeError:
+ return arg
-_JOINERS: dict[str, Callable] = {'inner': ordered_set_intersection, 'outer':
- ordered_set_union, 'left': operator.itemgetter(0), 'right': operator.
- itemgetter(-1), 'exact': assert_and_return_exact_match}
+
+def _unpack_dict_tuples(
+ result_vars: Mapping[Any, tuple[Variable, ...]], num_outputs: int
+) -> tuple[dict[Hashable, Variable], ...]:
+ out: tuple[dict[Hashable, Variable], ...] = tuple({} for _ in range(num_outputs))
+ for name, values in result_vars.items():
+ for value, results_dict in zip(values, out):
+ results_dict[name] = value
+ return out
def _check_core_dims(signature, variable_args, name):
@@ -161,75 +412,507 @@ def _check_core_dims(signature, variable_args, name):
give a detailed error message, which requires inspecting the variable in
the inner loop.
"""
- pass
-
-
-def apply_dict_of_variables_vfunc(func, *args, signature: _UFuncSignature,
- join='inner', fill_value=None, on_missing_core_dim:
- MissingCoreDimOptions='raise'):
+ missing = []
+ for i, (core_dims, variable_arg) in enumerate(
+ zip(signature.input_core_dims, variable_args)
+ ):
+ # Check whether all the dims are on the variable. Note that we need the
+ # `hasattr` to check for a dims property, to protect against the case where
+ # a numpy array is passed in.
+ if hasattr(variable_arg, "dims") and set(core_dims) - set(variable_arg.dims):
+ missing += [[i, variable_arg, core_dims]]
+ if missing:
+ message = ""
+ for i, variable_arg, core_dims in missing:
+ message += f"Missing core dims {set(core_dims) - set(variable_arg.dims)} from arg number {i + 1} on a variable named `{name}`:\n{variable_arg}\n\n"
+ message += "Either add the core dimension, or if passing a dataset alternatively pass `on_missing_core_dim` as `copy` or `drop`. "
+ return message
+ return True
+
+
+def apply_dict_of_variables_vfunc(
+ func,
+ *args,
+ signature: _UFuncSignature,
+ join="inner",
+ fill_value=None,
+ on_missing_core_dim: MissingCoreDimOptions = "raise",
+):
"""Apply a variable level function over dicts of DataArray, DataArray,
Variable and ndarray objects.
"""
- pass
-
-
-def _fast_dataset(variables: dict[Hashable, Variable], coord_variables:
- Mapping[Hashable, Variable], indexes: dict[Hashable, Index]) ->Dataset:
+ args = tuple(_as_variables_or_variable(arg) for arg in args)
+ names = join_dict_keys(args, how=join)
+ grouped_by_name = collect_dict_values(args, names, fill_value)
+
+ result_vars = {}
+ for name, variable_args in zip(names, grouped_by_name):
+ core_dim_present = _check_core_dims(signature, variable_args, name)
+ if core_dim_present is True:
+ result_vars[name] = func(*variable_args)
+ else:
+ if on_missing_core_dim == "raise":
+ raise ValueError(core_dim_present)
+ elif on_missing_core_dim == "copy":
+ result_vars[name] = variable_args[0]
+ elif on_missing_core_dim == "drop":
+ pass
+ else:
+ raise ValueError(
+ f"Invalid value for `on_missing_core_dim`: {on_missing_core_dim!r}"
+ )
+
+ if signature.num_outputs > 1:
+ return _unpack_dict_tuples(result_vars, signature.num_outputs)
+ else:
+ return result_vars
+
+
+def _fast_dataset(
+ variables: dict[Hashable, Variable],
+ coord_variables: Mapping[Hashable, Variable],
+ indexes: dict[Hashable, Index],
+) -> Dataset:
"""Create a dataset as quickly as possible.
Beware: the `variables` dict is modified INPLACE.
"""
- pass
-
+ from xarray.core.dataset import Dataset
-def apply_dataset_vfunc(func, *args, signature: _UFuncSignature, join=
- 'inner', dataset_join='exact', fill_value=_NO_FILL_VALUE, exclude_dims=
- frozenset(), keep_attrs='override', on_missing_core_dim:
- MissingCoreDimOptions='raise') ->(Dataset | tuple[Dataset, ...]):
+ variables.update(coord_variables)
+ coord_names = set(coord_variables)
+ return Dataset._construct_direct(variables, coord_names, indexes=indexes)
+
+
+def apply_dataset_vfunc(
+ func,
+ *args,
+ signature: _UFuncSignature,
+ join="inner",
+ dataset_join="exact",
+ fill_value=_NO_FILL_VALUE,
+ exclude_dims=frozenset(),
+ keep_attrs="override",
+ on_missing_core_dim: MissingCoreDimOptions = "raise",
+) -> Dataset | tuple[Dataset, ...]:
"""Apply a variable level function over Dataset, dict of DataArray,
DataArray, Variable and/or ndarray objects.
"""
- pass
+ from xarray.core.dataset import Dataset
+
+ if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE:
+ raise TypeError(
+ "to apply an operation to datasets with different "
+ "data variables with apply_ufunc, you must supply the "
+ "dataset_fill_value argument."
+ )
+
+ objs = _all_of_type(args, Dataset)
+
+ if len(args) > 1:
+ args = tuple(
+ deep_align(
+ args,
+ join=join,
+ copy=False,
+ exclude=exclude_dims,
+ raise_on_invalid=False,
+ )
+ )
+
+ list_of_coords, list_of_indexes = build_output_coords_and_indexes(
+ args, signature, exclude_dims, combine_attrs=keep_attrs
+ )
+ args = tuple(getattr(arg, "data_vars", arg) for arg in args)
+
+ result_vars = apply_dict_of_variables_vfunc(
+ func,
+ *args,
+ signature=signature,
+ join=dataset_join,
+ fill_value=fill_value,
+ on_missing_core_dim=on_missing_core_dim,
+ )
+
+ out: Dataset | tuple[Dataset, ...]
+ if signature.num_outputs > 1:
+ out = tuple(
+ _fast_dataset(*args)
+ for args in zip(result_vars, list_of_coords, list_of_indexes)
+ )
+ else:
+ (coord_vars,) = list_of_coords
+ (indexes,) = list_of_indexes
+ out = _fast_dataset(result_vars, coord_vars, indexes=indexes)
+
+ attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs)
+ if isinstance(out, tuple):
+ for ds in out:
+ ds.attrs = attrs
+ else:
+ out.attrs = attrs
+
+ return out
def _iter_over_selections(obj, dim, values):
"""Iterate over selections of an xarray object in the provided order."""
- pass
+ from xarray.core.groupby import _dummy_copy
+
+ dummy = None
+ for value in values:
+ try:
+ obj_sel = obj.sel(**{dim: value})
+ except (KeyError, IndexError):
+ if dummy is None:
+ dummy = _dummy_copy(obj)
+ obj_sel = dummy
+ yield obj_sel
def apply_groupby_func(func, *args):
"""Apply a dataset or datarray level function over GroupBy, Dataset,
DataArray, Variable and/or ndarray objects.
"""
- pass
+ from xarray.core.groupby import GroupBy, peek_at
+ from xarray.core.variable import Variable
+
+ groupbys = [arg for arg in args if isinstance(arg, GroupBy)]
+ assert groupbys, "must have at least one groupby to iterate over"
+ first_groupby = groupbys[0]
+ (grouper,) = first_groupby.groupers
+ if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr]
+ raise ValueError(
+ "apply_ufunc can only perform operations over "
+ "multiple GroupBy objects at once if they are all "
+ "grouped the same way"
+ )
+
+ grouped_dim = grouper.name
+ unique_values = grouper.unique_coord.values
+
+ iterators = []
+ for arg in args:
+ iterator: Iterator[Any]
+ if isinstance(arg, GroupBy):
+ iterator = (value for _, value in arg)
+ elif hasattr(arg, "dims") and grouped_dim in arg.dims:
+ if isinstance(arg, Variable):
+ raise ValueError(
+ "groupby operations cannot be performed with "
+ "xarray.Variable objects that share a dimension with "
+ "the grouped dimension"
+ )
+ iterator = _iter_over_selections(arg, grouped_dim, unique_values)
+ else:
+ iterator = itertools.repeat(arg)
+ iterators.append(iterator)
+
+ applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators))
+ applied_example, applied = peek_at(applied)
+ combine = first_groupby._combine # type: ignore[attr-defined]
+ if isinstance(applied_example, tuple):
+ combined = tuple(combine(output) for output in zip(*applied))
+ else:
+ combined = combine(applied)
+ return combined
+
+
+def unified_dim_sizes(
+ variables: Iterable[Variable], exclude_dims: Set = frozenset()
+) -> dict[Hashable, int]:
+ dim_sizes: dict[Hashable, int] = {}
+
+ for var in variables:
+ if len(set(var.dims)) < len(var.dims):
+ raise ValueError(
+ "broadcasting cannot handle duplicate "
+ f"dimensions on a variable: {list(var.dims)}"
+ )
+ for dim, size in zip(var.dims, var.shape):
+ if dim not in exclude_dims:
+ if dim not in dim_sizes:
+ dim_sizes[dim] = size
+ elif dim_sizes[dim] != size:
+ raise ValueError(
+ "operands cannot be broadcast together "
+ "with mismatched lengths for dimension "
+ f"{dim}: {dim_sizes[dim]} vs {size}"
+ )
+ return dim_sizes
SLICE_NONE = slice(None)
-def apply_variable_ufunc(func, *args, signature: _UFuncSignature,
- exclude_dims=frozenset(), dask='forbidden', output_dtypes=None,
- vectorize=False, keep_attrs='override', dask_gufunc_kwargs=None) ->(
- Variable | tuple[Variable, ...]):
+def broadcast_compat_data(
+ variable: Variable,
+ broadcast_dims: tuple[Hashable, ...],
+ core_dims: tuple[Hashable, ...],
+) -> Any:
+ data = variable.data
+
+ old_dims = variable.dims
+ new_dims = broadcast_dims + core_dims
+
+ if new_dims == old_dims:
+ # optimize for the typical case
+ return data
+
+ set_old_dims = set(old_dims)
+ set_new_dims = set(new_dims)
+ unexpected_dims = [d for d in old_dims if d not in set_new_dims]
+
+ if unexpected_dims:
+ raise ValueError(
+ "operand to apply_ufunc encountered unexpected "
+ f"dimensions {unexpected_dims!r} on an input variable: these are core "
+ "dimensions on other input or output variables"
+ )
+
+ # for consistency with numpy, keep broadcast dimensions to the left
+ old_broadcast_dims = tuple(d for d in broadcast_dims if d in set_old_dims)
+ reordered_dims = old_broadcast_dims + core_dims
+ if reordered_dims != old_dims:
+ order = tuple(old_dims.index(d) for d in reordered_dims)
+ data = duck_array_ops.transpose(data, order)
+
+ if new_dims != reordered_dims:
+ key_parts: list[slice | None] = []
+ for dim in new_dims:
+ if dim in set_old_dims:
+ key_parts.append(SLICE_NONE)
+ elif key_parts:
+ # no need to insert new axes at the beginning that are already
+ # handled by broadcasting
+ key_parts.append(np.newaxis)
+ data = data[tuple(key_parts)]
+
+ return data
+
+
+def _vectorize(func, signature, output_dtypes, exclude_dims):
+ if signature.all_core_dims:
+ func = np.vectorize(
+ func,
+ otypes=output_dtypes,
+ signature=signature.to_gufunc_string(exclude_dims),
+ )
+ else:
+ func = np.vectorize(func, otypes=output_dtypes)
+
+ return func
+
+
+def apply_variable_ufunc(
+ func,
+ *args,
+ signature: _UFuncSignature,
+ exclude_dims=frozenset(),
+ dask="forbidden",
+ output_dtypes=None,
+ vectorize=False,
+ keep_attrs="override",
+ dask_gufunc_kwargs=None,
+) -> Variable | tuple[Variable, ...]:
"""Apply a ndarray level function over Variable and/or ndarray objects."""
- pass
+ from xarray.core.formatting import short_array_repr
+ from xarray.core.variable import Variable, as_compatible_data
+
+ dim_sizes = unified_dim_sizes(
+ (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
+ )
+ broadcast_dims = tuple(
+ dim for dim in dim_sizes if dim not in signature.all_core_dims
+ )
+ output_dims = [broadcast_dims + out for out in signature.output_core_dims]
+
+ input_data = [
+ (
+ broadcast_compat_data(arg, broadcast_dims, core_dims)
+ if isinstance(arg, Variable)
+ else arg
+ )
+ for arg, core_dims in zip(args, signature.input_core_dims)
+ ]
+
+ if any(is_chunked_array(array) for array in input_data):
+ if dask == "forbidden":
+ raise ValueError(
+ "apply_ufunc encountered a chunked array on an "
+ "argument, but handling for chunked arrays has not "
+ "been enabled. Either set the ``dask`` argument "
+ "or load your data into memory first with "
+ "``.load()`` or ``.compute()``"
+ )
+ elif dask == "parallelized":
+ chunkmanager = get_chunked_array_type(*input_data)
+
+ numpy_func = func
+
+ if dask_gufunc_kwargs is None:
+ dask_gufunc_kwargs = {}
+ else:
+ dask_gufunc_kwargs = dask_gufunc_kwargs.copy()
+
+ allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None)
+ if allow_rechunk is None:
+ for n, (data, core_dims) in enumerate(
+ zip(input_data, signature.input_core_dims)
+ ):
+ if is_chunked_array(data):
+ # core dimensions cannot span multiple chunks
+ for axis, dim in enumerate(core_dims, start=-len(core_dims)):
+ if len(data.chunks[axis]) != 1:
+ raise ValueError(
+ f"dimension {dim} on {n}th function argument to "
+ "apply_ufunc with dask='parallelized' consists of "
+ "multiple chunks, but is also a core dimension. To "
+ "fix, either rechunk into a single array chunk along "
+ f"this dimension, i.e., ``.chunk(dict({dim}=-1))``, or "
+ "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` "
+ "but beware that this may significantly increase memory usage."
+ )
+ dask_gufunc_kwargs["allow_rechunk"] = True
+
+ output_sizes = dask_gufunc_kwargs.pop("output_sizes", {})
+ if output_sizes:
+ output_sizes_renamed = {}
+ for key, value in output_sizes.items():
+ if key not in signature.all_output_core_dims:
+ raise ValueError(
+ f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims"
+ )
+ output_sizes_renamed[signature.dims_map[key]] = value
+ dask_gufunc_kwargs["output_sizes"] = output_sizes_renamed
+
+ for key in signature.all_output_core_dims:
+ if (
+ key not in signature.all_input_core_dims or key in exclude_dims
+ ) and key not in output_sizes:
+ raise ValueError(
+ f"dimension '{key}' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'"
+ )
+
+ def func(*arrays):
+ res = chunkmanager.apply_gufunc(
+ numpy_func,
+ signature.to_gufunc_string(exclude_dims),
+ *arrays,
+ vectorize=vectorize,
+ output_dtypes=output_dtypes,
+ **dask_gufunc_kwargs,
+ )
+
+ return res
+
+ elif dask == "allowed":
+ pass
+ else:
+ raise ValueError(
+ f"unknown setting for chunked array handling in apply_ufunc: {dask}"
+ )
+ else:
+ if vectorize:
+ func = _vectorize(
+ func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims
+ )
+
+ result_data = func(*input_data)
+
+ if signature.num_outputs == 1:
+ result_data = (result_data,)
+ elif (
+ not isinstance(result_data, tuple) or len(result_data) != signature.num_outputs
+ ):
+ raise ValueError(
+ f"applied function does not have the number of "
+ f"outputs specified in the ufunc signature. "
+ f"Received a {type(result_data)} with {len(result_data)} elements. "
+ f"Expected a tuple of {signature.num_outputs} elements:\n\n"
+ f"{limit_lines(repr(result_data), limit=10)}"
+ )
+
+ objs = _all_of_type(args, Variable)
+ attrs = merge_attrs(
+ [obj.attrs for obj in objs],
+ combine_attrs=keep_attrs,
+ )
+
+ output: list[Variable] = []
+ for dims, data in zip(output_dims, result_data):
+ data = as_compatible_data(data)
+ if data.ndim != len(dims):
+ raise ValueError(
+ "applied function returned data with an unexpected "
+ f"number of dimensions. Received {data.ndim} dimension(s) but "
+ f"expected {len(dims)} dimensions with names {dims!r}, from:\n\n"
+ f"{short_array_repr(data)}"
+ )
+
+ var = Variable(dims, data, fastpath=True)
+ for dim, new_size in var.sizes.items():
+ if dim in dim_sizes and new_size != dim_sizes[dim]:
+ raise ValueError(
+ f"size of dimension '{dim}' on inputs was unexpectedly "
+ f"changed by applied function from {dim_sizes[dim]} to {new_size}. Only "
+ "dimensions specified in ``exclude_dims`` with "
+ "xarray.apply_ufunc are allowed to change size. "
+ "The data returned was:\n\n"
+ f"{short_array_repr(data)}"
+ )
+
+ var.attrs = attrs
+ output.append(var)
+ if signature.num_outputs == 1:
+ return output[0]
+ else:
+ return tuple(output)
-def apply_array_ufunc(func, *args, dask='forbidden'):
+
+def apply_array_ufunc(func, *args, dask="forbidden"):
"""Apply a ndarray level function over ndarray objects."""
- pass
-
-
-def apply_ufunc(func: Callable, *args: Any, input_core_dims: (Sequence[
- Sequence] | None)=None, output_core_dims: (Sequence[Sequence] | None)=(
- (),), exclude_dims: Set=frozenset(), vectorize: bool=False, join:
- JoinOptions='exact', dataset_join: str='exact', dataset_fill_value:
- object=_NO_FILL_VALUE, keep_attrs: (bool | str | None)=None, kwargs: (
- Mapping | None)=None, dask: Literal['forbidden', 'allowed',
- 'parallelized']='forbidden', output_dtypes: (Sequence | None)=None,
- output_sizes: (Mapping[Any, int] | None)=None, meta: Any=None,
- dask_gufunc_kwargs: (dict[str, Any] | None)=None, on_missing_core_dim:
- MissingCoreDimOptions='raise') ->Any:
+ if any(is_chunked_array(arg) for arg in args):
+ if dask == "forbidden":
+ raise ValueError(
+ "apply_ufunc encountered a dask array on an "
+ "argument, but handling for dask arrays has not "
+ "been enabled. Either set the ``dask`` argument "
+ "or load your data into memory first with "
+ "``.load()`` or ``.compute()``"
+ )
+ elif dask == "parallelized":
+ raise ValueError(
+ "cannot use dask='parallelized' for apply_ufunc "
+ "unless at least one input is an xarray object"
+ )
+ elif dask == "allowed":
+ pass
+ else:
+ raise ValueError(f"unknown setting for dask array handling: {dask}")
+ return func(*args)
+
+
+def apply_ufunc(
+ func: Callable,
+ *args: Any,
+ input_core_dims: Sequence[Sequence] | None = None,
+ output_core_dims: Sequence[Sequence] | None = ((),),
+ exclude_dims: Set = frozenset(),
+ vectorize: bool = False,
+ join: JoinOptions = "exact",
+ dataset_join: str = "exact",
+ dataset_fill_value: object = _NO_FILL_VALUE,
+ keep_attrs: bool | str | None = None,
+ kwargs: Mapping | None = None,
+ dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden",
+ output_dtypes: Sequence | None = None,
+ output_sizes: Mapping[Any, int] | None = None,
+ meta: Any = None,
+ dask_gufunc_kwargs: dict[str, Any] | None = None,
+ on_missing_core_dim: MissingCoreDimOptions = "raise",
+) -> Any:
"""Apply a vectorized function for unlabeled arrays on xarray objects.
The function will be mapped over the data variable(s) of the input
@@ -247,7 +930,8 @@ def apply_ufunc(func: Callable, *args: Any, input_core_dims: (Sequence[
the style of NumPy universal functions [1]_ (if this is not the case,
set ``vectorize=True``). If this function returns multiple outputs, you
must set ``output_core_dims`` as well.
- *args : Dataset, DataArray, DataArrayGroupBy, DatasetGroupBy, Variable, numpy.ndarray, dask.array.Array or scalar
+ *args : Dataset, DataArray, DataArrayGroupBy, DatasetGroupBy, Variable, \
+ numpy.ndarray, dask.array.Array or scalar
Mix of labeled and/or unlabeled arrays to which to apply the function.
input_core_dims : sequence of sequence, optional
List of the same length as ``args`` giving the list of core dimensions
@@ -471,11 +1155,139 @@ def apply_ufunc(func: Callable, *args: Any, input_core_dims: (Sequence[
.. [1] https://numpy.org/doc/stable/reference/ufuncs.html
.. [2] https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html
"""
- pass
-
+ from xarray.core.dataarray import DataArray
+ from xarray.core.groupby import GroupBy
+ from xarray.core.variable import Variable
+
+ if input_core_dims is None:
+ input_core_dims = ((),) * (len(args))
+ elif len(input_core_dims) != len(args):
+ raise ValueError(
+ f"input_core_dims must be None or a tuple with the length same to "
+ f"the number of arguments. "
+ f"Given {len(input_core_dims)} input_core_dims: {input_core_dims}, "
+ f" but number of args is {len(args)}."
+ )
+
+ if kwargs is None:
+ kwargs = {}
+
+ signature = _UFuncSignature(input_core_dims, output_core_dims)
+
+ if exclude_dims:
+ if not isinstance(exclude_dims, set):
+ raise TypeError(
+ f"Expected exclude_dims to be a 'set'. Received '{type(exclude_dims).__name__}' instead."
+ )
+ if not exclude_dims <= signature.all_core_dims:
+ raise ValueError(
+ f"each dimension in `exclude_dims` must also be a "
+ f"core dimension in the function signature. "
+ f"Please make {(exclude_dims - signature.all_core_dims)} a core dimension"
+ )
-def cov(da_a: T_DataArray, da_b: T_DataArray, dim: Dims=None, ddof: int=1,
- weights: (T_DataArray | None)=None) ->T_DataArray:
+ # handle dask_gufunc_kwargs
+ if dask == "parallelized":
+ if dask_gufunc_kwargs is None:
+ dask_gufunc_kwargs = {}
+ else:
+ dask_gufunc_kwargs = dask_gufunc_kwargs.copy()
+ # todo: remove warnings after deprecation cycle
+ if meta is not None:
+ warnings.warn(
+ "``meta`` should be given in the ``dask_gufunc_kwargs`` parameter."
+ " It will be removed as direct parameter in a future version.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ dask_gufunc_kwargs.setdefault("meta", meta)
+ if output_sizes is not None:
+ warnings.warn(
+ "``output_sizes`` should be given in the ``dask_gufunc_kwargs`` "
+ "parameter. It will be removed as direct parameter in a future "
+ "version.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ dask_gufunc_kwargs.setdefault("output_sizes", output_sizes)
+
+ if kwargs:
+ func = functools.partial(func, **kwargs)
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ if isinstance(keep_attrs, bool):
+ keep_attrs = "override" if keep_attrs else "drop"
+
+ variables_vfunc = functools.partial(
+ apply_variable_ufunc,
+ func,
+ signature=signature,
+ exclude_dims=exclude_dims,
+ keep_attrs=keep_attrs,
+ dask=dask,
+ vectorize=vectorize,
+ output_dtypes=output_dtypes,
+ dask_gufunc_kwargs=dask_gufunc_kwargs,
+ )
+
+ # feed groupby-apply_ufunc through apply_groupby_func
+ if any(isinstance(a, GroupBy) for a in args):
+ this_apply = functools.partial(
+ apply_ufunc,
+ func,
+ input_core_dims=input_core_dims,
+ output_core_dims=output_core_dims,
+ exclude_dims=exclude_dims,
+ join=join,
+ dataset_join=dataset_join,
+ dataset_fill_value=dataset_fill_value,
+ keep_attrs=keep_attrs,
+ dask=dask,
+ vectorize=vectorize,
+ output_dtypes=output_dtypes,
+ dask_gufunc_kwargs=dask_gufunc_kwargs,
+ )
+ return apply_groupby_func(this_apply, *args)
+ # feed datasets apply_variable_ufunc through apply_dataset_vfunc
+ elif any(is_dict_like(a) for a in args):
+ return apply_dataset_vfunc(
+ variables_vfunc,
+ *args,
+ signature=signature,
+ join=join,
+ exclude_dims=exclude_dims,
+ dataset_join=dataset_join,
+ fill_value=dataset_fill_value,
+ keep_attrs=keep_attrs,
+ on_missing_core_dim=on_missing_core_dim,
+ )
+ # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
+ elif any(isinstance(a, DataArray) for a in args):
+ return apply_dataarray_vfunc(
+ variables_vfunc,
+ *args,
+ signature=signature,
+ join=join,
+ exclude_dims=exclude_dims,
+ keep_attrs=keep_attrs,
+ )
+ # feed Variables directly through apply_variable_ufunc
+ elif any(isinstance(a, Variable) for a in args):
+ return variables_vfunc(*args)
+ else:
+ # feed anything else through apply_array_ufunc
+ return apply_array_ufunc(func, *args, dask=dask)
+
+
+def cov(
+ da_a: T_DataArray,
+ da_b: T_DataArray,
+ dim: Dims = None,
+ ddof: int = 1,
+ weights: T_DataArray | None = None,
+) -> T_DataArray:
"""
Compute covariance between two DataArray objects along a shared dimension.
@@ -563,11 +1375,25 @@ def cov(da_a: T_DataArray, da_b: T_DataArray, dim: Dims=None, ddof: int=1,
Coordinates:
* time (time) datetime64[ns] 24B 2000-01-01 2000-01-02 2000-01-03
"""
- pass
-
+ from xarray.core.dataarray import DataArray
-def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims=None, weights: (
- T_DataArray | None)=None) ->T_DataArray:
+ if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]):
+ raise TypeError(
+ "Only xr.DataArray is supported."
+ f"Given {[type(arr) for arr in [da_a, da_b]]}."
+ )
+ if weights is not None:
+ if not isinstance(weights, DataArray):
+ raise TypeError(f"Only xr.DataArray is supported. Given {type(weights)}.")
+ return _cov_corr(da_a, da_b, weights=weights, dim=dim, ddof=ddof, method="cov")
+
+
+def corr(
+ da_a: T_DataArray,
+ da_b: T_DataArray,
+ dim: Dims = None,
+ weights: T_DataArray | None = None,
+) -> T_DataArray:
"""
Compute the Pearson correlation coefficient between
two DataArray objects along a shared dimension.
@@ -653,21 +1479,83 @@ def corr(da_a: T_DataArray, da_b: T_DataArray, dim: Dims=None, weights: (
Coordinates:
* time (time) datetime64[ns] 24B 2000-01-01 2000-01-02 2000-01-03
"""
- pass
-
+ from xarray.core.dataarray import DataArray
-def _cov_corr(da_a: T_DataArray, da_b: T_DataArray, weights: (T_DataArray |
- None)=None, dim: Dims=None, ddof: int=0, method: Literal['cov', 'corr',
- None]=None) ->T_DataArray:
+ if any(not isinstance(arr, DataArray) for arr in [da_a, da_b]):
+ raise TypeError(
+ "Only xr.DataArray is supported."
+ f"Given {[type(arr) for arr in [da_a, da_b]]}."
+ )
+ if weights is not None:
+ if not isinstance(weights, DataArray):
+ raise TypeError(f"Only xr.DataArray is supported. Given {type(weights)}.")
+ return _cov_corr(da_a, da_b, weights=weights, dim=dim, method="corr")
+
+
+def _cov_corr(
+ da_a: T_DataArray,
+ da_b: T_DataArray,
+ weights: T_DataArray | None = None,
+ dim: Dims = None,
+ ddof: int = 0,
+ method: Literal["cov", "corr", None] = None,
+) -> T_DataArray:
"""
Internal method for xr.cov() and xr.corr() so only have to
sanitize the input arrays once and we don't repeat code.
"""
- pass
-
-
-def cross(a: (DataArray | Variable), b: (DataArray | Variable), *, dim:
- Hashable) ->(DataArray | Variable):
+ # 1. Broadcast the two arrays
+ da_a, da_b = align(da_a, da_b, join="inner", copy=False)
+
+ # 2. Ignore the nans
+ valid_values = da_a.notnull() & da_b.notnull()
+ da_a = da_a.where(valid_values)
+ da_b = da_b.where(valid_values)
+
+ # 3. Detrend along the given dim
+ if weights is not None:
+ demeaned_da_a = da_a - da_a.weighted(weights).mean(dim=dim)
+ demeaned_da_b = da_b - da_b.weighted(weights).mean(dim=dim)
+ else:
+ demeaned_da_a = da_a - da_a.mean(dim=dim)
+ demeaned_da_b = da_b - da_b.mean(dim=dim)
+
+ # 4. Compute covariance along the given dim
+ # N.B. `skipna=True` is required or auto-covariance is computed incorrectly. E.g.
+ # Try xr.cov(da,da) for da = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
+ if weights is not None:
+ cov = (
+ (demeaned_da_a.conj() * demeaned_da_b)
+ .weighted(weights)
+ .mean(dim=dim, skipna=True)
+ )
+ else:
+ cov = (demeaned_da_a.conj() * demeaned_da_b).mean(dim=dim, skipna=True)
+
+ if method == "cov":
+ # Adjust covariance for degrees of freedom
+ valid_count = valid_values.sum(dim)
+ adjust = valid_count / (valid_count - ddof)
+ # I think the cast is required because of `T_DataArray` + `T_Xarray` (would be
+ # the same with `T_DatasetOrArray`)
+ # https://github.com/pydata/xarray/pull/8384#issuecomment-1784228026
+ return cast(T_DataArray, cov * adjust)
+
+ else:
+ # Compute std and corr
+ if weights is not None:
+ da_a_std = da_a.weighted(weights).std(dim=dim)
+ da_b_std = da_b.weighted(weights).std(dim=dim)
+ else:
+ da_a_std = da_a.std(dim=dim)
+ da_b_std = da_b.std(dim=dim)
+ corr = cov / (da_a_std * da_b_std)
+ return cast(T_DataArray, corr)
+
+
+def cross(
+ a: DataArray | Variable, b: DataArray | Variable, *, dim: Hashable
+) -> DataArray | Variable:
"""
Compute the cross product of two (arrays of) vectors.
@@ -759,11 +1647,83 @@ def cross(a: (DataArray | Variable), b: (DataArray | Variable), *, dim:
--------
numpy.cross : Corresponding numpy function
"""
- pass
+
+ if dim not in a.dims:
+ raise ValueError(f"Dimension {dim!r} not on a")
+ elif dim not in b.dims:
+ raise ValueError(f"Dimension {dim!r} not on b")
+
+ if not 1 <= a.sizes[dim] <= 3:
+ raise ValueError(
+ f"The size of {dim!r} on a must be 1, 2, or 3 to be "
+ f"compatible with a cross product but is {a.sizes[dim]}"
+ )
+ elif not 1 <= b.sizes[dim] <= 3:
+ raise ValueError(
+ f"The size of {dim!r} on b must be 1, 2, or 3 to be "
+ f"compatible with a cross product but is {b.sizes[dim]}"
+ )
+
+ all_dims = list(dict.fromkeys(a.dims + b.dims))
+
+ if a.sizes[dim] != b.sizes[dim]:
+ # Arrays have different sizes. Append zeros where the smaller
+ # array is missing a value, zeros will not affect np.cross:
+
+ if (
+ not isinstance(a, Variable) # Only used to make mypy happy.
+ and dim in getattr(a, "coords", {})
+ and not isinstance(b, Variable) # Only used to make mypy happy.
+ and dim in getattr(b, "coords", {})
+ ):
+ # If the arrays have coords we know which indexes to fill
+ # with zeros:
+ a, b = align(
+ a,
+ b,
+ fill_value=0,
+ join="outer",
+ exclude=set(all_dims) - {dim},
+ )
+ elif min(a.sizes[dim], b.sizes[dim]) == 2:
+ # If the array doesn't have coords we can only infer
+ # that it has composite values if the size is at least 2.
+ # Once padded, rechunk the padded array because apply_ufunc
+ # requires core dimensions not to be chunked:
+ if a.sizes[dim] < b.sizes[dim]:
+ a = a.pad({dim: (0, 1)}, constant_values=0)
+ # TODO: Should pad or apply_ufunc handle correct chunking?
+ a = a.chunk({dim: -1}) if is_duck_dask_array(a.data) else a
+ else:
+ b = b.pad({dim: (0, 1)}, constant_values=0)
+ # TODO: Should pad or apply_ufunc handle correct chunking?
+ b = b.chunk({dim: -1}) if is_duck_dask_array(b.data) else b
+ else:
+ raise ValueError(
+ f"{dim!r} on {'a' if a.sizes[dim] == 1 else 'b'} is incompatible:"
+ " dimensions without coordinates must have have a length of 2 or 3"
+ )
+
+ c = apply_ufunc(
+ np.cross,
+ a,
+ b,
+ input_core_dims=[[dim], [dim]],
+ output_core_dims=[[dim] if a.sizes[dim] == 3 else []],
+ dask="parallelized",
+ output_dtypes=[np.result_type(a, b)],
+ )
+ c = c.transpose(*all_dims, missing_dims="ignore")
+
+ return c
@deprecate_dims
-def dot(*arrays, dim: Dims=None, **kwargs: Any):
+def dot(
+ *arrays,
+ dim: Dims = None,
+ **kwargs: Any,
+):
"""Generalized dot product for xarray objects. Like ``np.einsum``, but
provides a simpler interface based on array dimension names.
@@ -851,7 +1811,71 @@ def dot(*arrays, dim: Dims=None, **kwargs: Any):
<xarray.DataArray ()> Size: 8B
array(235)
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.variable import Variable
+
+ if any(not isinstance(arr, (Variable, DataArray)) for arr in arrays):
+ raise TypeError(
+ "Only xr.DataArray and xr.Variable are supported."
+ f"Given {[type(arr) for arr in arrays]}."
+ )
+
+ if len(arrays) == 0:
+ raise TypeError("At least one array should be given.")
+
+ common_dims: set[Hashable] = set.intersection(*(set(arr.dims) for arr in arrays))
+ all_dims = []
+ for arr in arrays:
+ all_dims += [d for d in arr.dims if d not in all_dims]
+
+ einsum_axes = "abcdefghijklmnopqrstuvwxyz"
+ dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}
+
+ if dim is None:
+ # find dimensions that occur more than once
+ dim_counts: Counter = Counter()
+ for arr in arrays:
+ dim_counts.update(arr.dims)
+ dim = tuple(d for d, c in dim_counts.items() if c > 1)
+ else:
+ dim = parse_dims(dim, all_dims=tuple(all_dims))
+
+ dot_dims: set[Hashable] = set(dim)
+
+ # dimensions to be parallelized
+ broadcast_dims = common_dims - dot_dims
+ input_core_dims = [
+ [d for d in arr.dims if d not in broadcast_dims] for arr in arrays
+ ]
+ output_core_dims = [
+ [d for d in all_dims if d not in dot_dims and d not in broadcast_dims]
+ ]
+
+ # construct einsum subscripts, such as '...abc,...ab->...c'
+ # Note: input_core_dims are always moved to the last position
+ subscripts_list = [
+ "..." + "".join(dim_map[d] for d in ds) for ds in input_core_dims
+ ]
+ subscripts = ",".join(subscripts_list)
+ subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0])
+
+ join = OPTIONS["arithmetic_join"]
+ # using "inner" emulates `(a * b).sum()` for all joins (except "exact")
+ if join != "exact":
+ join = "inner"
+
+ # subscripts should be passed to np.einsum as arg, not as kwargs. We need
+ # to construct a partial function for apply_ufunc to work.
+ func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
+ result = apply_ufunc(
+ func,
+ *arrays,
+ input_core_dims=input_core_dims,
+ output_core_dims=output_core_dims,
+ join=join,
+ dask="allowed",
+ )
+ return result.transpose(*all_dims, missing_dims="ignore")
def where(cond, x, y, keep_attrs=None):
@@ -937,11 +1961,82 @@ def where(cond, x, y, keep_attrs=None):
Dataset.where, DataArray.where :
equivalent methods
"""
- pass
-
+ from xarray.core.dataset import Dataset
-def polyval(coord: (Dataset | DataArray), coeffs: (Dataset | DataArray),
- degree_dim: Hashable='degree') ->(Dataset | DataArray):
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ # alignment for three arguments is complicated, so don't support it yet
+ result = apply_ufunc(
+ duck_array_ops.where,
+ cond,
+ x,
+ y,
+ join="exact",
+ dataset_join="exact",
+ dask="allowed",
+ keep_attrs=keep_attrs,
+ )
+
+ # keep the attributes of x, the second parameter, by default to
+ # be consistent with the `where` method of `DataArray` and `Dataset`
+ # rebuild the attrs from x at each level of the output, which could be
+ # Dataset, DataArray, or Variable, and also handle coords
+ if keep_attrs is True and hasattr(result, "attrs"):
+ if isinstance(y, Dataset) and not isinstance(x, Dataset):
+ # handle special case where x gets promoted to Dataset
+ result.attrs = {}
+ if getattr(x, "name", None) in result.data_vars:
+ result[x.name].attrs = getattr(x, "attrs", {})
+ else:
+ # otherwise, fill in global attrs and variable attrs (if they exist)
+ result.attrs = getattr(x, "attrs", {})
+ for v in getattr(result, "data_vars", []):
+ result[v].attrs = getattr(getattr(x, v, None), "attrs", {})
+ for c in getattr(result, "coords", []):
+ # always fill coord attrs of x
+ result[c].attrs = getattr(getattr(x, c, None), "attrs", {})
+
+ return result
+
+
+@overload
+def polyval(
+ coord: DataArray, coeffs: DataArray, degree_dim: Hashable = "degree"
+) -> DataArray: ...
+
+
+@overload
+def polyval(
+ coord: DataArray, coeffs: Dataset, degree_dim: Hashable = "degree"
+) -> Dataset: ...
+
+
+@overload
+def polyval(
+ coord: Dataset, coeffs: DataArray, degree_dim: Hashable = "degree"
+) -> Dataset: ...
+
+
+@overload
+def polyval(
+ coord: Dataset, coeffs: Dataset, degree_dim: Hashable = "degree"
+) -> Dataset: ...
+
+
+@overload
+def polyval(
+ coord: Dataset | DataArray,
+ coeffs: Dataset | DataArray,
+ degree_dim: Hashable = "degree",
+) -> Dataset | DataArray: ...
+
+
+def polyval(
+ coord: Dataset | DataArray,
+ coeffs: Dataset | DataArray,
+ degree_dim: Hashable = "degree",
+) -> Dataset | DataArray:
"""Evaluate a polynomial at specific values
Parameters
@@ -963,10 +2058,32 @@ def polyval(coord: (Dataset | DataArray), coeffs: (Dataset | DataArray),
xarray.DataArray.polyfit
numpy.polynomial.polynomial.polyval
"""
- pass
-
-def _ensure_numeric(data: (Dataset | DataArray)) ->(Dataset | DataArray):
+ if degree_dim not in coeffs._indexes:
+ raise ValueError(
+ f"Dimension `{degree_dim}` should be a coordinate variable with labels."
+ )
+ if not np.issubdtype(coeffs[degree_dim].dtype, np.integer):
+ raise ValueError(
+ f"Dimension `{degree_dim}` should be of integer dtype. Received {coeffs[degree_dim].dtype} instead."
+ )
+ max_deg = coeffs[degree_dim].max().item()
+ coeffs = coeffs.reindex(
+ {degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False
+ )
+ coord = _ensure_numeric(coord)
+
+ # using Horner's method
+ # https://en.wikipedia.org/wiki/Horner%27s_method
+ res = zeros_like(coord) + coeffs.isel({degree_dim: max_deg}, drop=True)
+ for deg in range(max_deg - 1, -1, -1):
+ res *= coord
+ res += coeffs.isel({degree_dim: deg}, drop=True)
+
+ return res
+
+
+def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray:
"""Converts all datetime64 variables to float64
Parameters
@@ -979,23 +2096,124 @@ def _ensure_numeric(data: (Dataset | DataArray)) ->(Dataset | DataArray):
DataArray or Dataset
Variables with datetime64 dtypes converted to float64.
"""
- pass
-
+ from xarray.core.dataset import Dataset
-def _calc_idxminmax(*, array, func: Callable, dim: (Hashable | None)=None,
- skipna: (bool | None)=None, fill_value: Any=dtypes.NA, keep_attrs: (
- bool | None)=None):
+ def _cfoffset(x: DataArray) -> Any:
+ scalar = x.compute().data[0]
+ if not is_scalar(scalar):
+ # we do not get a scalar back on dask == 2021.04.1
+ scalar = scalar.item()
+ return type(scalar)(1970, 1, 1)
+
+ def to_floatable(x: DataArray) -> DataArray:
+ if x.dtype.kind in "MO":
+ # datetimes (CFIndexes are object type)
+ offset = (
+ np.datetime64("1970-01-01") if x.dtype.kind == "M" else _cfoffset(x)
+ )
+ return x.copy(
+ data=datetime_to_numeric(x.data, offset=offset, datetime_unit="ns"),
+ )
+ elif x.dtype.kind == "m":
+ # timedeltas
+ return duck_array_ops.astype(x, dtype=float)
+ return x
+
+ if isinstance(data, Dataset):
+ return data.map(to_floatable)
+ else:
+ return to_floatable(data)
+
+
+def _calc_idxminmax(
+ *,
+ array,
+ func: Callable,
+ dim: Hashable | None = None,
+ skipna: bool | None = None,
+ fill_value: Any = dtypes.NA,
+ keep_attrs: bool | None = None,
+):
"""Apply common operations for idxmin and idxmax."""
- pass
+ # This function doesn't make sense for scalars so don't try
+ if not array.ndim:
+ raise ValueError("This function does not apply for scalars")
+
+ if dim is not None:
+ pass # Use the dim if available
+ elif array.ndim == 1:
+ # it is okay to guess the dim if there is only 1
+ dim = array.dims[0]
+ else:
+ # The dim is not specified and ambiguous. Don't guess.
+ raise ValueError("Must supply 'dim' argument for multidimensional arrays")
+
+ if dim not in array.dims:
+ raise KeyError(
+ f"Dimension {dim!r} not found in array dimensions {array.dims!r}"
+ )
+ if dim not in array.coords:
+ raise KeyError(
+ f"Dimension {dim!r} is not one of the coordinates {tuple(array.coords.keys())}"
+ )
+
+ # These are dtypes with NaN values argmin and argmax can handle
+ na_dtypes = "cfO"
+
+ if skipna or (skipna is None and array.dtype.kind in na_dtypes):
+ # Need to skip NaN values since argmin and argmax can't handle them
+ allna = array.isnull().all(dim)
+ array = array.where(~allna, 0)
+
+ # This will run argmin or argmax.
+ indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
+
+ # Handle chunked arrays (e.g. dask).
+ if is_chunked_array(array.data):
+ chunkmanager = get_chunked_array_type(array.data)
+ chunks = dict(zip(array.dims, array.chunks))
+ dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
+ data = dask_coord[duck_array_ops.ravel(indx.data)]
+ res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
+ # we need to attach back the dim name
+ res.name = dim
+ else:
+ res = array[dim][(indx,)]
+ # The dim is gone but we need to remove the corresponding coordinate.
+ del res.coords[dim]
+ if skipna or (skipna is None and array.dtype.kind in na_dtypes):
+ # Put the NaN values back in after removing them
+ res = res.where(~allna, fill_value)
-_T = TypeVar('_T', bound=Union['Dataset', 'DataArray'])
-_U = TypeVar('_U', bound=Union['Dataset', 'DataArray'])
-_V = TypeVar('_V', bound=Union['Dataset', 'DataArray'])
+ # Copy attributes from argmin/argmax, if any
+ res.attrs = indx.attrs
+ return res
-def unify_chunks(*objects: (Dataset | DataArray)) ->tuple[Dataset |
- DataArray, ...]:
+
+_T = TypeVar("_T", bound=Union["Dataset", "DataArray"])
+_U = TypeVar("_U", bound=Union["Dataset", "DataArray"])
+_V = TypeVar("_V", bound=Union["Dataset", "DataArray"])
+
+
+@overload
+def unify_chunks(__obj: _T) -> tuple[_T]: ...
+
+
+@overload
+def unify_chunks(__obj1: _T, __obj2: _U) -> tuple[_T, _U]: ...
+
+
+@overload
+def unify_chunks(__obj1: _T, __obj2: _U, __obj3: _V) -> tuple[_T, _U, _V]: ...
+
+
+@overload
+def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]: ...
+
+
+def unify_chunks(*objects: Dataset | DataArray) -> tuple[Dataset | DataArray, ...]:
"""
Given any number of Dataset and/or DataArray objects, returns
new objects with unified chunk size along all chunked dimensions.
@@ -1009,4 +2227,43 @@ def unify_chunks(*objects: (Dataset | DataArray)) ->tuple[Dataset |
--------
dask.array.core.unify_chunks
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ # Convert all objects to datasets
+ datasets = [
+ obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy()
+ for obj in objects
+ ]
+
+ # Get arguments to pass into dask.array.core.unify_chunks
+ unify_chunks_args = []
+ sizes: dict[Hashable, int] = {}
+ for ds in datasets:
+ for v in ds._variables.values():
+ if v.chunks is not None:
+ # Check that sizes match across different datasets
+ for dim, size in v.sizes.items():
+ try:
+ if sizes[dim] != size:
+ raise ValueError(
+ f"Dimension {dim!r} size mismatch: {sizes[dim]} != {size}"
+ )
+ except KeyError:
+ sizes[dim] = size
+ unify_chunks_args += [v._data, v._dims]
+
+ # No dask arrays: Return inputs
+ if not unify_chunks_args:
+ return objects
+
+ chunkmanager = get_chunked_array_type(*[arg for arg in unify_chunks_args])
+ _, chunked_data = chunkmanager.unify_chunks(*unify_chunks_args)
+ chunked_data_iter = iter(chunked_data)
+ out: list[Dataset | DataArray] = []
+ for obj, ds in zip(objects, datasets):
+ for k, v in ds._variables.items():
+ if v.chunks is not None:
+ ds._variables[k] = v.copy(data=next(chunked_data_iter))
+ out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds)
+
+ return tuple(out)
diff --git a/xarray/core/concat.py b/xarray/core/concat.py
index 636b3856..15292bdb 100644
--- a/xarray/core/concat.py
+++ b/xarray/core/concat.py
@@ -1,26 +1,80 @@
from __future__ import annotations
+
from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Union, overload
+
import numpy as np
import pandas as pd
+
from xarray.core import dtypes, utils
from xarray.core.alignment import align, reindex_variables
from xarray.core.coordinates import Coordinates
from xarray.core.duck_array_ops import lazy_array_equiv
from xarray.core.indexes import Index, PandasIndex
-from xarray.core.merge import _VALID_COMPAT, collect_variables_and_indexes, merge_attrs, merge_collected
+from xarray.core.merge import (
+ _VALID_COMPAT,
+ collect_variables_and_indexes,
+ merge_attrs,
+ merge_collected,
+)
from xarray.core.types import T_DataArray, T_Dataset, T_Variable
from xarray.core.variable import Variable
from xarray.core.variable import concat as concat_vars
+
if TYPE_CHECKING:
- from xarray.core.types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions
+ from xarray.core.types import (
+ CombineAttrsOptions,
+ CompatOptions,
+ ConcatOptions,
+ JoinOptions,
+ )
+
T_DataVars = Union[ConcatOptions, Iterable[Hashable]]
-def concat(objs, dim, data_vars: T_DataVars='all', coords='different',
- compat: CompatOptions='equals', positions=None, fill_value=dtypes.NA,
- join: JoinOptions='outer', combine_attrs: CombineAttrsOptions=
- 'override', create_index_for_new_dim: bool=True):
+# TODO: replace dim: Any by 1D array_likes
+@overload
+def concat(
+ objs: Iterable[T_Dataset],
+ dim: Hashable | T_Variable | T_DataArray | pd.Index | Any,
+ data_vars: T_DataVars = "all",
+ coords: ConcatOptions | list[Hashable] = "different",
+ compat: CompatOptions = "equals",
+ positions: Iterable[Iterable[int]] | None = None,
+ fill_value: object = dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "override",
+ create_index_for_new_dim: bool = True,
+) -> T_Dataset: ...
+
+
+@overload
+def concat(
+ objs: Iterable[T_DataArray],
+ dim: Hashable | T_Variable | T_DataArray | pd.Index | Any,
+ data_vars: T_DataVars = "all",
+ coords: ConcatOptions | list[Hashable] = "different",
+ compat: CompatOptions = "equals",
+ positions: Iterable[Iterable[int]] | None = None,
+ fill_value: object = dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "override",
+ create_index_for_new_dim: bool = True,
+) -> T_DataArray: ...
+
+
+def concat(
+ objs,
+ dim,
+ data_vars: T_DataVars = "all",
+ coords="different",
+ compat: CompatOptions = "equals",
+ positions=None,
+ fill_value=dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "override",
+ create_index_for_new_dim: bool = True,
+):
"""Concatenate xarray objects along a new or existing dimension.
Parameters
@@ -97,7 +151,8 @@ def concat(objs, dim, data_vars: T_DataVars='all', coords='different',
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -189,33 +244,528 @@ def concat(objs, dim, data_vars: T_DataVars='all', coords='different',
Indexes:
*empty*
"""
- pass
+ # TODO: add ignore_index arguments copied from pandas.concat
+ # TODO: support concatenating scalar coordinates even if the concatenated
+ # dimension already exists
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+ try:
+ first_obj, objs = utils.peek_at(objs)
+ except StopIteration:
+ raise ValueError("must supply at least one object to concatenate")
-def _calc_concat_dim_index(dim_or_data: (Hashable | Any)) ->tuple[Hashable,
- PandasIndex | None]:
+ if compat not in _VALID_COMPAT:
+ raise ValueError(
+ f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'"
+ )
+
+ if isinstance(first_obj, DataArray):
+ return _dataarray_concat(
+ objs,
+ dim=dim,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ positions=positions,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ create_index_for_new_dim=create_index_for_new_dim,
+ )
+ elif isinstance(first_obj, Dataset):
+ return _dataset_concat(
+ objs,
+ dim=dim,
+ data_vars=data_vars,
+ coords=coords,
+ compat=compat,
+ positions=positions,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ create_index_for_new_dim=create_index_for_new_dim,
+ )
+ else:
+ raise TypeError(
+ "can only concatenate xarray Dataset and DataArray "
+ f"objects, got {type(first_obj)}"
+ )
+
+
+def _calc_concat_dim_index(
+ dim_or_data: Hashable | Any,
+) -> tuple[Hashable, PandasIndex | None]:
"""Infer the dimension name and 1d index / coordinate variable (if appropriate)
for concatenating along the new dimension.
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ dim: Hashable | None
+
+ if utils.hashable(dim_or_data):
+ dim = dim_or_data
+ index = None
+ else:
+ if not isinstance(dim_or_data, (DataArray, Variable)):
+ dim = getattr(dim_or_data, "name", None)
+ if dim is None:
+ dim = "concat_dim"
+ else:
+ (dim,) = dim_or_data.dims
+ coord_dtype = getattr(dim_or_data, "dtype", None)
+ index = PandasIndex(dim_or_data, dim, coord_dtype=coord_dtype)
+ return dim, index
-def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars,
- coords, compat):
+
+def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, compat):
"""
Determine which dataset variables need to be concatenated in the result,
"""
- pass
+ # Return values
+ concat_over = set()
+ equals = {}
+
+ if dim in dim_names:
+ concat_over_existing_dim = True
+ concat_over.add(dim)
+ else:
+ concat_over_existing_dim = False
+
+ concat_dim_lengths = []
+ for ds in datasets:
+ if concat_over_existing_dim:
+ if dim not in ds.dims:
+ if dim in ds:
+ ds = ds.set_coords(dim)
+ concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
+ concat_dim_lengths.append(ds.sizes.get(dim, 1))
+
+ def process_subset_opt(opt, subset):
+ if isinstance(opt, str):
+ if opt == "different":
+ if compat == "override":
+ raise ValueError(
+ f"Cannot specify both {subset}='different' and compat='override'."
+ )
+ # all nonindexes that are not the same in each dataset
+ for k in getattr(datasets[0], subset):
+ if k not in concat_over:
+ equals[k] = None
+ variables = [
+ ds.variables[k] for ds in datasets if k in ds.variables
+ ]
-def _dataset_concat(datasets: Iterable[T_Dataset], dim: (str | T_Variable |
- T_DataArray | pd.Index), data_vars: T_DataVars, coords: (str | list[str
- ]), compat: CompatOptions, positions: (Iterable[Iterable[int]] | None),
- fill_value: Any=dtypes.NA, join: JoinOptions='outer', combine_attrs:
- CombineAttrsOptions='override', create_index_for_new_dim: bool=True
- ) ->T_Dataset:
+ if len(variables) == 1:
+ # coords="different" doesn't make sense when only one object
+ # contains a particular variable.
+ break
+ elif len(variables) != len(datasets) and opt == "different":
+ raise ValueError(
+ f"{k!r} not present in all datasets and coords='different'. "
+ f"Either add {k!r} to datasets where it is missing or "
+ "specify coords='minimal'."
+ )
+
+ # first check without comparing values i.e. no computes
+ for var in variables[1:]:
+ equals[k] = getattr(variables[0], compat)(
+ var, equiv=lazy_array_equiv
+ )
+ if equals[k] is not True:
+ # exit early if we know these are not equal or that
+ # equality cannot be determined i.e. one or all of
+ # the variables wraps a numpy array
+ break
+
+ if equals[k] is False:
+ concat_over.add(k)
+
+ elif equals[k] is None:
+ # Compare the variable of all datasets vs. the one
+ # of the first dataset. Perform the minimum amount of
+ # loads in order to avoid multiple loads from disk
+ # while keeping the RAM footprint low.
+ v_lhs = datasets[0].variables[k].load()
+ # We'll need to know later on if variables are equal.
+ computed = []
+ for ds_rhs in datasets[1:]:
+ v_rhs = ds_rhs.variables[k].compute()
+ computed.append(v_rhs)
+ if not getattr(v_lhs, compat)(v_rhs):
+ concat_over.add(k)
+ equals[k] = False
+ # computed variables are not to be re-computed
+ # again in the future
+ for ds, v in zip(datasets[1:], computed):
+ ds.variables[k].data = v.data
+ break
+ else:
+ equals[k] = True
+
+ elif opt == "all":
+ concat_over.update(
+ set().union(
+ *list(set(getattr(d, subset)) - set(d.dims) for d in datasets)
+ )
+ )
+ elif opt == "minimal":
+ pass
+ else:
+ raise ValueError(f"unexpected value for {subset}: {opt}")
+ else:
+ valid_vars = tuple(getattr(datasets[0], subset))
+ invalid_vars = [k for k in opt if k not in valid_vars]
+ if invalid_vars:
+ if subset == "coords":
+ raise ValueError(
+ f"the variables {invalid_vars} in coords are not "
+ f"found in the coordinates of the first dataset {valid_vars}"
+ )
+ else:
+ # note: data_vars are not listed in the error message here,
+ # because there may be lots of them
+ raise ValueError(
+ f"the variables {invalid_vars} in data_vars are not "
+ f"found in the data variables of the first dataset"
+ )
+ concat_over.update(opt)
+
+ process_subset_opt(data_vars, "data_vars")
+ process_subset_opt(coords, "coords")
+ return concat_over, equals, concat_dim_lengths
+
+
+# determine dimensional coordinate names and a dict mapping name to DataArray
+def _parse_datasets(
+ datasets: list[T_Dataset],
+) -> tuple[
+ dict[Hashable, Variable],
+ dict[Hashable, int],
+ set[Hashable],
+ set[Hashable],
+ list[Hashable],
+]:
+ dims: set[Hashable] = set()
+ all_coord_names: set[Hashable] = set()
+ data_vars: set[Hashable] = set() # list of data_vars
+ dim_coords: dict[Hashable, Variable] = {} # maps dim name to variable
+ dims_sizes: dict[Hashable, int] = {} # shared dimension sizes to expand variables
+ variables_order: dict[Hashable, Variable] = {} # variables in order of appearance
+
+ for ds in datasets:
+ dims_sizes.update(ds.sizes)
+ all_coord_names.update(ds.coords)
+ data_vars.update(ds.data_vars)
+ variables_order.update(ds.variables)
+
+ # preserves ordering of dimensions
+ for dim in ds.dims:
+ if dim in dims:
+ continue
+
+ if dim in ds.coords and dim not in dim_coords:
+ dim_coords[dim] = ds.coords[dim].variable
+ dims = dims | set(ds.dims)
+
+ return dim_coords, dims_sizes, all_coord_names, data_vars, list(variables_order)
+
+
+def _dataset_concat(
+ datasets: Iterable[T_Dataset],
+ dim: str | T_Variable | T_DataArray | pd.Index,
+ data_vars: T_DataVars,
+ coords: str | list[str],
+ compat: CompatOptions,
+ positions: Iterable[Iterable[int]] | None,
+ fill_value: Any = dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "override",
+ create_index_for_new_dim: bool = True,
+) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ datasets = list(datasets)
+
+ if not all(isinstance(dataset, Dataset) for dataset in datasets):
+ raise TypeError(
+ "The elements in the input list need to be either all 'Dataset's or all 'DataArray's"
+ )
+
+ if isinstance(dim, DataArray):
+ dim_var = dim.variable
+ elif isinstance(dim, Variable):
+ dim_var = dim
+ else:
+ dim_var = None
+
+ dim_name, index = _calc_concat_dim_index(dim)
+
+ # Make sure we're working on a copy (we'll be loading variables)
+ datasets = [ds.copy() for ds in datasets]
+ datasets = list(
+ align(
+ *datasets, join=join, copy=False, exclude=[dim_name], fill_value=fill_value
+ )
+ )
+
+ dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets(
+ datasets
+ )
+ dim_names = set(dim_coords)
+
+ both_data_and_coords = coord_names & data_names
+ if both_data_and_coords:
+ raise ValueError(
+ f"{both_data_and_coords!r} is a coordinate in some datasets but not others."
+ )
+ # we don't want the concat dimension in the result dataset yet
+ dim_coords.pop(dim_name, None)
+ dims_sizes.pop(dim_name, None)
+
+ # case where concat dimension is a coordinate or data_var but not a dimension
+ if (
+ dim_name in coord_names or dim_name in data_names
+ ) and dim_name not in dim_names:
+ datasets = [
+ ds.expand_dims(dim_name, create_index_for_new_dim=create_index_for_new_dim)
+ for ds in datasets
+ ]
+
+ # determine which variables to concatenate
+ concat_over, equals, concat_dim_lengths = _calc_concat_over(
+ datasets, dim_name, dim_names, data_vars, coords, compat
+ )
+
+ # determine which variables to merge, and then merge them according to compat
+ variables_to_merge = (coord_names | data_names) - concat_over
+
+ result_vars = {}
+ result_indexes = {}
+
+ if variables_to_merge:
+ grouped = {
+ k: v
+ for k, v in collect_variables_and_indexes(datasets).items()
+ if k in variables_to_merge
+ }
+ merged_vars, merged_indexes = merge_collected(
+ grouped, compat=compat, equals=equals
+ )
+ result_vars.update(merged_vars)
+ result_indexes.update(merged_indexes)
+
+ result_vars.update(dim_coords)
+
+ # assign attrs and encoding from first dataset
+ result_attrs = merge_attrs([ds.attrs for ds in datasets], combine_attrs)
+ result_encoding = datasets[0].encoding
+
+ # check that global attributes are fixed across all datasets if necessary
+ if compat == "identical":
+ for ds in datasets[1:]:
+ if not utils.dict_equiv(ds.attrs, result_attrs):
+ raise ValueError("Dataset global attributes not equal.")
+
+ # we've already verified everything is consistent; now, calculate
+ # shared dimension sizes so we can expand the necessary variables
+ def ensure_common_dims(vars, concat_dim_lengths):
+ # ensure each variable with the given name shares the same
+ # dimensions and the same shape for all of them except along the
+ # concat dimension
+ common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims))
+ if dim_name not in common_dims:
+ common_dims = (dim_name,) + common_dims
+ for var, dim_len in zip(vars, concat_dim_lengths):
+ if var.dims != common_dims:
+ common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims)
+ var = var.set_dims(common_dims, common_shape)
+ yield var
+
+ # get the indexes to concatenate together, create a PandasIndex
+ # for any scalar coordinate variable found with ``name`` matching ``dim``.
+ # TODO: depreciate concat a mix of scalar and dimensional indexed coordinates?
+ # TODO: (benbovy - explicit indexes): check index types and/or coordinates
+ # of all datasets?
+ def get_indexes(name):
+ for ds in datasets:
+ if name in ds._indexes:
+ yield ds._indexes[name]
+ elif name == dim_name:
+ var = ds._variables[name]
+ if not var.dims:
+ data = var.set_dims(dim_name).values
+ if create_index_for_new_dim:
+ yield PandasIndex(data, dim_name, coord_dtype=var.dtype)
+
+ # create concatenation index, needed for later reindexing
+ file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
+ concat_index = np.arange(file_start_indexes[-1])
+ concat_index_size = concat_index.size
+ variable_index_mask = np.ones(concat_index_size, dtype=bool)
+
+ # stack up each variable and/or index to fill-out the dataset (in order)
+ # n.b. this loop preserves variable order, needed for groupby.
+ ndatasets = len(datasets)
+ for name in vars_order:
+ if name in concat_over and name not in result_indexes:
+ variables = []
+ # Initialize the mask to all True then set False if any name is missing in
+ # the datasets:
+ variable_index_mask.fill(True)
+ var_concat_dim_length = []
+ for i, ds in enumerate(datasets):
+ if name in ds.variables:
+ variables.append(ds[name].variable)
+ var_concat_dim_length.append(concat_dim_lengths[i])
+ else:
+ # raise if coordinate not in all datasets
+ if name in coord_names:
+ raise ValueError(
+ f"coordinate {name!r} not present in all datasets."
+ )
+
+ # Mask out the indexes without the name:
+ start = file_start_indexes[i]
+ end = file_start_indexes[i + 1]
+ variable_index_mask[slice(start, end)] = False
+
+ variable_index = concat_index[variable_index_mask]
+ vars = ensure_common_dims(variables, var_concat_dim_length)
+
+ # Try to concatenate the indexes, concatenate the variables when no index
+ # is found on all datasets.
+ indexes: list[Index] = list(get_indexes(name))
+ if indexes:
+ if len(indexes) < ndatasets:
+ raise ValueError(
+ f"{name!r} must have either an index or no index in all datasets, "
+ f"found {len(indexes)}/{len(datasets)} datasets with an index."
+ )
+ combined_idx = indexes[0].concat(indexes, dim_name, positions)
+ if name in datasets[0]._indexes:
+ idx_vars = datasets[0].xindexes.get_all_coords(name)
+ else:
+ # index created from a scalar coordinate
+ idx_vars = {name: datasets[0][name].variable}
+ result_indexes.update({k: combined_idx for k in idx_vars})
+ combined_idx_vars = combined_idx.create_variables(idx_vars)
+ for k, v in combined_idx_vars.items():
+ v.attrs = merge_attrs(
+ [ds.variables[k].attrs for ds in datasets],
+ combine_attrs=combine_attrs,
+ )
+ result_vars[k] = v
+ else:
+ combined_var = concat_vars(
+ vars, dim_name, positions, combine_attrs=combine_attrs
+ )
+ # reindex if variable is not present in all datasets
+ if len(variable_index) < concat_index_size:
+ combined_var = reindex_variables(
+ variables={name: combined_var},
+ dim_pos_indexers={
+ dim_name: pd.Index(variable_index).get_indexer(concat_index)
+ },
+ fill_value=fill_value,
+ )[name]
+ result_vars[name] = combined_var
+
+ elif name in result_vars:
+ # preserves original variable order
+ result_vars[name] = result_vars.pop(name)
+
+ absent_coord_names = coord_names - set(result_vars)
+ if absent_coord_names:
+ raise ValueError(
+ f"Variables {absent_coord_names!r} are coordinates in some datasets but not others."
+ )
+
+ result_data_vars = {}
+ coord_vars = {}
+ for name, result_var in result_vars.items():
+ if name in coord_names:
+ coord_vars[name] = result_var
+ else:
+ result_data_vars[name] = result_var
+
+ if index is not None:
+ if dim_var is not None:
+ index_vars = index.create_variables({dim_name: dim_var})
+ else:
+ index_vars = index.create_variables()
+
+ coord_vars[dim_name] = index_vars[dim_name]
+ result_indexes[dim_name] = index
+
+ coords_obj = Coordinates(coord_vars, indexes=result_indexes)
+
+ result = type(datasets[0])(result_data_vars, coords=coords_obj, attrs=result_attrs)
+ result.encoding = result_encoding
+
+ return result
+
+
+def _dataarray_concat(
+ arrays: Iterable[T_DataArray],
+ dim: str | T_Variable | T_DataArray | pd.Index,
+ data_vars: T_DataVars,
+ coords: str | list[str],
+ compat: CompatOptions,
+ positions: Iterable[Iterable[int]] | None,
+ fill_value: object = dtypes.NA,
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "override",
+ create_index_for_new_dim: bool = True,
+) -> T_DataArray:
+ from xarray.core.dataarray import DataArray
+
+ arrays = list(arrays)
+
+ if not all(isinstance(array, DataArray) for array in arrays):
+ raise TypeError(
+ "The elements in the input list need to be either all 'Dataset's or all 'DataArray's"
+ )
+
+ if data_vars != "all":
+ raise ValueError(
+ "data_vars is not a valid argument when concatenating DataArray objects"
+ )
+
+ datasets = []
+ for n, arr in enumerate(arrays):
+ if n == 0:
+ name = arr.name
+ elif name != arr.name:
+ if compat == "identical":
+ raise ValueError("array names not identical")
+ else:
+ arr = arr.rename(name)
+ datasets.append(arr._to_temp_dataset())
+
+ ds = _dataset_concat(
+ datasets,
+ dim,
+ data_vars,
+ coords,
+ compat,
+ positions,
+ fill_value=fill_value,
+ join=join,
+ combine_attrs=combine_attrs,
+ create_index_for_new_dim=create_index_for_new_dim,
+ )
+
+ merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs)
+
+ result = arrays[0]._from_temp_dataset(ds, name)
+ result.attrs = merged_attrs
+
+ return result
diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py
index 2740e51a..251edd1f 100644
--- a/xarray/core/coordinates.py
+++ b/xarray/core/coordinates.py
@@ -1,32 +1,68 @@
from __future__ import annotations
+
from collections.abc import Hashable, Iterator, Mapping, Sequence
from contextlib import contextmanager
-from typing import TYPE_CHECKING, Any, Generic, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Generic,
+ cast,
+)
+
import numpy as np
import pandas as pd
+
from xarray.core import formatting
from xarray.core.alignment import Aligner
-from xarray.core.indexes import Index, Indexes, PandasIndex, PandasMultiIndex, assert_no_index_corrupted, create_default_index_implicit
+from xarray.core.indexes import (
+ Index,
+ Indexes,
+ PandasIndex,
+ PandasMultiIndex,
+ assert_no_index_corrupted,
+ create_default_index_implicit,
+)
from xarray.core.merge import merge_coordinates_without_align, merge_coords
from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray
-from xarray.core.utils import Frozen, ReprObject, either_dict_or_kwargs, emit_user_level_warning
+from xarray.core.utils import (
+ Frozen,
+ ReprObject,
+ either_dict_or_kwargs,
+ emit_user_level_warning,
+)
from xarray.core.variable import Variable, as_variable, calculate_dimensions
+
if TYPE_CHECKING:
from xarray.core.common import DataWithCoords
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
-_THIS_ARRAY = ReprObject('<this-array>')
+
+# Used as the key corresponding to a DataArray's variable when converting
+# arbitrary DataArray objects to datasets
+_THIS_ARRAY = ReprObject("<this-array>")
-class AbstractCoordinates(Mapping[Hashable, 'T_DataArray']):
+class AbstractCoordinates(Mapping[Hashable, "T_DataArray"]):
_data: DataWithCoords
- __slots__ = '_data',
+ __slots__ = ("_data",)
+
+ def __getitem__(self, key: Hashable) -> T_DataArray:
+ raise NotImplementedError()
+
+ @property
+ def _names(self) -> set[Hashable]:
+ raise NotImplementedError()
+
+ @property
+ def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]:
+ raise NotImplementedError()
- def __getitem__(self, key: Hashable) ->T_DataArray:
+ @property
+ def dtypes(self) -> Frozen[Hashable, np.dtype]:
raise NotImplementedError()
@property
- def indexes(self) ->Indexes[pd.Index]:
+ def indexes(self) -> Indexes[pd.Index]:
"""Mapping of pandas.Index objects used for label based indexing.
Raises an error if this Coordinates object has indexes that cannot
@@ -36,31 +72,44 @@ class AbstractCoordinates(Mapping[Hashable, 'T_DataArray']):
--------
Coordinates.xindexes
"""
- pass
+ return self._data.indexes
@property
- def xindexes(self) ->Indexes[Index]:
+ def xindexes(self) -> Indexes[Index]:
"""Mapping of :py:class:`~xarray.indexes.Index` objects
used for label based indexing.
"""
- pass
+ return self._data.xindexes
+
+ @property
+ def variables(self):
+ raise NotImplementedError()
+
+ def _update_coords(self, coords, indexes):
+ raise NotImplementedError()
+
+ def _drop_coords(self, coord_names):
+ raise NotImplementedError()
- def __iter__(self) ->Iterator[Hashable]:
+ def __iter__(self) -> Iterator[Hashable]:
+ # needs to be in the same order as the dataset variables
for k in self.variables:
if k in self._names:
yield k
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self._names)
- def __contains__(self, key: Hashable) ->bool:
+ def __contains__(self, key: Hashable) -> bool:
return key in self._names
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
return formatting.coords_repr(self)
- def to_index(self, ordered_dims: (Sequence[Hashable] | None)=None
- ) ->pd.Index:
+ def to_dataset(self) -> Dataset:
+ raise NotImplementedError()
+
+ def to_index(self, ordered_dims: Sequence[Hashable] | None = None) -> pd.Index:
"""Convert all index coordinates into a :py:class:`pandas.Index`.
Parameters
@@ -76,7 +125,64 @@ class AbstractCoordinates(Mapping[Hashable, 'T_DataArray']):
coordinates. This will be a MultiIndex if this object is has more
than more dimension.
"""
- pass
+ if ordered_dims is None:
+ ordered_dims = list(self.dims)
+ elif set(ordered_dims) != set(self.dims):
+ raise ValueError(
+ "ordered_dims must match dims, but does not: "
+ f"{ordered_dims} vs {self.dims}"
+ )
+
+ if len(ordered_dims) == 0:
+ raise ValueError("no valid index for a 0-dimensional object")
+ elif len(ordered_dims) == 1:
+ (dim,) = ordered_dims
+ return self._data.get_index(dim)
+ else:
+ indexes = [self._data.get_index(k) for k in ordered_dims]
+
+ # compute the sizes of the repeat and tile for the cartesian product
+ # (taken from pandas.core.reshape.util)
+ index_lengths = np.fromiter(
+ (len(index) for index in indexes), dtype=np.intp
+ )
+ cumprod_lengths = np.cumprod(index_lengths)
+
+ if cumprod_lengths[-1] == 0:
+ # if any factor is empty, the cartesian product is empty
+ repeat_counts = np.zeros_like(cumprod_lengths)
+
+ else:
+ # sizes of the repeats
+ repeat_counts = cumprod_lengths[-1] / cumprod_lengths
+ # sizes of the tiles
+ tile_counts = np.roll(cumprod_lengths, 1)
+ tile_counts[0] = 1
+
+ # loop over the indexes
+ # for each MultiIndex or Index compute the cartesian product of the codes
+
+ code_list = []
+ level_list = []
+ names = []
+
+ for i, index in enumerate(indexes):
+ if isinstance(index, pd.MultiIndex):
+ codes, levels = index.codes, index.levels
+ else:
+ code, level = pd.factorize(index)
+ codes = [code]
+ levels = [level]
+
+ # compute the cartesian product
+ code_list += [
+ np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i])
+ for code in codes
+ ]
+ level_list += levels
+ names += index.names
+
+ return pd.MultiIndex(level_list, code_list, names=names)
class Coordinates(AbstractCoordinates):
@@ -157,22 +263,36 @@ class Coordinates(AbstractCoordinates):
*empty*
"""
+
_data: DataWithCoords
- __slots__ = '_data',
- def __init__(self, coords: (Mapping[Any, Any] | None)=None, indexes: (
- Mapping[Any, Index] | None)=None) ->None:
+ __slots__ = ("_data",)
+
+ def __init__(
+ self,
+ coords: Mapping[Any, Any] | None = None,
+ indexes: Mapping[Any, Index] | None = None,
+ ) -> None:
+ # When coordinates are constructed directly, an internal Dataset is
+ # created so that it is compatible with the DatasetCoordinates and
+ # DataArrayCoordinates classes serving as a proxy for the data.
+ # TODO: refactor DataArray / Dataset so that Coordinates store the data.
from xarray.core.dataset import Dataset
+
if coords is None:
coords = {}
+
variables: dict[Hashable, Variable]
default_indexes: dict[Hashable, PandasIndex] = {}
coords_obj_indexes: dict[Hashable, Index] = {}
+
if isinstance(coords, Coordinates):
if indexes is not None:
raise ValueError(
- 'passing both a ``Coordinates`` object and a mapping of indexes to ``Coordinates.__init__`` is not allowed (this constructor does not support merging them)'
- )
+ "passing both a ``Coordinates`` object and a mapping of indexes "
+ "to ``Coordinates.__init__`` is not allowed "
+ "(this constructor does not support merging them)"
+ )
variables = {k: v.copy() for k, v in coords.variables.items()}
coords_obj_indexes = dict(coords.xindexes)
else:
@@ -180,35 +300,59 @@ class Coordinates(AbstractCoordinates):
for name, data in coords.items():
var = as_variable(data, name=name, auto_convert=False)
if var.dims == (name,) and indexes is None:
- index, index_vars = create_default_index_implicit(var,
- list(coords))
+ index, index_vars = create_default_index_implicit(var, list(coords))
default_indexes.update({k: index for k in index_vars})
variables.update(index_vars)
else:
variables[name] = var
+
if indexes is None:
indexes = {}
else:
indexes = dict(indexes)
+
indexes.update(default_indexes)
indexes.update(coords_obj_indexes)
+
no_coord_index = set(indexes) - set(variables)
if no_coord_index:
raise ValueError(
- f'no coordinate variables found for these indexes: {no_coord_index}'
- )
+ f"no coordinate variables found for these indexes: {no_coord_index}"
+ )
+
for k, idx in indexes.items():
if not isinstance(idx, Index):
- raise TypeError(
- f"'{k}' is not an `xarray.indexes.Index` object")
+ raise TypeError(f"'{k}' is not an `xarray.indexes.Index` object")
+
+ # maybe convert to base variable
for k, v in variables.items():
if k not in indexes:
variables[k] = v.to_base_variable()
- self._data = Dataset._construct_direct(coord_names=set(variables),
- variables=variables, indexes=indexes)
+
+ self._data = Dataset._construct_direct(
+ coord_names=set(variables), variables=variables, indexes=indexes
+ )
+
+ @classmethod
+ def _construct_direct(
+ cls,
+ coords: dict[Any, Variable],
+ indexes: dict[Any, Index],
+ dims: dict[Any, int] | None = None,
+ ) -> Self:
+ from xarray.core.dataset import Dataset
+
+ obj = object.__new__(cls)
+ obj._data = Dataset._construct_direct(
+ coord_names=set(coords),
+ variables=coords,
+ indexes=indexes,
+ dims=dims,
+ )
+ return obj
@classmethod
- def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: str) ->Self:
+ def from_pandas_multiindex(cls, midx: pd.MultiIndex, dim: str) -> Self:
"""Wrap a pandas multi-index as Xarray coordinates (dimension + levels).
The returned coordinates can be directly assigned to a
@@ -228,20 +372,29 @@ class Coordinates(AbstractCoordinates):
A collection of Xarray indexed coordinates created from the multi-index.
"""
- pass
+ xr_idx = PandasMultiIndex(midx, dim)
+
+ variables = xr_idx.create_variables()
+ indexes = {k: xr_idx for k in variables}
+
+ return cls(coords=variables, indexes=indexes)
+
+ @property
+ def _names(self) -> set[Hashable]:
+ return self._data._coord_names
@property
- def dims(self) ->(Frozen[Hashable, int] | tuple[Hashable, ...]):
+ def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]:
"""Mapping from dimension names to lengths or tuple of dimension names."""
- pass
+ return self._data.dims
@property
- def sizes(self) ->Frozen[Hashable, int]:
+ def sizes(self) -> Frozen[Hashable, int]:
"""Mapping from dimension names to lengths."""
- pass
+ return self._data.sizes
@property
- def dtypes(self) ->Frozen[Hashable, np.dtype]:
+ def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from coordinate names to dtypes.
Cannot be modified directly.
@@ -250,27 +403,29 @@ class Coordinates(AbstractCoordinates):
--------
Dataset.dtypes
"""
- pass
+ return Frozen({n: v.dtype for n, v in self._data.variables.items()})
@property
- def variables(self) ->Mapping[Hashable, Variable]:
+ def variables(self) -> Mapping[Hashable, Variable]:
"""Low level interface to Coordinates contents as dict of Variable objects.
This dictionary is frozen to prevent mutation.
"""
- pass
+ return self._data.variables
- def to_dataset(self) ->Dataset:
+ def to_dataset(self) -> Dataset:
"""Convert these coordinates into a new Dataset."""
- pass
+ names = [name for name in self._data._variables if name in self._names]
+ return self._data._copy_listed(names)
- def __getitem__(self, key: Hashable) ->DataArray:
+ def __getitem__(self, key: Hashable) -> DataArray:
return self._data[key]
- def __delitem__(self, key: Hashable) ->None:
+ def __delitem__(self, key: Hashable) -> None:
+ # redirect to DatasetCoordinates.__delitem__
del self._data.coords[key]
- def equals(self, other: Self) ->bool:
+ def equals(self, other: Self) -> bool:
"""Two Coordinates objects are equal if they have matching variables,
all of which are equal.
@@ -278,27 +433,61 @@ class Coordinates(AbstractCoordinates):
--------
Coordinates.identical
"""
- pass
+ if not isinstance(other, Coordinates):
+ return False
+ return self.to_dataset().equals(other.to_dataset())
- def identical(self, other: Self) ->bool:
+ def identical(self, other: Self) -> bool:
"""Like equals, but also checks all variable attributes.
See Also
--------
Coordinates.equals
"""
- pass
+ if not isinstance(other, Coordinates):
+ return False
+ return self.to_dataset().identical(other.to_dataset())
+
+ def _update_coords(
+ self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
+ ) -> None:
+ # redirect to DatasetCoordinates._update_coords
+ self._data.coords._update_coords(coords, indexes)
+
+ def _drop_coords(self, coord_names):
+ # redirect to DatasetCoordinates._drop_coords
+ self._data.coords._drop_coords(coord_names)
def _merge_raw(self, other, reflexive):
"""For use with binary arithmetic."""
- pass
+ if other is None:
+ variables = dict(self.variables)
+ indexes = dict(self.xindexes)
+ else:
+ coord_list = [self, other] if not reflexive else [other, self]
+ variables, indexes = merge_coordinates_without_align(coord_list)
+ return variables, indexes
@contextmanager
def _merge_inplace(self, other):
"""For use with in-place binary arithmetic."""
- pass
-
- def merge(self, other: (Mapping[Any, Any] | None)) ->Dataset:
+ if other is None:
+ yield
+ else:
+ # don't include indexes in prioritized, because we didn't align
+ # first and we want indexes to be checked
+ prioritized = {
+ k: (v, None)
+ for k, v in self.variables.items()
+ if k not in self.xindexes
+ }
+ variables, indexes = merge_coordinates_without_align(
+ [self, other], prioritized
+ )
+ yield
+ self._update_coords(variables, indexes)
+
+ def merge(self, other: Mapping[Any, Any] | None) -> Dataset:
"""Merge two sets of coordinates to create a new Dataset
The method implements the logic used for joining coordinates in the
@@ -321,17 +510,62 @@ class Coordinates(AbstractCoordinates):
merged : Dataset
A new Dataset with merged coordinates.
"""
- pass
+ from xarray.core.dataset import Dataset
- def __setitem__(self, key: Hashable, value: Any) ->None:
+ if other is None:
+ return self.to_dataset()
+
+ if not isinstance(other, Coordinates):
+ other = Dataset(coords=other).coords
+
+ coords, indexes = merge_coordinates_without_align([self, other])
+ coord_names = set(coords)
+ return Dataset._construct_direct(
+ variables=coords, coord_names=coord_names, indexes=indexes
+ )
+
+ def __setitem__(self, key: Hashable, value: Any) -> None:
self.update({key: value})
- def update(self, other: Mapping[Any, Any]) ->None:
+ def update(self, other: Mapping[Any, Any]) -> None:
"""Update this Coordinates variables with other coordinate variables."""
- pass
- def assign(self, coords: (Mapping | None)=None, **coords_kwargs: Any
- ) ->Self:
+ if not len(other):
+ return
+
+ other_coords: Coordinates
+
+ if isinstance(other, Coordinates):
+ # Coordinates object: just pass it (default indexes won't be created)
+ other_coords = other
+ else:
+ other_coords = create_coords_with_default_indexes(
+ getattr(other, "variables", other)
+ )
+
+ # Discard original indexed coordinates prior to merge allows to:
+ # - fail early if the new coordinates don't preserve the integrity of existing
+ # multi-coordinate indexes
+ # - drop & replace coordinates without alignment (note: we must keep indexed
+ # coordinates extracted from the DataArray objects passed as values to
+ # `other` - if any - as those are still used for aligning the old/new coordinates)
+ coords_to_align = drop_indexed_coords(set(other_coords) & set(other), self)
+
+ coords, indexes = merge_coords(
+ [coords_to_align, other_coords],
+ priority_arg=1,
+ indexes=coords_to_align.xindexes,
+ )
+
+ # special case for PandasMultiIndex: updating only its dimension coordinate
+ # is still allowed but depreciated.
+ # It is the only case where we need to actually drop coordinates here (multi-index levels)
+ # TODO: remove when removing PandasMultiIndex's dimension coordinate.
+ self._drop_coords(self._names - coords_to_align._names)
+
+ self._update_coords(coords, indexes)
+
+ def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self:
"""Assign new coordinates (and indexes) to a Coordinates object, returning
a new object with all the original coordinates in addition to the new ones.
@@ -378,23 +612,75 @@ class Coordinates(AbstractCoordinates):
* y_level_1 (y) int64 32B 0 1 0 1
"""
- pass
-
- def _reindex_callback(self, aligner: Aligner, dim_pos_indexers: dict[
- Hashable, Any], variables: dict[Hashable, Variable], indexes: dict[
- Hashable, Index], fill_value: Any, exclude_dims: frozenset[Hashable
- ], exclude_vars: frozenset[Hashable]) ->Self:
+ # TODO: this doesn't support a callable, which is inconsistent with `DataArray.assign_coords`
+ coords = either_dict_or_kwargs(coords, coords_kwargs, "assign")
+ new_coords = self.copy()
+ new_coords.update(coords)
+ return new_coords
+
+ def _overwrite_indexes(
+ self,
+ indexes: Mapping[Any, Index],
+ variables: Mapping[Any, Variable] | None = None,
+ ) -> Self:
+ results = self.to_dataset()._overwrite_indexes(indexes, variables)
+
+ # TODO: remove cast once we get rid of DatasetCoordinates
+ # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates)
+ return cast(Self, results.coords)
+
+ def _reindex_callback(
+ self,
+ aligner: Aligner,
+ dim_pos_indexers: dict[Hashable, Any],
+ variables: dict[Hashable, Variable],
+ indexes: dict[Hashable, Index],
+ fill_value: Any,
+ exclude_dims: frozenset[Hashable],
+ exclude_vars: frozenset[Hashable],
+ ) -> Self:
"""Callback called from ``Aligner`` to create a new reindexed Coordinate."""
- pass
+ aligned = self.to_dataset()._reindex_callback(
+ aligner,
+ dim_pos_indexers,
+ variables,
+ indexes,
+ fill_value,
+ exclude_dims,
+ exclude_vars,
+ )
+
+ # TODO: remove cast once we get rid of DatasetCoordinates
+ # and DataArrayCoordinates (i.e., Dataset and DataArray encapsulate Coordinates)
+ return cast(Self, aligned.coords)
def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
- pass
+ return self._data._ipython_key_completions_()
- def copy(self, deep: bool=False, memo: (dict[int, Any] | None)=None
- ) ->Self:
+ def copy(
+ self,
+ deep: bool = False,
+ memo: dict[int, Any] | None = None,
+ ) -> Self:
"""Return a copy of this Coordinates object."""
- pass
+ # do not copy indexes (may corrupt multi-coordinate indexes)
+ # TODO: disable variables deepcopy? it may also be problematic when they
+ # encapsulate index objects like pd.Index
+ variables = {
+ k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items()
+ }
+
+ # TODO: getting an error with `self._construct_direct`, possibly because of how
+ # a subclass implements `_construct_direct`. (This was originally the same
+ # runtime code, but we switched the type definitions in #8216, which
+ # necessitates the cast.)
+ return cast(
+ Self,
+ Coordinates._construct_direct(
+ coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes)
+ ),
+ )
class DatasetCoordinates(Coordinates):
@@ -404,14 +690,24 @@ class DatasetCoordinates(Coordinates):
and :py:class:`~xarray.DataArray` constructors via their `coords` argument.
This will add both the coordinates variables and their index.
"""
+
_data: Dataset
- __slots__ = '_data',
+
+ __slots__ = ("_data",)
def __init__(self, dataset: Dataset):
self._data = dataset
@property
- def dtypes(self) ->Frozen[Hashable, np.dtype]:
+ def _names(self) -> set[Hashable]:
+ return self._data._coord_names
+
+ @property
+ def dims(self) -> Frozen[Hashable, int]:
+ return self._data.dims
+
+ @property
+ def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from coordinate names to dtypes.
Cannot be modified directly, but is updated when adding new variables.
@@ -420,27 +716,84 @@ class DatasetCoordinates(Coordinates):
--------
Dataset.dtypes
"""
- pass
+ return Frozen(
+ {
+ n: v.dtype
+ for n, v in self._data._variables.items()
+ if n in self._data._coord_names
+ }
+ )
+
+ @property
+ def variables(self) -> Mapping[Hashable, Variable]:
+ return Frozen(
+ {k: v for k, v in self._data.variables.items() if k in self._names}
+ )
- def __getitem__(self, key: Hashable) ->DataArray:
+ def __getitem__(self, key: Hashable) -> DataArray:
if key in self._data.data_vars:
raise KeyError(key)
return self._data[key]
- def to_dataset(self) ->Dataset:
+ def to_dataset(self) -> Dataset:
"""Convert these coordinates into a new Dataset"""
- pass
- def __delitem__(self, key: Hashable) ->None:
+ names = [name for name in self._data._variables if name in self._names]
+ return self._data._copy_listed(names)
+
+ def _update_coords(
+ self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
+ ) -> None:
+ variables = self._data._variables.copy()
+ variables.update(coords)
+
+ # check for inconsistent state *before* modifying anything in-place
+ dims = calculate_dimensions(variables)
+ new_coord_names = set(coords)
+ for dim, size in dims.items():
+ if dim in variables:
+ new_coord_names.add(dim)
+
+ self._data._variables = variables
+ self._data._coord_names.update(new_coord_names)
+ self._data._dims = dims
+
+ # TODO(shoyer): once ._indexes is always populated by a dict, modify
+ # it to update inplace instead.
+ original_indexes = dict(self._data.xindexes)
+ original_indexes.update(indexes)
+ self._data._indexes = original_indexes
+
+ def _drop_coords(self, coord_names):
+ # should drop indexed coordinates only
+ for name in coord_names:
+ del self._data._variables[name]
+ del self._data._indexes[name]
+ self._data._coord_names.difference_update(coord_names)
+
+ def _drop_indexed_coords(self, coords_to_drop: set[Hashable]) -> None:
+ assert self._data.xindexes is not None
+ new_coords = drop_indexed_coords(coords_to_drop, self)
+ for name in self._data._coord_names - new_coords._names:
+ del self._data._variables[name]
+ self._data._indexes = dict(new_coords.xindexes)
+ self._data._coord_names.intersection_update(new_coords._names)
+
+ def __delitem__(self, key: Hashable) -> None:
if key in self:
del self._data[key]
else:
raise KeyError(
- f'{key!r} is not in coordinate variables {tuple(self.keys())}')
+ f"{key!r} is not in coordinate variables {tuple(self.keys())}"
+ )
def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
- pass
+ return [
+ key
+ for key in self._data._ipython_key_completions_()
+ if key not in self._data.data_vars
+ ]
class DataArrayCoordinates(Coordinates, Generic[T_DataArray]):
@@ -450,14 +803,20 @@ class DataArrayCoordinates(Coordinates, Generic[T_DataArray]):
and :py:class:`~xarray.DataArray` constructors via their `coords` argument.
This will add both the coordinates variables and their index.
"""
+
_data: T_DataArray
- __slots__ = '_data',
- def __init__(self, dataarray: T_DataArray) ->None:
+ __slots__ = ("_data",)
+
+ def __init__(self, dataarray: T_DataArray) -> None:
self._data = dataarray
@property
- def dtypes(self) ->Frozen[Hashable, np.dtype]:
+ def dims(self) -> tuple[Hashable, ...]:
+ return self._data.dims
+
+ @property
+ def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from coordinate names to dtypes.
Cannot be modified directly, but is updated when adding new variables.
@@ -466,52 +825,201 @@ class DataArrayCoordinates(Coordinates, Generic[T_DataArray]):
--------
DataArray.dtype
"""
- pass
+ return Frozen({n: v.dtype for n, v in self._data._coords.items()})
+
+ @property
+ def _names(self) -> set[Hashable]:
+ return set(self._data._coords)
- def __getitem__(self, key: Hashable) ->T_DataArray:
+ def __getitem__(self, key: Hashable) -> T_DataArray:
return self._data._getitem_coord(key)
- def __delitem__(self, key: Hashable) ->None:
+ def _update_coords(
+ self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
+ ) -> None:
+ coords_plus_data = coords.copy()
+ coords_plus_data[_THIS_ARRAY] = self._data.variable
+ dims = calculate_dimensions(coords_plus_data)
+ if not set(dims) <= set(self.dims):
+ raise ValueError(
+ "cannot add coordinates with new dimensions to a DataArray"
+ )
+ self._data._coords = coords
+
+ # TODO(shoyer): once ._indexes is always populated by a dict, modify
+ # it to update inplace instead.
+ original_indexes = dict(self._data.xindexes)
+ original_indexes.update(indexes)
+ self._data._indexes = original_indexes
+
+ def _drop_coords(self, coord_names):
+ # should drop indexed coordinates only
+ for name in coord_names:
+ del self._data._coords[name]
+ del self._data._indexes[name]
+
+ @property
+ def variables(self):
+ return Frozen(self._data._coords)
+
+ def to_dataset(self) -> Dataset:
+ from xarray.core.dataset import Dataset
+
+ coords = {k: v.copy(deep=False) for k, v in self._data._coords.items()}
+ indexes = dict(self._data.xindexes)
+ return Dataset._construct_direct(coords, set(coords), indexes=indexes)
+
+ def __delitem__(self, key: Hashable) -> None:
if key not in self:
raise KeyError(
- f'{key!r} is not in coordinate variables {tuple(self.keys())}')
+ f"{key!r} is not in coordinate variables {tuple(self.keys())}"
+ )
assert_no_index_corrupted(self._data.xindexes, {key})
+
del self._data._coords[key]
if self._data._indexes is not None and key in self._data._indexes:
del self._data._indexes[key]
def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
- pass
+ return self._data._ipython_key_completions_()
-def drop_indexed_coords(coords_to_drop: set[Hashable], coords: Coordinates
- ) ->Coordinates:
+def drop_indexed_coords(
+ coords_to_drop: set[Hashable], coords: Coordinates
+) -> Coordinates:
"""Drop indexed coordinates associated with coordinates in coords_to_drop.
This will raise an error in case it corrupts any passed index and its
coordinate variables.
"""
- pass
+ new_variables = dict(coords.variables)
+ new_indexes = dict(coords.xindexes)
+
+ for idx, idx_coords in coords.xindexes.group_by_index():
+ idx_drop_coords = set(idx_coords) & coords_to_drop
+
+ # special case for pandas multi-index: still allow but deprecate
+ # dropping only its dimension coordinate.
+ # TODO: remove when removing PandasMultiIndex's dimension coordinate.
+ if isinstance(idx, PandasMultiIndex) and idx_drop_coords == {idx.dim}:
+ idx_drop_coords.update(idx.index.names)
+ emit_user_level_warning(
+ f"updating coordinate {idx.dim!r} with a PandasMultiIndex would leave "
+ f"the multi-index level coordinates {list(idx.index.names)!r} in an inconsistent state. "
+ f"This will raise an error in the future. Use `.drop_vars({list(idx_coords)!r})` before "
+ "assigning new coordinate values.",
+ FutureWarning,
+ )
+
+ elif idx_drop_coords and len(idx_drop_coords) != len(idx_coords):
+ idx_drop_coords_str = ", ".join(f"{k!r}" for k in idx_drop_coords)
+ idx_coords_str = ", ".join(f"{k!r}" for k in idx_coords)
+ raise ValueError(
+ f"cannot drop or update coordinate(s) {idx_drop_coords_str}, which would corrupt "
+ f"the following index built from coordinates {idx_coords_str}:\n"
+ f"{idx}"
+ )
+ for k in idx_drop_coords:
+ del new_variables[k]
+ del new_indexes[k]
-def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]
- ) ->None:
+ return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes)
+
+
+def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None:
"""Make sure the dimension coordinate of obj is consistent with coords.
obj: DataArray or Dataset
coords: Dict-like of variables
"""
- pass
-
-
-def create_coords_with_default_indexes(coords: Mapping[Any, Any], data_vars:
- (DataVars | None)=None) ->Coordinates:
+ for k in obj.dims:
+ # make sure there are no conflict in dimension coordinates
+ if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable):
+ raise IndexError(
+ f"dimension coordinate {k!r} conflicts between "
+ f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}"
+ )
+
+
+def create_coords_with_default_indexes(
+ coords: Mapping[Any, Any], data_vars: DataVars | None = None
+) -> Coordinates:
"""Returns a Coordinates object from a mapping of coordinates (arbitrary objects).
Create default (pandas) indexes for each of the input dimension coordinates.
Extract coordinates from each input DataArray.
"""
- pass
+ # Note: data_vars is needed here only because a pd.MultiIndex object
+ # can be promoted as coordinates.
+ # TODO: It won't be relevant anymore when this behavior will be dropped
+ # in favor of the more explicit ``Coordinates.from_pandas_multiindex()``.
+
+ from xarray.core.dataarray import DataArray
+
+ all_variables = dict(coords)
+ if data_vars is not None:
+ all_variables.update(data_vars)
+
+ indexes: dict[Hashable, Index] = {}
+ variables: dict[Hashable, Variable] = {}
+
+ # promote any pandas multi-index in data_vars as coordinates
+ coords_promoted: dict[Hashable, Any] = {}
+ pd_mindex_keys: list[Hashable] = []
+
+ for k, v in all_variables.items():
+ if isinstance(v, pd.MultiIndex):
+ coords_promoted[k] = v
+ pd_mindex_keys.append(k)
+ elif k in coords:
+ coords_promoted[k] = v
+
+ if pd_mindex_keys:
+ pd_mindex_keys_fmt = ",".join([f"'{k}'" for k in pd_mindex_keys])
+ emit_user_level_warning(
+ f"the `pandas.MultiIndex` object(s) passed as {pd_mindex_keys_fmt} coordinate(s) or "
+ "data variable(s) will no longer be implicitly promoted and wrapped into "
+ "multiple indexed coordinates in the future "
+ "(i.e., one coordinate for each multi-index level + one dimension coordinate). "
+ "If you want to keep this behavior, you need to first wrap it explicitly using "
+ "`mindex_coords = xarray.Coordinates.from_pandas_multiindex(mindex_obj, 'dim')` "
+ "and pass it as coordinates, e.g., `xarray.Dataset(coords=mindex_coords)`, "
+ "`dataset.assign_coords(mindex_coords)` or `dataarray.assign_coords(mindex_coords)`.",
+ FutureWarning,
+ )
+
+ dataarray_coords: list[DataArrayCoordinates] = []
+
+ for name, obj in coords_promoted.items():
+ if isinstance(obj, DataArray):
+ dataarray_coords.append(obj.coords)
+
+ variable = as_variable(obj, name=name, auto_convert=False)
+
+ if variable.dims == (name,):
+ # still needed to convert to IndexVariable first due to some
+ # pandas multi-index edge cases.
+ variable = variable.to_index_variable()
+ idx, idx_vars = create_default_index_implicit(variable, all_variables)
+ indexes.update({k: idx for k in idx_vars})
+ variables.update(idx_vars)
+ all_variables.update(idx_vars)
+ else:
+ variables[name] = variable
+
+ new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)
+
+ # extract and merge coordinates and indexes from input DataArrays
+ if dataarray_coords:
+ prioritized = {k: (v, indexes.get(k, None)) for k, v in variables.items()}
+ variables, indexes = merge_coordinates_without_align(
+ dataarray_coords + [new_coords],
+ prioritized=prioritized,
+ )
+ new_coords = Coordinates._construct_direct(coords=variables, indexes=indexes)
+
+ return new_coords
diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py
index e551b759..98ff9002 100644
--- a/xarray/core/dask_array_ops.py
+++ b/xarray/core/dask_array_ops.py
@@ -1,14 +1,95 @@
from __future__ import annotations
+
from xarray.core import dtypes, nputils
def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
"""Wrapper to apply bottleneck moving window funcs on dask arrays"""
- pass
+ import dask.array as da
+
+ dtype, fill_value = dtypes.maybe_promote(a.dtype)
+ a = a.astype(dtype)
+ # inputs for overlap
+ if axis < 0:
+ axis = a.ndim + axis
+ depth = {d: 0 for d in range(a.ndim)}
+ depth[axis] = (window + 1) // 2
+ boundary = {d: fill_value for d in range(a.ndim)}
+ # Create overlap array.
+ ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
+ # apply rolling func
+ out = da.map_blocks(
+ moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype
+ )
+ # trim array
+ result = da.overlap.trim_internal(out, depth)
+ return result
+
+
+def least_squares(lhs, rhs, rcond=None, skipna=False):
+ import dask.array as da
+
+ lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1]))
+ if skipna:
+ added_dim = rhs.ndim == 1
+ if added_dim:
+ rhs = rhs.reshape(rhs.shape[0], 1)
+ results = da.apply_along_axis(
+ nputils._nanpolyfit_1d,
+ 0,
+ rhs,
+ lhs_da,
+ dtype=float,
+ shape=(lhs.shape[1] + 1,),
+ rcond=rcond,
+ )
+ coeffs = results[:-1, ...]
+ residuals = results[-1, ...]
+ if added_dim:
+ coeffs = coeffs.reshape(coeffs.shape[0])
+ residuals = residuals.reshape(residuals.shape[0])
+ else:
+ # Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
+ # See issue dask/dask#6516
+ coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
+ return coeffs, residuals
def push(array, n, axis):
"""
Dask-aware bottleneck.push
"""
- pass
+ import dask.array as da
+ import numpy as np
+
+ from xarray.core.duck_array_ops import _push
+
+ def _fill_with_last_one(a, b):
+ # cumreduction apply the push func over all the blocks first so, the only missing part is filling
+ # the missing values using the last data of the previous chunk
+ return np.where(~np.isnan(b), b, a)
+
+ if n is not None and 0 < n < array.shape[axis] - 1:
+ arange = da.broadcast_to(
+ da.arange(
+ array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype
+ ).reshape(
+ tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
+ ),
+ array.shape,
+ array.chunks,
+ )
+ valid_arange = da.where(da.notnull(array), arange, np.nan)
+ valid_limits = (arange - push(valid_arange, None, axis)) <= n
+ # omit the forward fill that violate the limit
+ return da.where(valid_limits, push(array, None, axis), np.nan)
+
+ # The method parameter makes that the tests for python 3.7 fails.
+ return da.reductions.cumreduction(
+ func=_push,
+ binop=_fill_with_last_one,
+ ident=np.nan,
+ x=array,
+ axis=axis,
+ dtype=array.dtype,
+ )
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index 3706f0b8..79fd0412 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -1,85 +1,275 @@
from __future__ import annotations
+
import datetime
import warnings
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from functools import partial
from os import PathLike
-from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, TypeVar, Union, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generic,
+ Literal,
+ NoReturn,
+ TypeVar,
+ Union,
+ overload,
+)
+
import numpy as np
import pandas as pd
+
from xarray.coding.calendar_ops import convert_calendar, interp_calendar
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.core import alignment, computation, dtypes, indexing, ops, utils
from xarray.core._aggregations import DataArrayAggregations
from xarray.core.accessor_dt import CombinedDatetimelikeAccessor
from xarray.core.accessor_str import StringAccessor
-from xarray.core.alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
+from xarray.core.alignment import (
+ _broadcast_helper,
+ _get_broadcast_dims_map_common_coords,
+ align,
+)
from xarray.core.arithmetic import DataArrayArithmetic
from xarray.core.common import AbstractArray, DataWithCoords, get_chunksizes
from xarray.core.computation import unify_chunks
-from xarray.core.coordinates import Coordinates, DataArrayCoordinates, assert_coordinate_consistent, create_coords_with_default_indexes
+from xarray.core.coordinates import (
+ Coordinates,
+ DataArrayCoordinates,
+ assert_coordinate_consistent,
+ create_coords_with_default_indexes,
+)
from xarray.core.dataset import Dataset
from xarray.core.formatting import format_item
-from xarray.core.indexes import Index, Indexes, PandasMultiIndex, filter_indexes_from_coords, isel_indexes
+from xarray.core.indexes import (
+ Index,
+ Indexes,
+ PandasMultiIndex,
+ filter_indexes_from_coords,
+ isel_indexes,
+)
from xarray.core.indexing import is_fancy_indexer, map_index_queries
from xarray.core.merge import PANDAS_TYPES, MergeError
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.types import Bins, DaCompatible, NetcdfWriteModes, T_DataArray, T_DataArrayOrSet, ZarrWriteModes
-from xarray.core.utils import Default, HybridMappingProxy, ReprObject, _default, either_dict_or_kwargs, hashable, infix_dims
-from xarray.core.variable import IndexVariable, Variable, as_compatible_data, as_variable
+from xarray.core.types import (
+ Bins,
+ DaCompatible,
+ NetcdfWriteModes,
+ T_DataArray,
+ T_DataArrayOrSet,
+ ZarrWriteModes,
+)
+from xarray.core.utils import (
+ Default,
+ HybridMappingProxy,
+ ReprObject,
+ _default,
+ either_dict_or_kwargs,
+ hashable,
+ infix_dims,
+)
+from xarray.core.variable import (
+ IndexVariable,
+ Variable,
+ as_compatible_data,
+ as_variable,
+)
from xarray.plot.accessor import DataArrayPlotAccessor
from xarray.plot.utils import _get_units_from_attrs
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
+
if TYPE_CHECKING:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from iris.cube import Cube as iris_Cube
from numpy.typing import ArrayLike
+
from xarray.backends import ZarrStore
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
from xarray.core.groupby import DataArrayGroupBy
from xarray.core.resample import DataArrayResample
from xarray.core.rolling import DataArrayCoarsen, DataArrayRolling
- from xarray.core.types import CoarsenBoundaryOptions, DatetimeLike, DatetimeUnitOptions, Dims, ErrorOptions, ErrorOptionsWithWarn, InterpOptions, PadModeOptions, PadReflectOptions, QuantileMethods, QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, Self, SideOptions, T_ChunkDimFreq, T_ChunksFreq, T_Xarray
+ from xarray.core.types import (
+ CoarsenBoundaryOptions,
+ DatetimeLike,
+ DatetimeUnitOptions,
+ Dims,
+ ErrorOptions,
+ ErrorOptionsWithWarn,
+ InterpOptions,
+ PadModeOptions,
+ PadReflectOptions,
+ QuantileMethods,
+ QueryEngineOptions,
+ QueryParserOptions,
+ ReindexMethodOptions,
+ Self,
+ SideOptions,
+ T_ChunkDimFreq,
+ T_ChunksFreq,
+ T_Xarray,
+ )
from xarray.core.weighted import DataArrayWeighted
from xarray.groupers import Grouper, Resampler
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
- T_XarrayOther = TypeVar('T_XarrayOther', bound=Union['DataArray', Dataset])
+ T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset])
-def _infer_coords_and_dims(shape: tuple[int, ...], coords: (Sequence[
- Sequence | pd.Index | DataArray | Variable | np.ndarray] | Mapping |
- None), dims: (str | Iterable[Hashable] | None)) ->tuple[Mapping[
- Hashable, Any], tuple[Hashable, ...]]:
+
+def _check_coords_dims(shape, coords, dim):
+ sizes = dict(zip(dim, shape))
+ for k, v in coords.items():
+ if any(d not in dim for d in v.dims):
+ raise ValueError(
+ f"coordinate {k} has dimensions {v.dims}, but these "
+ "are not a subset of the DataArray "
+ f"dimensions {dim}"
+ )
+
+ for d, s in v.sizes.items():
+ if s != sizes[d]:
+ raise ValueError(
+ f"conflicting sizes for dimension {d!r}: "
+ f"length {sizes[d]} on the data but length {s} on "
+ f"coordinate {k!r}"
+ )
+
+
+def _infer_coords_and_dims(
+ shape: tuple[int, ...],
+ coords: (
+ Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray]
+ | Mapping
+ | None
+ ),
+ dims: str | Iterable[Hashable] | None,
+) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]:
"""All the logic for creating a new DataArray"""
- pass
+
+ if (
+ coords is not None
+ and not utils.is_dict_like(coords)
+ and len(coords) != len(shape)
+ ):
+ raise ValueError(
+ f"coords is not dict-like, but it has {len(coords)} items, "
+ f"which does not match the {len(shape)} dimensions of the "
+ "data"
+ )
+
+ if isinstance(dims, str):
+ dims = (dims,)
+ elif dims is None:
+ dims = [f"dim_{n}" for n in range(len(shape))]
+ if coords is not None and len(coords) == len(shape):
+ # try to infer dimensions from coords
+ if utils.is_dict_like(coords):
+ dims = list(coords.keys())
+ else:
+ for n, (dim, coord) in enumerate(zip(dims, coords)):
+ coord = as_variable(
+ coord, name=dims[n], auto_convert=False
+ ).to_index_variable()
+ dims[n] = coord.name
+ dims_tuple = tuple(dims)
+ if len(dims_tuple) != len(shape):
+ raise ValueError(
+ "different number of dimensions on data "
+ f"and dims: {len(shape)} vs {len(dims_tuple)}"
+ )
+ for d in dims_tuple:
+ if not hashable(d):
+ raise TypeError(f"Dimension {d} is not hashable")
+
+ new_coords: Mapping[Hashable, Any]
+
+ if isinstance(coords, Coordinates):
+ new_coords = coords
+ else:
+ new_coords = {}
+ if utils.is_dict_like(coords):
+ for k, v in coords.items():
+ new_coords[k] = as_variable(v, name=k, auto_convert=False)
+ if new_coords[k].dims == (k,):
+ new_coords[k] = new_coords[k].to_index_variable()
+ elif coords is not None:
+ for dim, coord in zip(dims_tuple, coords):
+ var = as_variable(coord, name=dim, auto_convert=False)
+ var.dims = (dim,)
+ new_coords[dim] = var.to_index_variable()
+
+ _check_coords_dims(shape, new_coords, dims_tuple)
+
+ return new_coords, dims_tuple
+
+
+def _check_data_shape(
+ data: Any,
+ coords: (
+ Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray]
+ | Mapping
+ | None
+ ),
+ dims: str | Iterable[Hashable] | None,
+) -> Any:
+ if data is dtypes.NA:
+ data = np.nan
+ if coords is not None and utils.is_scalar(data, include_0d=False):
+ if utils.is_dict_like(coords):
+ if dims is None:
+ return data
+ else:
+ data_shape = tuple(
+ (
+ as_variable(coords[k], k, auto_convert=False).size
+ if k in coords.keys()
+ else 1
+ )
+ for k in dims
+ )
+ else:
+ data_shape = tuple(
+ as_variable(coord, "foo", auto_convert=False).size for coord in coords
+ )
+ data = np.full(data_shape, data)
+ return data
class _LocIndexer(Generic[T_DataArray]):
- __slots__ = 'data_array',
+ __slots__ = ("data_array",)
def __init__(self, data_array: T_DataArray):
self.data_array = data_array
- def __getitem__(self, key) ->T_DataArray:
+ def __getitem__(self, key) -> T_DataArray:
if not utils.is_dict_like(key):
+ # expand the indexer so we can handle Ellipsis
labels = indexing.expanded_indexer(key, self.data_array.ndim)
key = dict(zip(self.data_array.dims, labels))
return self.data_array.sel(key)
- def __setitem__(self, key, value) ->None:
+ def __setitem__(self, key, value) -> None:
if not utils.is_dict_like(key):
+ # expand the indexer so we can handle Ellipsis
labels = indexing.expanded_indexer(key, self.data_array.ndim)
key = dict(zip(self.data_array.dims, labels))
+
dim_indexers = map_index_queries(self.data_array, key).dim_indexers
self.data_array[dim_indexers] = value
-_THIS_ARRAY = ReprObject('<this-array>')
+# Used as the key corresponding to a DataArray's variable when converting
+# arbitrary DataArray objects to datasets
+_THIS_ARRAY = ReprObject("<this-array>")
-class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
- DataArrayAggregations):
+class DataArray(
+ AbstractArray,
+ DataWithCoords,
+ DataArrayArithmetic,
+ DataArrayAggregations,
+):
"""N-dimensional array with labeled coordinates and dimensions.
DataArray provides a wrapper around numpy ndarrays that uses
@@ -211,21 +401,41 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
description: Ambient temperature.
units: degC
"""
+
_cache: dict[str, Any]
_coords: dict[Any, Variable]
_close: Callable[[], None] | None
_indexes: dict[Hashable, Index]
_name: Hashable | None
_variable: Variable
- __slots__ = ('_cache', '_coords', '_close', '_indexes', '_name',
- '_variable', '__weakref__')
- dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor['DataArray'])
-
- def __init__(self, data: Any=dtypes.NA, coords: (Sequence[Sequence | pd
- .Index | DataArray | Variable | np.ndarray] | Mapping | None)=None,
- dims: (str | Iterable[Hashable] | None)=None, name: (Hashable |
- None)=None, attrs: (Mapping | None)=None, indexes: (Mapping[Any,
- Index] | None)=None, fastpath: bool=False) ->None:
+
+ __slots__ = (
+ "_cache",
+ "_coords",
+ "_close",
+ "_indexes",
+ "_name",
+ "_variable",
+ "__weakref__",
+ )
+
+ dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor["DataArray"])
+
+ def __init__(
+ self,
+ data: Any = dtypes.NA,
+ coords: (
+ Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray]
+ | Mapping
+ | None
+ ) = None,
+ dims: str | Iterable[Hashable] | None = None,
+ name: Hashable | None = None,
+ attrs: Mapping | None = None,
+ # internal parameters
+ indexes: Mapping[Any, Index] | None = None,
+ fastpath: bool = False,
+ ) -> None:
if fastpath:
variable = data
assert dims is None
@@ -234,8 +444,11 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
else:
if indexes is not None:
raise ValueError(
- 'Explicitly passing indexes via the `indexes` argument is not supported when `fastpath=False`. Use the `coords` argument instead.'
- )
+ "Explicitly passing indexes via the `indexes` argument is not supported "
+ "when `fastpath=False`. Use the `coords` argument instead."
+ )
+
+ # try to fill in arguments from data if they weren't supplied
if coords is None:
if isinstance(data, DataArray):
coords = data.coords
@@ -245,47 +458,206 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
coords = [data.index, data.columns]
elif isinstance(data, (pd.Index, IndexVariable)):
coords = [data]
+
if dims is None:
- dims = getattr(data, 'dims', getattr(coords, 'dims', None))
+ dims = getattr(data, "dims", getattr(coords, "dims", None))
if name is None:
- name = getattr(data, 'name', None)
+ name = getattr(data, "name", None)
if attrs is None and not isinstance(data, PANDAS_TYPES):
- attrs = getattr(data, 'attrs', None)
+ attrs = getattr(data, "attrs", None)
+
data = _check_data_shape(data, coords, dims)
data = as_compatible_data(data)
coords, dims = _infer_coords_and_dims(data.shape, coords, dims)
variable = Variable(dims, data, attrs, fastpath=True)
+
if not isinstance(coords, Coordinates):
coords = create_coords_with_default_indexes(coords)
indexes = dict(coords.xindexes)
coords = {k: v.copy() for k, v in coords.variables.items()}
+
+ # These fully describe a DataArray
self._variable = variable
assert isinstance(coords, dict)
self._coords = coords
self._name = name
- self._indexes = indexes
+ self._indexes = indexes # type: ignore[assignment]
+
self._close = None
@classmethod
- def _construct_direct(cls, variable: Variable, coords: dict[Any,
- Variable], name: Hashable, indexes: dict[Hashable, Index]) ->Self:
+ def _construct_direct(
+ cls,
+ variable: Variable,
+ coords: dict[Any, Variable],
+ name: Hashable,
+ indexes: dict[Hashable, Index],
+ ) -> Self:
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
- pass
-
- def _overwrite_indexes(self, indexes: Mapping[Any, Index], variables: (
- Mapping[Any, Variable] | None)=None, drop_coords: (list[Hashable] |
- None)=None, rename_dims: (Mapping[Any, Any] | None)=None) ->Self:
+ obj = object.__new__(cls)
+ obj._variable = variable
+ obj._coords = coords
+ obj._name = name
+ obj._indexes = indexes
+ obj._close = None
+ return obj
+
+ def _replace(
+ self,
+ variable: Variable | None = None,
+ coords=None,
+ name: Hashable | None | Default = _default,
+ indexes=None,
+ ) -> Self:
+ if variable is None:
+ variable = self.variable
+ if coords is None:
+ coords = self._coords
+ if indexes is None:
+ indexes = self._indexes
+ if name is _default:
+ name = self.name
+ return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True)
+
+ def _replace_maybe_drop_dims(
+ self,
+ variable: Variable,
+ name: Hashable | None | Default = _default,
+ ) -> Self:
+ if variable.dims == self.dims and variable.shape == self.shape:
+ coords = self._coords.copy()
+ indexes = self._indexes
+ elif variable.dims == self.dims:
+ # Shape has changed (e.g. from reduce(..., keepdims=True)
+ new_sizes = dict(zip(self.dims, variable.shape))
+ coords = {
+ k: v
+ for k, v in self._coords.items()
+ if v.shape == tuple(new_sizes[d] for d in v.dims)
+ }
+ indexes = filter_indexes_from_coords(self._indexes, set(coords))
+ else:
+ allowed_dims = set(variable.dims)
+ coords = {
+ k: v for k, v in self._coords.items() if set(v.dims) <= allowed_dims
+ }
+ indexes = filter_indexes_from_coords(self._indexes, set(coords))
+ return self._replace(variable, coords, name, indexes=indexes)
+
+ def _overwrite_indexes(
+ self,
+ indexes: Mapping[Any, Index],
+ variables: Mapping[Any, Variable] | None = None,
+ drop_coords: list[Hashable] | None = None,
+ rename_dims: Mapping[Any, Any] | None = None,
+ ) -> Self:
"""Maybe replace indexes and their corresponding coordinates."""
- pass
+ if not indexes:
+ return self
+
+ if variables is None:
+ variables = {}
+ if drop_coords is None:
+ drop_coords = []
+
+ new_variable = self.variable.copy()
+ new_coords = self._coords.copy()
+ new_indexes = dict(self._indexes)
+
+ for name in indexes:
+ new_coords[name] = variables[name]
+ new_indexes[name] = indexes[name]
- def _to_dataset_split(self, dim: Hashable) ->Dataset:
+ for name in drop_coords:
+ new_coords.pop(name)
+ new_indexes.pop(name)
+
+ if rename_dims:
+ new_variable.dims = tuple(rename_dims.get(d, d) for d in new_variable.dims)
+
+ return self._replace(
+ variable=new_variable, coords=new_coords, indexes=new_indexes
+ )
+
+ def _to_temp_dataset(self) -> Dataset:
+ return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False)
+
+ def _from_temp_dataset(
+ self, dataset: Dataset, name: Hashable | None | Default = _default
+ ) -> Self:
+ variable = dataset._variables.pop(_THIS_ARRAY)
+ coords = dataset._variables
+ indexes = dataset._indexes
+ return self._replace(variable, coords, name, indexes=indexes)
+
+ def _to_dataset_split(self, dim: Hashable) -> Dataset:
"""splits dataarray along dimension 'dim'"""
- pass
- def to_dataset(self, dim: Hashable=None, *, name: Hashable=None,
- promote_attrs: bool=False) ->Dataset:
+ def subset(dim, label):
+ array = self.loc[{dim: label}]
+ array.attrs = {}
+ return as_variable(array)
+
+ variables_from_split = {
+ label: subset(dim, label) for label in self.get_index(dim)
+ }
+ coord_names = set(self._coords) - {dim}
+
+ ambiguous_vars = set(variables_from_split) & coord_names
+ if ambiguous_vars:
+ rename_msg_fmt = ", ".join([f"{v}=..." for v in sorted(ambiguous_vars)])
+ raise ValueError(
+ f"Splitting along the dimension {dim!r} would produce the variables "
+ f"{tuple(sorted(ambiguous_vars))} which are also existing coordinate "
+ f"variables. Use DataArray.rename({rename_msg_fmt}) or "
+ f"DataArray.assign_coords({dim}=...) to resolve this ambiguity."
+ )
+
+ variables = variables_from_split | {
+ k: v for k, v in self._coords.items() if k != dim
+ }
+ indexes = filter_indexes_from_coords(self._indexes, coord_names)
+ dataset = Dataset._construct_direct(
+ variables, coord_names, indexes=indexes, attrs=self.attrs
+ )
+ return dataset
+
+ def _to_dataset_whole(
+ self, name: Hashable = None, shallow_copy: bool = True
+ ) -> Dataset:
+ if name is None:
+ name = self.name
+ if name is None:
+ raise ValueError(
+ "unable to convert unnamed DataArray to a "
+ "Dataset without providing an explicit name"
+ )
+ if name in self.coords:
+ raise ValueError(
+ "cannot create a Dataset from a DataArray with "
+ "the same name as one of its coordinates"
+ )
+ # use private APIs for speed: this is called by _to_temp_dataset(),
+ # which is used in the guts of a lot of operations (e.g., reindex)
+ variables = self._coords.copy()
+ variables[name] = self.variable
+ if shallow_copy:
+ for k in variables:
+ variables[k] = variables[k].copy(deep=False)
+ indexes = self._indexes
+
+ coord_names = set(self._coords)
+ return Dataset._construct_direct(variables, coord_names, indexes=indexes)
+
+ def to_dataset(
+ self,
+ dim: Hashable = None,
+ *,
+ name: Hashable = None,
+ promote_attrs: bool = False,
+ ) -> Dataset:
"""Convert a DataArray to a Dataset.
Parameters
@@ -304,20 +676,39 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
-------
dataset : Dataset
"""
- pass
+ if dim is not None and dim not in self.dims:
+ raise TypeError(
+ f"{dim} is not a dim. If supplying a ``name``, pass as a kwarg."
+ )
+
+ if dim is not None:
+ if name is not None:
+ raise TypeError("cannot supply both dim and name arguments")
+ result = self._to_dataset_split(dim)
+ else:
+ result = self._to_dataset_whole(name)
+
+ if promote_attrs:
+ result.attrs = dict(self.attrs)
+
+ return result
@property
- def name(self) ->(Hashable | None):
+ def name(self) -> Hashable | None:
"""The name of this array."""
- pass
+ return self._name
+
+ @name.setter
+ def name(self, value: Hashable | None) -> None:
+ self._name = value
@property
- def variable(self) ->Variable:
+ def variable(self) -> Variable:
"""Low level interface to the Variable object for this DataArray."""
- pass
+ return self._variable
@property
- def dtype(self) ->np.dtype:
+ def dtype(self) -> np.dtype:
"""
Data-type of the array’s elements.
@@ -326,10 +717,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
ndarray.dtype
numpy.dtype
"""
- pass
+ return self.variable.dtype
@property
- def shape(self) ->tuple[int, ...]:
+ def shape(self) -> tuple[int, ...]:
"""
Tuple of array dimensions.
@@ -337,10 +728,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
numpy.ndarray.shape
"""
- pass
+ return self.variable.shape
@property
- def size(self) ->int:
+ def size(self) -> int:
"""
Number of elements in the array.
@@ -350,20 +741,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
numpy.ndarray.size
"""
- pass
+ return self.variable.size
@property
- def nbytes(self) ->int:
+ def nbytes(self) -> int:
"""
Total bytes consumed by the elements of this DataArray's data.
If the underlying data array does not include ``nbytes``, estimates
the bytes consumed based on the ``size`` and ``dtype``.
"""
- pass
+ return self.variable.nbytes
@property
- def ndim(self) ->int:
+ def ndim(self) -> int:
"""
Number of array dimensions.
@@ -371,13 +762,13 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
numpy.ndarray.ndim
"""
- pass
+ return self.variable.ndim
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self.variable)
@property
- def data(self) ->Any:
+ def data(self) -> Any:
"""
The DataArray's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.
@@ -388,10 +779,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.as_numpy
DataArray.values
"""
- pass
+ return self.variable.data
+
+ @data.setter
+ def data(self, value: Any) -> None:
+ self.variable.data = value
@property
- def values(self) ->np.ndarray:
+ def values(self) -> np.ndarray:
"""
The array's data converted to numpy.ndarray.
@@ -403,9 +798,13 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
numpy's rules of what generates a view vs. a copy, and changes
to this array may be reflected in the DataArray as well.
"""
- pass
+ return self.variable.values
+
+ @values.setter
+ def values(self, value: Any) -> None:
+ self.variable.values = value
- def to_numpy(self) ->np.ndarray:
+ def to_numpy(self) -> np.ndarray:
"""
Coerces wrapped data to numpy and returns a numpy.ndarray.
@@ -416,9 +815,9 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.values
DataArray.data
"""
- pass
+ return self.variable.to_numpy()
- def as_numpy(self) ->Self:
+ def as_numpy(self) -> Self:
"""
Coerces wrapped data and coordinates into numpy arrays, returning a DataArray.
@@ -429,16 +828,24 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.values
DataArray.data
"""
- pass
+ coords = {k: v.as_numpy() for k, v in self._coords.items()}
+ return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes)
+
+ @property
+ def _in_memory(self) -> bool:
+ return self.variable._in_memory
- def to_index(self) ->pd.Index:
+ def _to_index(self) -> pd.Index:
+ return self.variable._to_index()
+
+ def to_index(self) -> pd.Index:
"""Convert this variable to a pandas.Index. Only possible for 1D
arrays.
"""
- pass
+ return self.variable.to_index()
@property
- def dims(self) ->tuple[Hashable, ...]:
+ def dims(self) -> tuple[Hashable, ...]:
"""Tuple of dimension names associated with this array.
Note that the type of this property is inconsistent with
@@ -450,65 +857,116 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.sizes
Dataset.dims
"""
- pass
+ return self.variable.dims
+
+ @dims.setter
+ def dims(self, value: Any) -> NoReturn:
+ raise AttributeError(
+ "you cannot assign dims on a DataArray. Use "
+ ".rename() or .swap_dims() instead."
+ )
+
+ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]:
+ if utils.is_dict_like(key):
+ return key
+ key = indexing.expanded_indexer(key, self.ndim)
+ return dict(zip(self.dims, key))
+
+ def _getitem_coord(self, key: Any) -> Self:
+ from xarray.core.dataset import _get_virtual_variable
- def __getitem__(self, key: Any) ->Self:
+ try:
+ var = self._coords[key]
+ except KeyError:
+ dim_sizes = dict(zip(self.dims, self.shape))
+ _, key, var = _get_virtual_variable(self._coords, key, dim_sizes)
+
+ return self._replace_maybe_drop_dims(var, name=key)
+
+ def __getitem__(self, key: Any) -> Self:
if isinstance(key, str):
return self._getitem_coord(key)
else:
+ # xarray-style array indexing
return self.isel(indexers=self._item_key_to_dict(key))
- def __setitem__(self, key: Any, value: Any) ->None:
+ def __setitem__(self, key: Any, value: Any) -> None:
if isinstance(key, str):
self.coords[key] = value
else:
+ # Coordinates in key, value and self[key] should be consistent.
+ # TODO Coordinate consistency in key is checked here, but it
+ # causes unnecessary indexing. It should be optimized.
obj = self[key]
if isinstance(value, DataArray):
assert_coordinate_consistent(value, obj.coords.variables)
value = value.variable
- key = {k: (v.variable if isinstance(v, DataArray) else v) for k,
- v in self._item_key_to_dict(key).items()}
+ # DataArray key -> Variable key
+ key = {
+ k: v.variable if isinstance(v, DataArray) else v
+ for k, v in self._item_key_to_dict(key).items()
+ }
self.variable[key] = value
- def __delitem__(self, key: Any) ->None:
+ def __delitem__(self, key: Any) -> None:
del self.coords[key]
@property
- def _attr_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for attribute-style access"""
- pass
+ yield from self._item_sources
+ yield self.attrs
@property
- def _item_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for key-completion"""
- pass
+ yield HybridMappingProxy(keys=self._coords, mapping=self.coords)
+
+ # virtual coordinates
+ # uses empty dict -- everything here can already be found in self.coords.
+ yield HybridMappingProxy(keys=self.dims, mapping={})
- def __contains__(self, key: Any) ->bool:
+ def __contains__(self, key: Any) -> bool:
return key in self.data
@property
- def loc(self) ->_LocIndexer:
+ def loc(self) -> _LocIndexer:
"""Attribute for location based indexing like pandas."""
- pass
+ return _LocIndexer(self)
@property
- def attrs(self) ->dict[Any, Any]:
+ def attrs(self) -> dict[Any, Any]:
"""Dictionary storing arbitrary metadata with this array."""
- pass
+ return self.variable.attrs
+
+ @attrs.setter
+ def attrs(self, value: Mapping[Any, Any]) -> None:
+ self.variable.attrs = dict(value)
@property
- def encoding(self) ->dict[Any, Any]:
+ def encoding(self) -> dict[Any, Any]:
"""Dictionary of format-specific settings for how this array should be
serialized."""
- pass
+ return self.variable.encoding
+
+ @encoding.setter
+ def encoding(self, value: Mapping[Any, Any]) -> None:
+ self.variable.encoding = dict(value)
- def drop_encoding(self) ->Self:
+ def reset_encoding(self) -> Self:
+ warnings.warn(
+ "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead"
+ )
+ return self.drop_encoding()
+
+ def drop_encoding(self) -> Self:
"""Return a new DataArray without encoding on the array or any attached
coords."""
- pass
+ ds = self._to_temp_dataset().drop_encoding()
+ return self._from_temp_dataset(ds)
@property
- def indexes(self) ->Indexes:
+ def indexes(self) -> Indexes:
"""Mapping of pandas.Index objects used for label based indexing.
Raises an error if this Dataset has indexes that cannot be coerced
@@ -519,17 +977,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.xindexes
"""
- pass
+ return self.xindexes.to_pandas_indexes()
@property
- def xindexes(self) ->Indexes[Index]:
+ def xindexes(self) -> Indexes[Index]:
"""Mapping of :py:class:`~xarray.indexes.Index` objects
used for label based indexing.
"""
- pass
+ return Indexes(self._indexes, {k: self._coords[k] for k in self._indexes})
@property
- def coords(self) ->DataArrayCoordinates:
+ def coords(self) -> DataArrayCoordinates:
"""Mapping of :py:class:`~xarray.DataArray` objects corresponding to
coordinate variables.
@@ -537,11 +995,31 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
Coordinates
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def reset_coords(self, names: Dims=None, *, drop: bool=False) ->(Self |
- Dataset):
+ return DataArrayCoordinates(self)
+
+ @overload
+ def reset_coords(
+ self,
+ names: Dims = None,
+ *,
+ drop: Literal[False] = False,
+ ) -> Dataset: ...
+
+ @overload
+ def reset_coords(
+ self,
+ names: Dims = None,
+ *,
+ drop: Literal[True],
+ ) -> Self: ...
+
+ @_deprecate_positional_args("v2023.10.0")
+ def reset_coords(
+ self,
+ names: Dims = None,
+ *,
+ drop: bool = False,
+ ) -> Self | Dataset:
"""Given names of coordinates, reset them to become variables.
Parameters
@@ -611,12 +1089,22 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
lat (y) int64 40B 20 21 22 23 24
Dimensions without coordinates: x, y
"""
- pass
-
- def __dask_tokenize__(self) ->object:
+ if names is None:
+ names = set(self.coords) - set(self._indexes)
+ dataset = self.coords.to_dataset().reset_coords(names, drop)
+ if drop:
+ return self._replace(coords=dataset._variables)
+ if self.name is None:
+ raise ValueError(
+ "cannot reset_coords with drop=False on an unnamed DataArrray"
+ )
+ dataset[self.name] = self.variable
+ return dataset
+
+ def __dask_tokenize__(self) -> object:
from dask.base import normalize_token
- return normalize_token((type(self), self._variable, self._coords,
- self._name))
+
+ return normalize_token((type(self), self._variable, self._coords, self._name))
def __dask_graph__(self):
return self._to_temp_dataset().__dask_graph__()
@@ -643,7 +1131,15 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
func, args = self._to_temp_dataset().__dask_postpersist__()
return self._dask_finalize, (self.name, func) + args
- def load(self, **kwargs) ->Self:
+ @classmethod
+ def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self:
+ ds = func(results, *args, **kwargs)
+ variable = ds._variables.pop(_THIS_ARRAY)
+ coords = ds._variables
+ indexes = ds._indexes
+ return cls(variable, coords, name=name, indexes=indexes, fastpath=True)
+
+ def load(self, **kwargs) -> Self:
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return this array.
@@ -663,9 +1159,13 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
dask.compute
"""
- pass
+ ds = self._to_temp_dataset().load(**kwargs)
+ new = self._from_temp_dataset(ds)
+ self._variable = new._variable
+ self._coords = new._coords
+ return self
- def compute(self, **kwargs) ->Self:
+ def compute(self, **kwargs) -> Self:
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array.
@@ -690,9 +1190,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
dask.compute
"""
- pass
+ new = self.copy(deep=False)
+ return new.load(**kwargs)
- def persist(self, **kwargs) ->Self:
+ def persist(self, **kwargs) -> Self:
"""Trigger computation in constituent dask arrays
This keeps them as dask arrays but encourages them to keep data in
@@ -714,9 +1215,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
dask.persist
"""
- pass
+ ds = self._to_temp_dataset().persist(**kwargs)
+ return self._from_temp_dataset(ds)
- def copy(self, deep: bool=True, data: Any=None) ->Self:
+ def copy(self, deep: bool = True, data: Any = None) -> Self:
"""Returns a copy of this array.
If `deep=True`, a deep copy is made of the data array.
@@ -784,17 +1286,38 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
pandas.DataFrame.copy
"""
- pass
+ return self._copy(deep=deep, data=data)
+
+ def _copy(
+ self,
+ deep: bool = True,
+ data: Any = None,
+ memo: dict[int, Any] | None = None,
+ ) -> Self:
+ variable = self.variable._copy(deep=deep, data=data, memo=memo)
+ indexes, index_vars = self.xindexes.copy_indexes(deep=deep)
+
+ coords = {}
+ for k, v in self._coords.items():
+ if k in index_vars:
+ coords[k] = index_vars[k]
+ else:
+ coords[k] = v._copy(deep=deep, memo=memo)
- def __copy__(self) ->Self:
+ return self._replace(variable, coords, indexes=indexes)
+
+ def __copy__(self) -> Self:
return self._copy(deep=False)
- def __deepcopy__(self, memo: (dict[int, Any] | None)=None) ->Self:
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
return self._copy(deep=True, memo=memo)
- __hash__ = None
+
+ # mutable objects should not be Hashable
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore[assignment]
@property
- def chunks(self) ->(tuple[tuple[int, ...], ...] | None):
+ def chunks(self) -> tuple[tuple[int, ...], ...] | None:
"""
Tuple of block lengths for this dataarray's data, in order of dimensions, or None if
the underlying data is not a dask array.
@@ -805,10 +1328,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.chunksizes
xarray.unify_chunks
"""
- pass
+ return self.variable.chunks
@property
- def chunksizes(self) ->Mapping[Any, tuple[int, ...]]:
+ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]:
"""
Mapping from dimension names to block lengths for this dataarray's data, or None if
the underlying data is not a dask array.
@@ -823,13 +1346,22 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.chunks
xarray.unify_chunks
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def chunk(self, chunks: T_ChunksFreq={}, *, name_prefix: str='xarray-',
- token: (str | None)=None, lock: bool=False, inline_array: bool=
- False, chunked_array_type: (str | ChunkManagerEntrypoint | None)=
- None, from_array_kwargs=None, **chunks_kwargs: T_ChunkDimFreq) ->Self:
+ all_variables = [self.variable] + [c.variable for c in self.coords.values()]
+ return get_chunksizes(all_variables)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def chunk(
+ self,
+ chunks: T_ChunksFreq = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
+ *,
+ name_prefix: str = "xarray-",
+ token: str | None = None,
+ lock: bool = False,
+ inline_array: bool = False,
+ chunked_array_type: str | ChunkManagerEntrypoint | None = None,
+ from_array_kwargs=None,
+ **chunks_kwargs: T_ChunkDimFreq,
+ ) -> Self:
"""Coerce this array's data into a dask arrays with the given chunks.
If this variable is a non-dask array, it will be converted to dask
@@ -881,11 +1413,46 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
xarray.unify_chunks
dask.array.from_array
"""
- pass
-
- def isel(self, indexers: (Mapping[Any, Any] | None)=None, drop: bool=
- False, missing_dims: ErrorOptionsWithWarn='raise', **
- indexers_kwargs: Any) ->Self:
+ chunk_mapping: T_ChunksFreq
+ if chunks is None:
+ warnings.warn(
+ "None value for 'chunks' is deprecated. "
+ "It will raise an error in the future. Use instead '{}'",
+ category=FutureWarning,
+ )
+ chunk_mapping = {}
+
+ if isinstance(chunks, (float, str, int)):
+ # ignoring type; unclear why it won't accept a Literal into the value.
+ chunk_mapping = dict.fromkeys(self.dims, chunks)
+ elif isinstance(chunks, (tuple, list)):
+ utils.emit_user_level_warning(
+ "Supplying chunks as dimension-order tuples is deprecated. "
+ "It will raise an error in the future. Instead use a dict with dimension names as keys.",
+ category=DeprecationWarning,
+ )
+ chunk_mapping = dict(zip(self.dims, chunks))
+ else:
+ chunk_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
+
+ ds = self._to_temp_dataset().chunk(
+ chunk_mapping,
+ name_prefix=name_prefix,
+ token=token,
+ lock=lock,
+ inline_array=inline_array,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=from_array_kwargs,
+ )
+ return self._from_temp_dataset(ds)
+
+ def isel(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ drop: bool = False,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Return a new DataArray whose data is given by selecting indexes
along the specified dimension(s).
@@ -945,11 +1512,45 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([ 0, 6, 12, 18, 24])
Dimensions without coordinates: points
"""
- pass
- def sel(self, indexers: (Mapping[Any, Any] | None)=None, method: (str |
- None)=None, tolerance=None, drop: bool=False, **indexers_kwargs: Any
- ) ->Self:
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
+
+ if any(is_fancy_indexer(idx) for idx in indexers.values()):
+ ds = self._to_temp_dataset()._isel_fancy(
+ indexers, drop=drop, missing_dims=missing_dims
+ )
+ return self._from_temp_dataset(ds)
+
+ # Much faster algorithm for when all indexers are ints, slices, one-dimensional
+ # lists, or zero or one-dimensional np.ndarray's
+
+ variable = self._variable.isel(indexers, missing_dims=missing_dims)
+ indexes, index_variables = isel_indexes(self.xindexes, indexers)
+
+ coords = {}
+ for coord_name, coord_value in self._coords.items():
+ if coord_name in index_variables:
+ coord_value = index_variables[coord_name]
+ else:
+ coord_indexers = {
+ k: v for k, v in indexers.items() if k in coord_value.dims
+ }
+ if coord_indexers:
+ coord_value = coord_value.isel(coord_indexers)
+ if drop and coord_value.ndim == 0:
+ continue
+ coords[coord_name] = coord_value
+
+ return self._replace(variable=variable, coords=coords, indexes=indexes)
+
+ def sel(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ method: str | None = None,
+ tolerance=None,
+ drop: bool = False,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Return a new DataArray whose data is given by selecting index
labels along the specified dimension(s).
@@ -1058,10 +1659,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
y (points) int64 40B 0 1 2 3 4
Dimensions without coordinates: points
"""
- pass
-
- def head(self, indexers: (Mapping[Any, int] | int | None)=None, **
- indexers_kwargs: Any) ->Self:
+ ds = self._to_temp_dataset().sel(
+ indexers=indexers,
+ drop=drop,
+ method=method,
+ tolerance=tolerance,
+ **indexers_kwargs,
+ )
+ return self._from_temp_dataset(ds)
+
+ def head(
+ self,
+ indexers: Mapping[Any, int] | int | None = None,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Return a new DataArray whose data is given by the the first `n`
values along the specified dimension(s). Default `n` = 5
@@ -1097,10 +1708,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
[5, 6]])
Dimensions without coordinates: x, y
"""
- pass
+ ds = self._to_temp_dataset().head(indexers, **indexers_kwargs)
+ return self._from_temp_dataset(ds)
- def tail(self, indexers: (Mapping[Any, int] | int | None)=None, **
- indexers_kwargs: Any) ->Self:
+ def tail(
+ self,
+ indexers: Mapping[Any, int] | int | None = None,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Return a new DataArray whose data is given by the the last `n`
values along the specified dimension(s). Default `n` = 5
@@ -1140,10 +1755,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
[23, 24]])
Dimensions without coordinates: x, y
"""
- pass
+ ds = self._to_temp_dataset().tail(indexers, **indexers_kwargs)
+ return self._from_temp_dataset(ds)
- def thin(self, indexers: (Mapping[Any, int] | int | None)=None, **
- indexers_kwargs: Any) ->Self:
+ def thin(
+ self,
+ indexers: Mapping[Any, int] | int | None = None,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Return a new DataArray whose data is given by each `n` value
along the specified dimension(s).
@@ -1186,11 +1805,16 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.head
DataArray.tail
"""
- pass
+ ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs)
+ return self._from_temp_dataset(ds)
- @_deprecate_positional_args('v2023.10.0')
- def broadcast_like(self, other: T_DataArrayOrSet, *, exclude: (Iterable
- [Hashable] | None)=None) ->Self:
+ @_deprecate_positional_args("v2023.10.0")
+ def broadcast_like(
+ self,
+ other: T_DataArrayOrSet,
+ *,
+ exclude: Iterable[Hashable] | None = None,
+ ) -> Self:
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]
@@ -1252,20 +1876,61 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* x (x) <U1 12B 'a' 'b' 'c'
* y (y) <U1 12B 'a' 'b' 'c'
"""
- pass
-
- def _reindex_callback(self, aligner: alignment.Aligner,
- dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable,
- Variable], indexes: dict[Hashable, Index], fill_value: Any,
- exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable]
- ) ->Self:
+ if exclude is None:
+ exclude = set()
+ else:
+ exclude = set(exclude)
+ args = align(other, self, join="outer", copy=False, exclude=exclude)
+
+ dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude)
+
+ return _broadcast_helper(args[1], exclude, dims_map, common_coords)
+
+ def _reindex_callback(
+ self,
+ aligner: alignment.Aligner,
+ dim_pos_indexers: dict[Hashable, Any],
+ variables: dict[Hashable, Variable],
+ indexes: dict[Hashable, Index],
+ fill_value: Any,
+ exclude_dims: frozenset[Hashable],
+ exclude_vars: frozenset[Hashable],
+ ) -> Self:
"""Callback called from ``Aligner`` to create a new reindexed DataArray."""
- pass
- @_deprecate_positional_args('v2023.10.0')
- def reindex_like(self, other: T_DataArrayOrSet, *, method:
- ReindexMethodOptions=None, tolerance: (float | Iterable[float] |
- str | None)=None, copy: bool=True, fill_value=dtypes.NA) ->Self:
+ if isinstance(fill_value, dict):
+ fill_value = fill_value.copy()
+ sentinel = object()
+ value = fill_value.pop(self.name, sentinel)
+ if value is not sentinel:
+ fill_value[_THIS_ARRAY] = value
+
+ ds = self._to_temp_dataset()
+ reindexed = ds._reindex_callback(
+ aligner,
+ dim_pos_indexers,
+ variables,
+ indexes,
+ fill_value,
+ exclude_dims,
+ exclude_vars,
+ )
+
+ da = self._from_temp_dataset(reindexed)
+ da.encoding = self.encoding
+
+ return da
+
+ @_deprecate_positional_args("v2023.10.0")
+ def reindex_like(
+ self,
+ other: T_DataArrayOrSet,
+ *,
+ method: ReindexMethodOptions = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value=dtypes.NA,
+ ) -> Self:
"""
Conform this object onto the indexes of another object, for indexes which the
objects share. Missing values are filled with ``fill_value``. The default fill
@@ -1434,13 +2099,26 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.broadcast_like
align
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def reindex(self, indexers: (Mapping[Any, Any] | None)=None, *, method:
- ReindexMethodOptions=None, tolerance: (float | Iterable[float] |
- str | None)=None, copy: bool=True, fill_value=dtypes.NA, **
- indexers_kwargs: Any) ->Self:
+ return alignment.reindex_like(
+ self,
+ other=other,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def reindex(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ *,
+ method: ReindexMethodOptions = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value=dtypes.NA,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Conform this object onto the indexes of another object, filling in
missing values with ``fill_value``. The default fill value is NaN.
@@ -1513,11 +2191,24 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.reindex_like
align
"""
- pass
-
- def interp(self, coords: (Mapping[Any, Any] | None)=None, method:
- InterpOptions='linear', assume_sorted: bool=False, kwargs: (Mapping
- [str, Any] | None)=None, **coords_kwargs: Any) ->Self:
+ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")
+ return alignment.reindex(
+ self,
+ indexers=indexers,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ )
+
+ def interp(
+ self,
+ coords: Mapping[Any, Any] | None = None,
+ method: InterpOptions = "linear",
+ assume_sorted: bool = False,
+ kwargs: Mapping[str, Any] | None = None,
+ **coords_kwargs: Any,
+ ) -> Self:
"""Interpolate a DataArray onto new coordinates
Performs univariate or multivariate interpolation of a DataArray onto
@@ -1643,11 +2334,26 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* x (x) float64 32B 0.0 0.75 1.25 1.75
* y (y) int64 24B 11 13 15
"""
- pass
-
- def interp_like(self, other: T_Xarray, method: InterpOptions='linear',
- assume_sorted: bool=False, kwargs: (Mapping[str, Any] | None)=None
- ) ->Self:
+ if self.dtype.kind not in "uifc":
+ raise TypeError(
+ f"interp only works for a numeric type array. Given {self.dtype}."
+ )
+ ds = self._to_temp_dataset().interp(
+ coords,
+ method=method,
+ kwargs=kwargs,
+ assume_sorted=assume_sorted,
+ **coords_kwargs,
+ )
+ return self._from_temp_dataset(ds)
+
+ def interp_like(
+ self,
+ other: T_Xarray,
+ method: InterpOptions = "linear",
+ assume_sorted: bool = False,
+ kwargs: Mapping[str, Any] | None = None,
+ ) -> Self:
"""Interpolate this object onto the coordinates of another object,
filling out of range values with NaN.
@@ -1755,10 +2461,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.interp
DataArray.reindex_like
"""
- pass
-
- def rename(self, new_name_or_name_dict: (Hashable | Mapping[Any,
- Hashable] | None)=None, **names: Hashable) ->Self:
+ if self.dtype.kind not in "uifc":
+ raise TypeError(
+ f"interp only works for a numeric type array. Given {self.dtype}."
+ )
+ ds = self._to_temp_dataset().interp_like(
+ other, method=method, kwargs=kwargs, assume_sorted=assume_sorted
+ )
+ return self._from_temp_dataset(ds)
+
+ def rename(
+ self,
+ new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None,
+ **names: Hashable,
+ ) -> Self:
"""Returns a new DataArray with renamed coordinates, dimensions or a new name.
Parameters
@@ -1782,10 +2498,27 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Dataset.rename
DataArray.swap_dims
"""
- pass
-
- def swap_dims(self, dims_dict: (Mapping[Any, Hashable] | None)=None, **
- dims_kwargs) ->Self:
+ if new_name_or_name_dict is None and not names:
+ # change name to None?
+ return self._replace(name=None)
+ if utils.is_dict_like(new_name_or_name_dict) or new_name_or_name_dict is None:
+ # change dims/coords
+ name_dict = either_dict_or_kwargs(new_name_or_name_dict, names, "rename")
+ dataset = self._to_temp_dataset()._rename(name_dict)
+ return self._from_temp_dataset(dataset)
+ if utils.hashable(new_name_or_name_dict) and names:
+ # change name + dims/coords
+ dataset = self._to_temp_dataset()._rename(names)
+ dataarray = self._from_temp_dataset(dataset)
+ return dataarray._replace(name=new_name_or_name_dict)
+ # only change name
+ return self._replace(name=new_name_or_name_dict)
+
+ def swap_dims(
+ self,
+ dims_dict: Mapping[Any, Hashable] | None = None,
+ **dims_kwargs,
+ ) -> Self:
"""Returns a new DataArray with swapped dimensions.
Parameters
@@ -1836,11 +2569,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.rename
Dataset.swap_dims
"""
- pass
-
- def expand_dims(self, dim: (None | Hashable | Sequence[Hashable] |
- Mapping[Any, Any])=None, axis: (None | int | Sequence[int])=None,
- create_index_for_new_dim: bool=True, **dim_kwargs: Any) ->Self:
+ dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
+ ds = self._to_temp_dataset().swap_dims(dims_dict)
+ return self._from_temp_dataset(ds)
+
+ def expand_dims(
+ self,
+ dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
+ axis: None | int | Sequence[int] = None,
+ create_index_for_new_dim: bool = True,
+ **dim_kwargs: Any,
+ ) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape. The new object is a
view into the underlying array, not a copy.
@@ -1921,11 +2660,27 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* y (y) int64 40B 0 1 2 3 4
Dimensions without coordinates: x
"""
- pass
-
- def set_index(self, indexes: (Mapping[Any, Hashable | Sequence[Hashable
- ]] | None)=None, append: bool=False, **indexes_kwargs: (Hashable |
- Sequence[Hashable])) ->Self:
+ if isinstance(dim, int):
+ raise TypeError("dim should be Hashable or sequence/mapping of Hashables")
+ elif isinstance(dim, Sequence) and not isinstance(dim, str):
+ if len(dim) != len(set(dim)):
+ raise ValueError("dims should not contain duplicate values.")
+ dim = dict.fromkeys(dim, 1)
+ elif dim is not None and not isinstance(dim, Mapping):
+ dim = {dim: 1}
+
+ dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
+ ds = self._to_temp_dataset().expand_dims(
+ dim, axis, create_index_for_new_dim=create_index_for_new_dim
+ )
+ return self._from_temp_dataset(ds)
+
+ def set_index(
+ self,
+ indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None,
+ append: bool = False,
+ **indexes_kwargs: Hashable | Sequence[Hashable],
+ ) -> Self:
"""Set DataArray (multi-)indexes using one or more existing
coordinates.
@@ -1980,10 +2735,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.reset_index
DataArray.set_xindex
"""
- pass
+ ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs)
+ return self._from_temp_dataset(ds)
- def reset_index(self, dims_or_levels: (Hashable | Sequence[Hashable]),
- drop: bool=False) ->Self:
+ def reset_index(
+ self,
+ dims_or_levels: Hashable | Sequence[Hashable],
+ drop: bool = False,
+ ) -> Self:
"""Reset the specified index(es) or multi-index level(s).
This legacy method is specific to pandas (multi-)indexes and
@@ -2013,10 +2772,15 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.set_xindex
DataArray.drop_indexes
"""
- pass
+ ds = self._to_temp_dataset().reset_index(dims_or_levels, drop=drop)
+ return self._from_temp_dataset(ds)
- def set_xindex(self, coord_names: (str | Sequence[Hashable]), index_cls:
- (type[Index] | None)=None, **options) ->Self:
+ def set_xindex(
+ self,
+ coord_names: str | Sequence[Hashable],
+ index_cls: type[Index] | None = None,
+ **options,
+ ) -> Self:
"""Set a new, Xarray-compatible index from one or more existing
coordinate(s).
@@ -2037,11 +2801,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Another dataarray, with this dataarray's data and with a new index.
"""
- pass
+ ds = self._to_temp_dataset().set_xindex(coord_names, index_cls, **options)
+ return self._from_temp_dataset(ds)
- def reorder_levels(self, dim_order: (Mapping[Any, Sequence[int |
- Hashable]] | None)=None, **dim_order_kwargs: Sequence[int | Hashable]
- ) ->Self:
+ def reorder_levels(
+ self,
+ dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None,
+ **dim_order_kwargs: Sequence[int | Hashable],
+ ) -> Self:
"""Rearrange index levels using input order.
Parameters
@@ -2060,12 +2827,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Another dataarray, with this dataarray's data but replaced
coordinates.
"""
- pass
+ ds = self._to_temp_dataset().reorder_levels(dim_order, **dim_order_kwargs)
+ return self._from_temp_dataset(ds)
- @partial(deprecate_dims, old_name='dimensions')
- def stack(self, dim: (Mapping[Any, Sequence[Hashable]] | None)=None,
- create_index: (bool | None)=True, index_cls: type[Index]=
- PandasMultiIndex, **dim_kwargs: Sequence[Hashable]) ->Self:
+ @partial(deprecate_dims, old_name="dimensions")
+ def stack(
+ self,
+ dim: Mapping[Any, Sequence[Hashable]] | None = None,
+ create_index: bool | None = True,
+ index_cls: type[Index] = PandasMultiIndex,
+ **dim_kwargs: Sequence[Hashable],
+ ) -> Self:
"""
Stack any number of existing dimensions into a single new dimension.
@@ -2124,11 +2896,22 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
DataArray.unstack
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def unstack(self, dim: Dims=None, *, fill_value: Any=dtypes.NA, sparse:
- bool=False) ->Self:
+ ds = self._to_temp_dataset().stack(
+ dim,
+ create_index=create_index,
+ index_cls=index_cls,
+ **dim_kwargs,
+ )
+ return self._from_temp_dataset(ds)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def unstack(
+ self,
+ dim: Dims = None,
+ *,
+ fill_value: Any = dtypes.NA,
+ sparse: bool = False,
+ ) -> Self:
"""
Unstack existing dimensions corresponding to MultiIndexes into
multiple new dimensions.
@@ -2183,10 +2966,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
DataArray.stack
"""
- pass
+ ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse)
+ return self._from_temp_dataset(ds)
- def to_unstacked_dataset(self, dim: Hashable, level: (int | Hashable)=0
- ) ->Dataset:
+ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset:
"""Unstack DataArray expanding to Dataset along a given level of a
stacked coordinate.
@@ -2235,11 +3018,29 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
Dataset.to_stacked_array
"""
- pass
+ idx = self._indexes[dim].to_pandas_index()
+ if not isinstance(idx, pd.MultiIndex):
+ raise ValueError(f"'{dim}' is not a stacked coordinate")
+
+ level_number = idx._get_level_number(level) # type: ignore[attr-defined]
+ variables = idx.levels[level_number]
+ variable_dim = idx.names[level_number]
+
+ # pull variables out of datarray
+ data_dict = {}
+ for k in variables:
+ data_dict[k] = self.sel({variable_dim: k}, drop=True).squeeze(drop=True)
+
+ # unstacked dataset
+ return Dataset(data_dict)
@deprecate_dims
- def transpose(self, *dim: Hashable, transpose_coords: bool=True,
- missing_dims: ErrorOptionsWithWarn='raise') ->Self:
+ def transpose(
+ self,
+ *dim: Hashable,
+ transpose_coords: bool = True,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ ) -> Self:
"""Return a new DataArray object with transposed dimensions.
Parameters
@@ -2272,10 +3073,28 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
numpy.transpose
Dataset.transpose
"""
- pass
+ if dim:
+ dim = tuple(infix_dims(dim, self.dims, missing_dims))
+ variable = self.variable.transpose(*dim)
+ if transpose_coords:
+ coords: dict[Hashable, Variable] = {}
+ for name, coord in self.coords.items():
+ coord_dims = tuple(d for d in dim if d in coord.dims)
+ coords[name] = coord.variable.transpose(*coord_dims)
+ return self._replace(variable, coords)
+ else:
+ return self._replace(variable)
- def drop_vars(self, names: (str | Iterable[Hashable] | Callable[[Self],
- str | Iterable[Hashable]]), *, errors: ErrorOptions='raise') ->Self:
+ @property
+ def T(self) -> Self:
+ return self.transpose()
+
+ def drop_vars(
+ self,
+ names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
+ *,
+ errors: ErrorOptions = "raise",
+ ) -> Self:
"""Returns an array with dropped variables.
Parameters
@@ -2341,10 +3160,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
[ 9, 10, 11]])
Dimensions without coordinates: x, y
"""
- pass
-
- def drop_indexes(self, coord_names: (Hashable | Iterable[Hashable]), *,
- errors: ErrorOptions='raise') ->Self:
+ if callable(names):
+ names = names(self)
+ ds = self._to_temp_dataset().drop_vars(names, errors=errors)
+ return self._from_temp_dataset(ds)
+
+ def drop_indexes(
+ self,
+ coord_names: Hashable | Iterable[Hashable],
+ *,
+ errors: ErrorOptions = "raise",
+ ) -> Self:
"""Drop the indexes assigned to the given coordinates.
Parameters
@@ -2361,10 +3187,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
dropped : DataArray
A new dataarray with dropped indexes.
"""
- pass
-
- def drop(self, labels: (Mapping[Any, Any] | None)=None, dim: (Hashable |
- None)=None, *, errors: ErrorOptions='raise', **labels_kwargs) ->Self:
+ ds = self._to_temp_dataset().drop_indexes(coord_names, errors=errors)
+ return self._from_temp_dataset(ds)
+
+ def drop(
+ self,
+ labels: Mapping[Any, Any] | None = None,
+ dim: Hashable | None = None,
+ *,
+ errors: ErrorOptions = "raise",
+ **labels_kwargs,
+ ) -> Self:
"""Backward compatible method based on `drop_vars` and `drop_sel`
Using either `drop_vars` or `drop_sel` is encouraged
@@ -2374,10 +3207,16 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.drop_vars
DataArray.drop_sel
"""
- pass
+ ds = self._to_temp_dataset().drop(labels, dim, errors=errors, **labels_kwargs)
+ return self._from_temp_dataset(ds)
- def drop_sel(self, labels: (Mapping[Any, Any] | None)=None, *, errors:
- ErrorOptions='raise', **labels_kwargs) ->Self:
+ def drop_sel(
+ self,
+ labels: Mapping[Any, Any] | None = None,
+ *,
+ errors: ErrorOptions = "raise",
+ **labels_kwargs,
+ ) -> Self:
"""Drop index labels from this DataArray.
Parameters
@@ -2433,10 +3272,15 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* x (x) int64 32B 0 2 4 8
* y (y) int64 24B 6 9 12
"""
- pass
+ if labels_kwargs or isinstance(labels, dict):
+ labels = either_dict_or_kwargs(labels, labels_kwargs, "drop")
- def drop_isel(self, indexers: (Mapping[Any, Any] | None)=None, **
- indexers_kwargs) ->Self:
+ ds = self._to_temp_dataset().drop_sel(labels, errors=errors)
+ return self._from_temp_dataset(ds)
+
+ def drop_isel(
+ self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs
+ ) -> Self:
"""Drop index positions from this DataArray.
Parameters
@@ -2481,11 +3325,18 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
[20, 21, 22, 24]])
Dimensions without coordinates: X, Y
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def dropna(self, dim: Hashable, *, how: Literal['any', 'all']='any',
- thresh: (int | None)=None) ->Self:
+ dataset = self._to_temp_dataset()
+ dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs)
+ return self._from_temp_dataset(dataset)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def dropna(
+ self,
+ dim: Hashable,
+ *,
+ how: Literal["any", "all"] = "any",
+ thresh: int | None = None,
+ ) -> Self:
"""Returns a new array with dropped labels for missing values along
the provided dimension.
@@ -2553,9 +3404,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
lon (X) float64 32B 10.0 10.25 10.5 10.75
Dimensions without coordinates: Y, X
"""
- pass
+ ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh)
+ return self._from_temp_dataset(ds)
- def fillna(self, value: Any) ->Self:
+ def fillna(self, value: Any) -> Self:
"""Fill missing values in this object.
This operation follows the normal broadcasting and alignment rules that
@@ -2609,20 +3461,40 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* Z (Z) int64 48B 0 1 2 3 4 5
height (Z) int64 48B 0 10 20 30 40 50
"""
- pass
-
- def interpolate_na(self, dim: (Hashable | None)=None, method:
- InterpOptions='linear', limit: (int | None)=None, use_coordinate: (
- bool | str)=True, max_gap: (None | int | float | str | pd.Timedelta |
- np.timedelta64 | datetime.timedelta)=None, keep_attrs: (bool | None
- )=None, **kwargs: Any) ->Self:
+ if utils.is_dict_like(value):
+ raise TypeError(
+ "cannot provide fill value as a dictionary with "
+ "fillna on a DataArray"
+ )
+ out = ops.fillna(self, value)
+ return out
+
+ def interpolate_na(
+ self,
+ dim: Hashable | None = None,
+ method: InterpOptions = "linear",
+ limit: int | None = None,
+ use_coordinate: bool | str = True,
+ max_gap: (
+ None
+ | int
+ | float
+ | str
+ | pd.Timedelta
+ | np.timedelta64
+ | datetime.timedelta
+ ) = None,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""Fill in NaNs by interpolating according to different methods.
Parameters
----------
dim : Hashable or None, optional
Specifies the dimension along which to interpolate.
- method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
+ method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \
+ "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
String indicating which method to use for interpolation:
- 'linear': linear interpolation. Additional keyword
@@ -2708,9 +3580,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Coordinates:
* x (x) int64 40B 0 1 2 3 4
"""
- pass
+ from xarray.core.missing import interp_na
- def ffill(self, dim: Hashable, limit: (int | None)=None) ->Self:
+ return interp_na(
+ self,
+ dim=dim,
+ method=method,
+ limit=limit,
+ use_coordinate=use_coordinate,
+ max_gap=max_gap,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def ffill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values forward
*Requires bottleneck.*
@@ -2790,9 +3673,11 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
lon (X) float64 24B 10.0 10.25 10.5
Dimensions without coordinates: Y, X
"""
- pass
+ from xarray.core.missing import ffill
+
+ return ffill(self, dim, limit=limit)
- def bfill(self, dim: Hashable, limit: (int | None)=None) ->Self:
+ def bfill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values backward
*Requires bottleneck.*
@@ -2872,9 +3757,11 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
lon (X) float64 24B 10.0 10.25 10.5
Dimensions without coordinates: Y, X
"""
- pass
+ from xarray.core.missing import bfill
- def combine_first(self, other: Self) ->Self:
+ return bfill(self, dim, limit=limit)
+
+ def combine_first(self, other: Self) -> Self:
"""Combine two DataArray objects, with union of coordinates.
This operation follows the normal broadcasting and alignment rules of
@@ -2890,11 +3777,18 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
-------
DataArray
"""
- pass
-
- def reduce(self, func: Callable[..., Any], dim: Dims=None, *, axis: (
- int | Sequence[int] | None)=None, keep_attrs: (bool | None)=None,
- keepdims: bool=False, **kwargs: Any) ->Self:
+ return ops.fillna(self, other, join="outer")
+
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> Self:
"""Reduce this array by applying `func` along some dimension(s).
Parameters
@@ -2928,9 +3822,11 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray with this object's array replaced with an array with
summarized data and the indicated dimension(s) removed.
"""
- pass
- def to_pandas(self) ->(Self | pd.Series | pd.DataFrame):
+ var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
+ return self._replace_maybe_drop_dims(var)
+
+ def to_pandas(self) -> Self | pd.Series | pd.DataFrame:
"""Convert this array into a pandas object with the same shape.
The type of the returned object depends on the number of DataArray
@@ -2949,10 +3845,22 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
result : DataArray | Series | DataFrame
DataArray, pandas Series or pandas DataFrame.
"""
- pass
-
- def to_dataframe(self, name: (Hashable | None)=None, dim_order: (
- Sequence[Hashable] | None)=None) ->pd.DataFrame:
+ # TODO: consolidate the info about pandas constructors and the
+ # attributes that correspond to their indexes into a separate module?
+ constructors = {0: lambda x: x, 1: pd.Series, 2: pd.DataFrame}
+ try:
+ constructor = constructors[self.ndim]
+ except KeyError:
+ raise ValueError(
+ f"Cannot convert arrays with {self.ndim} dimensions into "
+ "pandas objects. Requires 2 or fewer dimensions."
+ )
+ indexes = [self.get_index(dim) for dim in self.dims]
+ return constructor(self.values, *indexes) # type: ignore[operator]
+
+ def to_dataframe(
+ self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None
+ ) -> pd.DataFrame:
"""Convert this array and its coordinates into a tidy pandas.DataFrame.
The DataFrame is indexed by the Cartesian product of index coordinates
@@ -2987,9 +3895,34 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.to_pandas
DataArray.to_series
"""
- pass
+ if name is None:
+ name = self.name
+ if name is None:
+ raise ValueError(
+ "cannot convert an unnamed DataArray to a "
+ "DataFrame: use the ``name`` parameter"
+ )
+ if self.ndim == 0:
+ raise ValueError("cannot convert a scalar to a DataFrame")
+
+ # By using a unique name, we can convert a DataArray into a DataFrame
+ # even if it shares a name with one of its coordinates.
+ # I would normally use unique_name = object() but that results in a
+ # dataframe with columns in the wrong order, for reasons I have not
+ # been able to debug (possibly a pandas bug?).
+ unique_name = "__unique_name_identifier_z98xfz98xugfg73ho__"
+ ds = self._to_dataset_whole(name=unique_name)
+
+ if dim_order is None:
+ ordered_dims = dict(zip(self.dims, self.shape))
+ else:
+ ordered_dims = ds._normalize_dim_order(dim_order=dim_order)
+
+ df = ds._to_dataframe(ordered_dims)
+ df.columns = [name if c == unique_name else c for c in df.columns]
+ return df
- def to_series(self) ->pd.Series:
+ def to_series(self) -> pd.Series:
"""Convert this array into a pandas.Series.
The Series is indexed by the Cartesian product of index coordinates
@@ -3005,9 +3938,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.to_pandas
DataArray.to_dataframe
"""
- pass
+ index = self.coords.to_index()
+ return pd.Series(self.values.reshape(-1), index=index, name=self.name)
- def to_masked_array(self, copy: bool=True) ->np.ma.MaskedArray:
+ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
"""Convert this array into a numpy.ma.MaskedArray
Parameters
@@ -3021,14 +3955,84 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
result : MaskedArray
Masked where invalid values (nan or inf) occur.
"""
- pass
-
- def to_netcdf(self, path: (str | PathLike | None)=None, mode:
- NetcdfWriteModes='w', format: (T_NetcdfTypes | None)=None, group: (
- str | None)=None, engine: (T_NetcdfEngine | None)=None, encoding: (
- Mapping[Hashable, Mapping[str, Any]] | None)=None, unlimited_dims:
- (Iterable[Hashable] | None)=None, compute: bool=True,
- invalid_netcdf: bool=False) ->(bytes | Delayed | None):
+ values = self.to_numpy() # only compute lazy arrays once
+ isnull = pd.isnull(values)
+ return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)
+
+ # path=None writes to bytes
+ @overload
+ def to_netcdf(
+ self,
+ path: None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ invalid_netcdf: bool = False,
+ ) -> bytes: ...
+
+ # compute=False returns dask.Delayed
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ *,
+ compute: Literal[False],
+ invalid_netcdf: bool = False,
+ ) -> Delayed: ...
+
+ # default return None
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: Literal[True] = True,
+ invalid_netcdf: bool = False,
+ ) -> None: ...
+
+ # if compute cannot be evaluated at type check time
+ # we may get back either Delayed or None
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ invalid_netcdf: bool = False,
+ ) -> Delayed | None: ...
+
+ def to_netcdf(
+ self,
+ path: str | PathLike | None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Hashable, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ invalid_netcdf: bool = False,
+ ) -> bytes | Delayed | None:
"""Write DataArray contents to a netCDF file.
Parameters
@@ -3043,7 +4047,8 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Write ('w') or append ('a') mode. If mode='w', any existing file at
this location will be overwritten. If mode='a', existing variables
will be overwritten.
- format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"}, optional
+ format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \
+ "NETCDF3_CLASSIC"}, optional
File format for the resulting netCDF file:
* NETCDF4: Data is stored in an HDF5 file, using netCDF4 API
@@ -3117,17 +4122,91 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
Dataset.to_netcdf
"""
- pass
-
- def to_zarr(self, store: (MutableMapping | str | PathLike[str] | None)=
- None, chunk_store: (MutableMapping | str | PathLike | None)=None,
- mode: (ZarrWriteModes | None)=None, synchronizer=None, group: (str |
- None)=None, encoding: (Mapping | None)=None, *, compute: bool=True,
- consolidated: (bool | None)=None, append_dim: (Hashable | None)=
- None, region: (Mapping[str, slice | Literal['auto']] | Literal[
- 'auto'] | None)=None, safe_chunks: bool=True, storage_options: (
- dict[str, str] | None)=None, zarr_version: (int | None)=None) ->(
- ZarrStore | Delayed):
+ from xarray.backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE, to_netcdf
+
+ if self.name is None:
+ # If no name is set then use a generic xarray name
+ dataset = self.to_dataset(name=DATAARRAY_VARIABLE)
+ elif self.name in self.coords or self.name in self.dims:
+ # The name is the same as one of the coords names, which netCDF
+ # doesn't support, so rename it but keep track of the old name
+ dataset = self.to_dataset(name=DATAARRAY_VARIABLE)
+ dataset.attrs[DATAARRAY_NAME] = self.name
+ else:
+ # No problems with the name - so we're fine!
+ dataset = self.to_dataset()
+
+ return to_netcdf( # type: ignore # mypy cannot resolve the overloads:(
+ dataset,
+ path,
+ mode=mode,
+ format=format,
+ group=group,
+ engine=engine,
+ encoding=encoding,
+ unlimited_dims=unlimited_dims,
+ compute=compute,
+ multifile=False,
+ invalid_netcdf=invalid_netcdf,
+ )
+
+ # compute=True (default) returns ZarrStore
+ @overload
+ def to_zarr(
+ self,
+ store: MutableMapping | str | PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ *,
+ encoding: Mapping | None = None,
+ compute: Literal[True] = True,
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ ) -> ZarrStore: ...
+
+ # compute=False returns dask.Delayed
+ @overload
+ def to_zarr(
+ self,
+ store: MutableMapping | str | PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: Literal[False],
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ ) -> Delayed: ...
+
+ def to_zarr(
+ self,
+ store: MutableMapping | str | PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: bool = True,
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ ) -> ZarrStore | Delayed:
"""Write DataArray contents to a Zarr store
Zarr chunks are determined in the following way:
@@ -3251,10 +4330,40 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
:ref:`io.zarr`
The I/O user guide, with more details and examples.
"""
- pass
+ from xarray.backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE, to_zarr
- def to_dict(self, data: (bool | Literal['list', 'array'])='list',
- encoding: bool=False) ->dict[str, Any]:
+ if self.name is None:
+ # If no name is set then use a generic xarray name
+ dataset = self.to_dataset(name=DATAARRAY_VARIABLE)
+ elif self.name in self.coords or self.name in self.dims:
+ # The name is the same as one of the coords names, which the netCDF data model
+ # does not support, so rename it but keep track of the old name
+ dataset = self.to_dataset(name=DATAARRAY_VARIABLE)
+ dataset.attrs[DATAARRAY_NAME] = self.name
+ else:
+ # No problems with the name - so we're fine!
+ dataset = self.to_dataset()
+
+ return to_zarr( # type: ignore[call-overload,misc]
+ dataset,
+ store=store,
+ chunk_store=chunk_store,
+ mode=mode,
+ synchronizer=synchronizer,
+ group=group,
+ encoding=encoding,
+ compute=compute,
+ consolidated=consolidated,
+ append_dim=append_dim,
+ region=region,
+ safe_chunks=safe_chunks,
+ storage_options=storage_options,
+ zarr_version=zarr_version,
+ )
+
+ def to_dict(
+ self, data: bool | Literal["list", "array"] = "list", encoding: bool = False
+ ) -> dict[str, Any]:
"""
Convert this xarray.DataArray into a dictionary following xarray
naming conventions.
@@ -3285,10 +4394,16 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.from_dict
Dataset.to_dict
"""
- pass
+ d = self.variable.to_dict(data=data)
+ d.update({"coords": {}, "name": self.name})
+ for k, coord in self.coords.items():
+ d["coords"][k] = coord.variable.to_dict(data=data)
+ if encoding:
+ d["encoding"] = dict(self.encoding)
+ return d
@classmethod
- def from_dict(cls, d: Mapping[str, Any]) ->Self:
+ def from_dict(cls, d: Mapping[str, Any]) -> Self:
"""Convert a dictionary into an xarray.DataArray
Parameters
@@ -3332,10 +4447,31 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Attributes:
title: air temperature
"""
- pass
+ coords = None
+ if "coords" in d:
+ try:
+ coords = {
+ k: (v["dims"], v["data"], v.get("attrs"))
+ for k, v in d["coords"].items()
+ }
+ except KeyError as e:
+ raise ValueError(
+ "cannot convert dict when coords are missing the key "
+ f"'{str(e.args[0])}'"
+ )
+ try:
+ data = d["data"]
+ except KeyError:
+ raise ValueError("cannot convert dict without the key 'data''")
+ else:
+ obj = cls(data, coords, d.get("dims"), d.get("name"), d.get("attrs"))
+
+ obj.encoding.update(d.get("encoding", {}))
+
+ return obj
@classmethod
- def from_series(cls, series: pd.Series, sparse: bool=False) ->DataArray:
+ def from_series(cls, series: pd.Series, sparse: bool = False) -> DataArray:
"""Convert a pandas.Series into an xarray.DataArray.
If the series's index is a MultiIndex, it will be expanded into a
@@ -3356,22 +4492,37 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.to_series
Dataset.from_dataframe
"""
- pass
+ temp_name = "__temporary_name"
+ df = pd.DataFrame({temp_name: series})
+ ds = Dataset.from_dataframe(df, sparse=sparse)
+ result = ds[temp_name]
+ result.name = series.name
+ return result
- def to_iris(self) ->iris_Cube:
+ def to_iris(self) -> iris_Cube:
"""Convert this array into a iris.cube.Cube"""
- pass
+ from xarray.convert import to_iris
+
+ return to_iris(self)
@classmethod
- def from_iris(cls, cube: iris_Cube) ->Self:
+ def from_iris(cls, cube: iris_Cube) -> Self:
"""Convert a iris.cube.Cube into an xarray.DataArray"""
- pass
+ from xarray.convert import from_iris
- def _all_compat(self, other: Self, compat_str: str) ->bool:
+ return from_iris(cube)
+
+ def _all_compat(self, other: Self, compat_str: str) -> bool:
"""Helper function for equals, broadcast_equals, and identical"""
- pass
- def broadcast_equals(self, other: Self) ->bool:
+ def compat(x, y):
+ return getattr(x.variable, compat_str)(y.variable)
+
+ return utils.dict_equiv(self.coords, other.coords, compat=compat) and compat(
+ self, other
+ )
+
+ def broadcast_equals(self, other: Self) -> bool:
"""Two DataArrays are broadcast equal if they are equal after
broadcasting them against each other such that they have the same
dimensions.
@@ -3415,9 +4566,12 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
>>> a.broadcast_equals(b)
True
"""
- pass
+ try:
+ return self._all_compat(other, "broadcast_equals")
+ except (TypeError, AttributeError):
+ return False
- def equals(self, other: Self) ->bool:
+ def equals(self, other: Self) -> bool:
"""True if two DataArrays have the same dimensions, coordinates and
values; otherwise False.
@@ -3474,9 +4628,12 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
>>> a.equals(d)
False
"""
- pass
+ try:
+ return self._all_compat(other, "equals")
+ except (TypeError, AttributeError):
+ return False
- def identical(self, other: Self) ->bool:
+ def identical(self, other: Self) -> bool:
"""Like equals, but also checks the array name and attributes, and
attributes on all coordinates.
@@ -3529,20 +4686,101 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
>>> a.identical(c)
False
"""
- pass
+ try:
+ return self.name == other.name and self._all_compat(other, "identical")
+ except (TypeError, AttributeError):
+ return False
+
+ def _result_name(self, other: Any = None) -> Hashable | None:
+ # use the same naming heuristics as pandas:
+ # https://github.com/ContinuumIO/blaze/issues/458#issuecomment-51936356
+ other_name = getattr(other, "name", _default)
+ if other_name is _default or other_name == self.name:
+ return self.name
+ else:
+ return None
- def __array_wrap__(self, obj, context=None) ->Self:
+ def __array_wrap__(self, obj, context=None) -> Self:
new_var = self.variable.__array_wrap__(obj, context)
return self._replace(new_var)
- def __matmul__(self, obj: T_Xarray) ->T_Xarray:
+ def __matmul__(self, obj: T_Xarray) -> T_Xarray:
return self.dot(obj)
- def __rmatmul__(self, other: T_Xarray) ->T_Xarray:
+ def __rmatmul__(self, other: T_Xarray) -> T_Xarray:
+ # currently somewhat duplicative, as only other DataArrays are
+ # compatible with matmul
return computation.dot(other, self)
+
+ def _unary_op(self, f: Callable, *args, **kwargs) -> Self:
+ keep_attrs = kwargs.pop("keep_attrs", None)
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
+ warnings.filterwarnings(
+ "ignore", r"Mean of empty slice", category=RuntimeWarning
+ )
+ with np.errstate(all="ignore"):
+ da = self.__array_wrap__(f(self.variable.data, *args, **kwargs))
+ if keep_attrs:
+ da.attrs = self.attrs
+ return da
+
+ def _binary_op(
+ self, other: DaCompatible, f: Callable, reflexive: bool = False
+ ) -> Self:
+ from xarray.core.groupby import GroupBy
+
+ if isinstance(other, (Dataset, GroupBy)):
+ return NotImplemented
+ if isinstance(other, DataArray):
+ align_type = OPTIONS["arithmetic_join"]
+ self, other = align(self, other, join=align_type, copy=False)
+ other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other)
+ other_coords = getattr(other, "coords", None)
+
+ variable = (
+ f(self.variable, other_variable_or_arraylike)
+ if not reflexive
+ else f(other_variable_or_arraylike, self.variable)
+ )
+ coords, indexes = self.coords._merge_raw(other_coords, reflexive)
+ name = self._result_name(other)
+
+ return self._replace(variable, coords, name, indexes=indexes)
+
+ def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self:
+ from xarray.core.groupby import GroupBy
+
+ if isinstance(other, GroupBy):
+ raise TypeError(
+ "in-place operations between a DataArray and "
+ "a grouped object are not permitted"
+ )
+ # n.b. we can't align other to self (with other.reindex_like(self))
+ # because `other` may be converted into floats, which would cause
+ # in-place arithmetic to fail unpredictably. Instead, we simply
+ # don't support automatic alignment with in-place arithmetic.
+ other_coords = getattr(other, "coords", None)
+ other_variable = getattr(other, "variable", other)
+ try:
+ with self.coords._merge_inplace(other_coords):
+ f(self.variable, other_variable)
+ except MergeError as exc:
+ raise MergeError(
+ "Automatic alignment is not supported for in-place operations.\n"
+ "Consider aligning the indices manually or using a not-in-place operation.\n"
+ "See https://github.com/pydata/xarray/issues/3910 for more explanations."
+ ) from exc
+ return self
+
+ def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None:
+ self.attrs = other.attrs
+
plot = utils.UncachedAccessor(DataArrayPlotAccessor)
- def _title_for_slice(self, truncate: int=50) ->str:
+ def _title_for_slice(self, truncate: int = 50) -> str:
"""
If the dataarray has 1 dimensional coordinates or comes from a slice
we can show that info in the title
@@ -3558,11 +4796,27 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Can be used for plot titles
"""
- pass
+ one_dims = []
+ for dim, coord in self.coords.items():
+ if coord.size == 1:
+ one_dims.append(
+ f"{dim} = {format_item(coord.values)}{_get_units_from_attrs(coord)}"
+ )
- @_deprecate_positional_args('v2023.10.0')
- def diff(self, dim: Hashable, n: int=1, *, label: Literal['upper',
- 'lower']='upper') ->Self:
+ title = ", ".join(one_dims)
+ if len(title) > truncate:
+ title = title[: (truncate - 3)] + "..."
+
+ return title
+
+ @_deprecate_positional_args("v2023.10.0")
+ def diff(
+ self,
+ dim: Hashable,
+ n: int = 1,
+ *,
+ label: Literal["upper", "lower"] = "upper",
+ ) -> Self:
"""Calculate the n-th order discrete difference along given axis.
Parameters
@@ -3604,10 +4858,15 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
DataArray.differentiate
"""
- pass
+ ds = self._to_temp_dataset().diff(n=n, dim=dim, label=label)
+ return self._from_temp_dataset(ds)
- def shift(self, shifts: (Mapping[Any, int] | None)=None, fill_value:
- Any=dtypes.NA, **shifts_kwargs: int) ->Self:
+ def shift(
+ self,
+ shifts: Mapping[Any, int] | None = None,
+ fill_value: Any = dtypes.NA,
+ **shifts_kwargs: int,
+ ) -> Self:
"""Shift this DataArray by an offset along one or more dimensions.
Only the data is moved; coordinates stay in place. This is consistent
@@ -3647,10 +4906,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([nan, 5., 6.])
Dimensions without coordinates: x
"""
- pass
-
- def roll(self, shifts: (Mapping[Hashable, int] | None)=None,
- roll_coords: bool=False, **shifts_kwargs: int) ->Self:
+ variable = self.variable.shift(
+ shifts=shifts, fill_value=fill_value, **shifts_kwargs
+ )
+ return self._replace(variable=variable)
+
+ def roll(
+ self,
+ shifts: Mapping[Hashable, int] | None = None,
+ roll_coords: bool = False,
+ **shifts_kwargs: int,
+ ) -> Self:
"""Roll this array by an offset along one or more dimensions.
Unlike shift, roll treats the given dimensions as periodic, so will not
@@ -3689,10 +4955,13 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([7, 5, 6])
Dimensions without coordinates: x
"""
- pass
+ ds = self._to_temp_dataset().roll(
+ shifts=shifts, roll_coords=roll_coords, **shifts_kwargs
+ )
+ return self._from_temp_dataset(ds)
@property
- def real(self) ->Self:
+ def real(self) -> Self:
"""
The real part of the array.
@@ -3700,10 +4969,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
numpy.ndarray.real
"""
- pass
+ return self._replace(self.variable.real)
@property
- def imag(self) ->Self:
+ def imag(self) -> Self:
"""
The imaginary part of the array.
@@ -3711,10 +4980,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
numpy.ndarray.imag
"""
- pass
+ return self._replace(self.variable.imag)
@deprecate_dims
- def dot(self, other: T_Xarray, dim: Dims=None) ->T_Xarray:
+ def dot(
+ self,
+ other: T_Xarray,
+ dim: Dims = None,
+ ) -> T_Xarray:
"""Perform dot product of two DataArrays along their shared dims.
Equivalent to taking taking tensordot over all shared dims.
@@ -3755,11 +5028,25 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
('x', 'y')
"""
- pass
-
- def sortby(self, variables: (Hashable | DataArray | Sequence[Hashable |
- DataArray] | Callable[[Self], Hashable | DataArray | Sequence[
- Hashable | DataArray]]), ascending: bool=True) ->Self:
+ if isinstance(other, Dataset):
+ raise NotImplementedError(
+ "dot products are not yet supported with Dataset objects."
+ )
+ if not isinstance(other, DataArray):
+ raise TypeError("dot only operates on DataArrays.")
+
+ return computation.dot(self, other, dim=dim)
+
+ def sortby(
+ self,
+ variables: (
+ Hashable
+ | DataArray
+ | Sequence[Hashable | DataArray]
+ | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]]
+ ),
+ ascending: bool = True,
+ ) -> Self:
"""Sort object by labels or values (along an axis).
Sorts the dataarray, either along specified dimensions,
@@ -3823,13 +5110,25 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Coordinates:
* time (time) datetime64[ns] 40B 2000-01-05 2000-01-04 ... 2000-01-01
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def quantile(self, q: ArrayLike, dim: Dims=None, *, method:
- QuantileMethods='linear', keep_attrs: (bool | None)=None, skipna: (
- bool | None)=None, interpolation: (QuantileMethods | None)=None
- ) ->Self:
+ # We need to convert the callable here rather than pass it through to the
+ # dataset method, since otherwise the dataset method would try to call the
+ # callable with the dataset as the object
+ if callable(variables):
+ variables = variables(self)
+ ds = self._to_temp_dataset().sortby(variables, ascending=ascending)
+ return self._from_temp_dataset(ds)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def quantile(
+ self,
+ q: ArrayLike,
+ dim: Dims = None,
+ *,
+ method: QuantileMethods = "linear",
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ interpolation: QuantileMethods | None = None,
+ ) -> Self:
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements.
@@ -3928,11 +5227,25 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
- pass
- @_deprecate_positional_args('v2023.10.0')
- def rank(self, dim: Hashable, *, pct: bool=False, keep_attrs: (bool |
- None)=None) ->Self:
+ ds = self._to_temp_dataset().quantile(
+ q,
+ dim=dim,
+ keep_attrs=keep_attrs,
+ method=method,
+ skipna=skipna,
+ interpolation=interpolation,
+ )
+ return self._from_temp_dataset(ds)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def rank(
+ self,
+ dim: Hashable,
+ *,
+ pct: bool = False,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""Ranks the data.
Equal values are assigned a rank that is the average of the ranks that
@@ -3967,10 +5280,16 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([1., 2., 3.])
Dimensions without coordinates: x
"""
- pass
- def differentiate(self, coord: Hashable, edge_order: Literal[1, 2]=1,
- datetime_unit: DatetimeUnitOptions=None) ->Self:
+ ds = self._to_temp_dataset().rank(dim, pct=pct, keep_attrs=keep_attrs)
+ return self._from_temp_dataset(ds)
+
+ def differentiate(
+ self,
+ coord: Hashable,
+ edge_order: Literal[1, 2] = 1,
+ datetime_unit: DatetimeUnitOptions = None,
+ ) -> Self:
"""Differentiate the array with the second order accurate central
differences.
@@ -3984,7 +5303,8 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
The coordinate to be used to compute the gradient.
edge_order : {1, 2}, default: 1
N-th order accurate differences at the boundaries.
- datetime_unit : {"W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as", None}, optional
+ datetime_unit : {"W", "D", "h", "m", "s", "ms", \
+ "us", "ns", "ps", "fs", "as", None}, optional
Unit to compute gradient. Only valid for datetime coordinate. "Y" and "M" are not available as
datetime_unit.
@@ -4024,10 +5344,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* x (x) float64 32B 0.0 0.1 1.1 1.2
Dimensions without coordinates: y
"""
- pass
+ ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit)
+ return self._from_temp_dataset(ds)
- def integrate(self, coord: (Hashable | Sequence[Hashable])=None,
- datetime_unit: DatetimeUnitOptions=None) ->Self:
+ def integrate(
+ self,
+ coord: Hashable | Sequence[Hashable] = None,
+ datetime_unit: DatetimeUnitOptions = None,
+ ) -> Self:
"""Integrate along the given coordinate using the trapezoidal rule.
.. note::
@@ -4038,7 +5362,8 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
----------
coord : Hashable, or sequence of Hashable
Coordinate(s) used for the integration.
- datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as', None}, optional
+ datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
+ 'ps', 'fs', 'as', None}, optional
Specify the unit if a datetime coordinate is used.
Returns
@@ -4073,10 +5398,14 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([5.4, 6.6, 7.8])
Dimensions without coordinates: y
"""
- pass
+ ds = self._to_temp_dataset().integrate(coord, datetime_unit)
+ return self._from_temp_dataset(ds)
- def cumulative_integrate(self, coord: (Hashable | Sequence[Hashable])=
- None, datetime_unit: DatetimeUnitOptions=None) ->Self:
+ def cumulative_integrate(
+ self,
+ coord: Hashable | Sequence[Hashable] = None,
+ datetime_unit: DatetimeUnitOptions = None,
+ ) -> Self:
"""Integrate cumulatively along the given coordinate using the trapezoidal rule.
.. note::
@@ -4090,7 +5419,8 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
----------
coord : Hashable, or sequence of Hashable
Coordinate(s) used for the integration.
- datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as', None}, optional
+ datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
+ 'ps', 'fs', 'as', None}, optional
Specify the unit if a datetime coordinate is used.
Returns
@@ -4130,9 +5460,10 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* x (x) float64 32B 0.0 0.1 1.1 1.2
Dimensions without coordinates: y
"""
- pass
+ ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit)
+ return self._from_temp_dataset(ds)
- def unify_chunks(self) ->Self:
+ def unify_chunks(self) -> Self:
"""Unify chunk size along all chunked dimensions of this DataArray.
Returns
@@ -4143,11 +5474,16 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
--------
dask.array.core.unify_chunks
"""
- pass
- def map_blocks(self, func: Callable[..., T_Xarray], args: Sequence[Any]
- =(), kwargs: (Mapping[str, Any] | None)=None, template: (DataArray |
- Dataset | None)=None) ->T_Xarray:
+ return unify_chunks(self)[0]
+
+ def map_blocks(
+ self,
+ func: Callable[..., T_Xarray],
+ args: Sequence[Any] = (),
+ kwargs: Mapping[str, Any] | None = None,
+ template: DataArray | Dataset | None = None,
+ ) -> T_Xarray:
"""
Apply a function to each block of this DataArray.
@@ -4244,11 +5580,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 192B dask.array<chunksize=(24,), meta=np.ndarray>
"""
- pass
+ from xarray.core.parallel import map_blocks
- def polyfit(self, dim: Hashable, deg: int, skipna: (bool | None)=None,
- rcond: (float | None)=None, w: (Hashable | Any | None)=None, full:
- bool=False, cov: (bool | Literal['unscaled'])=False) ->Dataset:
+ return map_blocks(func, self, args, kwargs, template)
+
+ def polyfit(
+ self,
+ dim: Hashable,
+ deg: int,
+ skipna: bool | None = None,
+ rcond: float | None = None,
+ w: Hashable | Any | None = None,
+ full: bool = False,
+ cov: bool | Literal["unscaled"] = False,
+ ) -> Dataset:
"""
Least squares polynomial fit.
@@ -4301,16 +5646,25 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
xarray.polyval
DataArray.curvefit
"""
- pass
-
- def pad(self, pad_width: (Mapping[Any, int | tuple[int, int]] | None)=
- None, mode: PadModeOptions='constant', stat_length: (int | tuple[
- int, int] | Mapping[Any, tuple[int, int]] | None)=None,
- constant_values: (float | tuple[float, float] | Mapping[Any, tuple[
- float, float]] | None)=None, end_values: (int | tuple[int, int] |
- Mapping[Any, tuple[int, int]] | None)=None, reflect_type:
- PadReflectOptions=None, keep_attrs: (bool | None)=None, **
- pad_width_kwargs: Any) ->Self:
+ return self._to_temp_dataset().polyfit(
+ dim, deg, skipna=skipna, rcond=rcond, w=w, full=full, cov=cov
+ )
+
+ def pad(
+ self,
+ pad_width: Mapping[Any, int | tuple[int, int]] | None = None,
+ mode: PadModeOptions = "constant",
+ stat_length: (
+ int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
+ ) = None,
+ constant_values: (
+ float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
+ ) = None,
+ end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
+ reflect_type: PadReflectOptions = None,
+ keep_attrs: bool | None = None,
+ **pad_width_kwargs: Any,
+ ) -> Self:
"""Pad this array along one or more dimensions.
.. warning::
@@ -4327,7 +5681,8 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Mapping with the form of {dim: (pad_before, pad_after)}
describing the number of values padded along each dimension.
{dim: pad} is a shortcut for pad_before = pad_after = pad
- mode : {"constant", "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap"}, default: "constant"
+ mode : {"constant", "edge", "linear_ramp", "maximum", "mean", "median", \
+ "minimum", "reflect", "symmetric", "wrap"}, default: "constant"
How to pad the DataArray (taken from numpy docs):
- "constant": Pads with a constant value.
@@ -4449,12 +5804,27 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* y (y) int64 32B 10 20 30 40
z (x) float64 32B nan 100.0 200.0 nan
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def idxmin(self, dim: (Hashable | None)=None, *, skipna: (bool | None)=
- None, fill_value: Any=dtypes.NA, keep_attrs: (bool | None)=None
- ) ->Self:
+ ds = self._to_temp_dataset().pad(
+ pad_width=pad_width,
+ mode=mode,
+ stat_length=stat_length,
+ constant_values=constant_values,
+ end_values=end_values,
+ reflect_type=reflect_type,
+ keep_attrs=keep_attrs,
+ **pad_width_kwargs,
+ )
+ return self._from_temp_dataset(ds)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def idxmin(
+ self,
+ dim: Hashable | None = None,
+ *,
+ skipna: bool | None = None,
+ fill_value: Any = dtypes.NA,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""Return the coordinate label of the minimum value along a dimension.
Returns a new `DataArray` named after the dimension with the values of
@@ -4535,11 +5905,24 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Coordinates:
* y (y) int64 24B -1 0 1
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def idxmax(self, dim: Hashable=None, *, skipna: (bool | None)=None,
- fill_value: Any=dtypes.NA, keep_attrs: (bool | None)=None) ->Self:
+ return computation._calc_idxminmax(
+ array=self,
+ func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
+ dim=dim,
+ skipna=skipna,
+ fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def idxmax(
+ self,
+ dim: Hashable = None,
+ *,
+ skipna: bool | None = None,
+ fill_value: Any = dtypes.NA,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""Return the coordinate label of the maximum value along a dimension.
Returns a new `DataArray` named after the dimension with the values of
@@ -4620,12 +6003,24 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Coordinates:
* y (y) int64 24B -1 0 1
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def argmin(self, dim: Dims=None, *, axis: (int | None)=None, keep_attrs:
- (bool | None)=None, skipna: (bool | None)=None) ->(Self | dict[
- Hashable, Self]):
+ return computation._calc_idxminmax(
+ array=self,
+ func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
+ dim=dim,
+ skipna=skipna,
+ fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def argmin(
+ self,
+ dim: Dims = None,
+ *,
+ axis: int | None = None,
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ ) -> Self | dict[Hashable, Self]:
"""Index or indices of the minimum of the DataArray over one or more dimensions.
If a sequence is passed to 'dim', then result returned as dict of DataArrays,
@@ -4713,12 +6108,21 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([ 1, -5, 1])
Dimensions without coordinates: y
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def argmax(self, dim: Dims=None, *, axis: (int | None)=None, keep_attrs:
- (bool | None)=None, skipna: (bool | None)=None) ->(Self | dict[
- Hashable, Self]):
+ result = self.variable.argmin(dim, axis, keep_attrs, skipna)
+ if isinstance(result, dict):
+ return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
+ else:
+ return self._replace_maybe_drop_dims(result)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def argmax(
+ self,
+ dim: Dims = None,
+ *,
+ axis: int | None = None,
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ ) -> Self | dict[Hashable, Self]:
"""Index or indices of the maximum of the DataArray over one or more dimensions.
If a sequence is passed to 'dim', then result returned as dict of DataArrays,
@@ -4806,12 +6210,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([3, 5, 3])
Dimensions without coordinates: y
"""
- pass
-
- def query(self, queries: (Mapping[Any, Any] | None)=None, parser:
- QueryParserOptions='pandas', engine: QueryEngineOptions=None,
- missing_dims: ErrorOptionsWithWarn='raise', **queries_kwargs: Any
- ) ->DataArray:
+ result = self.variable.argmax(dim, axis, keep_attrs, skipna)
+ if isinstance(result, dict):
+ return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
+ else:
+ return self._replace_maybe_drop_dims(result)
+
+ def query(
+ self,
+ queries: Mapping[Any, Any] | None = None,
+ parser: QueryParserOptions = "pandas",
+ engine: QueryEngineOptions = None,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ **queries_kwargs: Any,
+ ) -> DataArray:
"""Return a new data array indexed along the specified
dimension(s), where the indexers are given as strings containing
Python expressions to be evaluated against the values in the array.
@@ -4872,14 +6284,29 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
array([3, 4])
Dimensions without coordinates: x
"""
- pass
- def curvefit(self, coords: (str | DataArray | Iterable[str | DataArray]
- ), func: Callable[..., Any], reduce_dims: Dims=None, skipna: bool=
- True, p0: (Mapping[str, float | DataArray] | None)=None, bounds: (
- Mapping[str, tuple[float | DataArray, float | DataArray]] | None)=
- None, param_names: (Sequence[str] | None)=None, errors:
- ErrorOptions='raise', kwargs: (dict[str, Any] | None)=None) ->Dataset:
+ ds = self._to_dataset_whole(shallow_copy=True)
+ ds = ds.query(
+ queries=queries,
+ parser=parser,
+ engine=engine,
+ missing_dims=missing_dims,
+ **queries_kwargs,
+ )
+ return ds[self.name]
+
+ def curvefit(
+ self,
+ coords: str | DataArray | Iterable[str | DataArray],
+ func: Callable[..., Any],
+ reduce_dims: Dims = None,
+ skipna: bool = True,
+ p0: Mapping[str, float | DataArray] | None = None,
+ bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None,
+ param_names: Sequence[str] | None = None,
+ errors: ErrorOptions = "raise",
+ kwargs: dict[str, Any] | None = None,
+ ) -> Dataset:
"""
Curve fitting optimization for arbitrary functions.
@@ -5024,11 +6451,25 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray.polyfit
scipy.optimize.curve_fit
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def drop_duplicates(self, dim: (Hashable | Iterable[Hashable]), *, keep:
- Literal['first', 'last', False]='first') ->Self:
+ return self._to_temp_dataset().curvefit(
+ coords,
+ func,
+ reduce_dims=reduce_dims,
+ skipna=skipna,
+ p0=p0,
+ bounds=bounds,
+ param_names=param_names,
+ errors=errors,
+ kwargs=kwargs,
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def drop_duplicates(
+ self,
+ dim: Hashable | Iterable[Hashable],
+ *,
+ keep: Literal["first", "last", False] = "first",
+ ) -> Self:
"""Returns a new DataArray with duplicate dimension values removed.
Parameters
@@ -5100,11 +6541,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
* x (x) int64 32B 0 1 2 3
* y (y) int64 32B 0 1 2 3
"""
- pass
-
- def convert_calendar(self, calendar: str, dim: str='time', align_on: (
- str | None)=None, missing: (Any | None)=None, use_cftime: (bool |
- None)=None) ->Self:
+ deduplicated = self._to_temp_dataset().drop_duplicates(dim, keep=keep)
+ return self._from_temp_dataset(deduplicated)
+
+ def convert_calendar(
+ self,
+ calendar: str,
+ dim: str = "time",
+ align_on: str | None = None,
+ missing: Any | None = None,
+ use_cftime: bool | None = None,
+ ) -> Self:
"""Convert the DataArray to another calendar.
Only converts the individual timestamps, does not modify any data except
@@ -5211,10 +6658,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
This option is best used with data on a frequency coarser than daily.
"""
- pass
-
- def interp_calendar(self, target: (pd.DatetimeIndex | CFTimeIndex |
- DataArray), dim: str='time') ->Self:
+ return convert_calendar(
+ self,
+ calendar,
+ dim=dim,
+ align_on=align_on,
+ missing=missing,
+ use_cftime=use_cftime,
+ )
+
+ def interp_calendar(
+ self,
+ target: pd.DatetimeIndex | CFTimeIndex | DataArray,
+ dim: str = "time",
+ ) -> Self:
"""Interpolates the DataArray to another calendar based on decimal year measure.
Each timestamp in `source` and `target` are first converted to their decimal
@@ -5239,13 +6696,19 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
DataArray
The source interpolated on the decimal years of target,
"""
- pass
-
- @_deprecate_positional_args('v2024.07.0')
- def groupby(self, group: (Hashable | DataArray | IndexVariable |
- Mapping[Any, Grouper] | None)=None, *, squeeze: Literal[False]=
- False, restore_coord_dims: bool=False, **groupers: Grouper
- ) ->DataArrayGroupBy:
+ return interp_calendar(self, target, dim=dim)
+
+ @_deprecate_positional_args("v2024.07.0")
+ def groupby(
+ self,
+ group: (
+ Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
+ ) = None,
+ *,
+ squeeze: Literal[False] = False,
+ restore_coord_dims: bool = False,
+ **groupers: Grouper,
+ ) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Parameters
@@ -5320,14 +6783,54 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Dataset.resample
DataArray.resample
"""
- pass
+ from xarray.core.groupby import (
+ DataArrayGroupBy,
+ ResolvedGrouper,
+ _validate_groupby_squeeze,
+ )
+ from xarray.groupers import UniqueGrouper
+
+ _validate_groupby_squeeze(squeeze)
- @_deprecate_positional_args('v2024.07.0')
- def groupby_bins(self, group: (Hashable | DataArray | IndexVariable),
- bins: Bins, right: bool=True, labels: (ArrayLike | Literal[False] |
- None)=None, precision: int=3, include_lowest: bool=False, squeeze:
- Literal[False]=False, restore_coord_dims: bool=False, duplicates:
- Literal['raise', 'drop']='raise') ->DataArrayGroupBy:
+ if isinstance(group, Mapping):
+ groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
+ group = None
+
+ grouper: Grouper
+ if group is not None:
+ if groupers:
+ raise ValueError(
+ "Providing a combination of `group` and **groupers is not supported."
+ )
+ grouper = UniqueGrouper()
+ else:
+ if len(groupers) > 1:
+ raise ValueError("grouping by multiple variables is not supported yet.")
+ if not groupers:
+ raise ValueError("Either `group` or `**groupers` must be provided.")
+ group, grouper = next(iter(groupers.items()))
+
+ rgrouper = ResolvedGrouper(grouper, group, self)
+
+ return DataArrayGroupBy(
+ self,
+ (rgrouper,),
+ restore_coord_dims=restore_coord_dims,
+ )
+
+ @_deprecate_positional_args("v2024.07.0")
+ def groupby_bins(
+ self,
+ group: Hashable | DataArray | IndexVariable,
+ bins: Bins,
+ right: bool = True,
+ labels: ArrayLike | Literal[False] | None = None,
+ precision: int = 3,
+ include_lowest: bool = False,
+ squeeze: Literal[False] = False,
+ restore_coord_dims: bool = False,
+ duplicates: Literal["raise", "drop"] = "raise",
+ ) -> DataArrayGroupBy:
"""Returns a DataArrayGroupBy object for performing grouped operations.
Rather than using all unique values of `group`, the values are discretized
@@ -5385,9 +6888,30 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
----------
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
"""
- pass
-
- def weighted(self, weights: DataArray) ->DataArrayWeighted:
+ from xarray.core.groupby import (
+ DataArrayGroupBy,
+ ResolvedGrouper,
+ _validate_groupby_squeeze,
+ )
+ from xarray.groupers import BinGrouper
+
+ _validate_groupby_squeeze(squeeze)
+ grouper = BinGrouper(
+ bins=bins,
+ right=right,
+ labels=labels,
+ precision=precision,
+ include_lowest=include_lowest,
+ )
+ rgrouper = ResolvedGrouper(grouper, group, self)
+
+ return DataArrayGroupBy(
+ self,
+ (rgrouper,),
+ restore_coord_dims=restore_coord_dims,
+ )
+
+ def weighted(self, weights: DataArray) -> DataArrayWeighted:
"""
Weighted DataArray operations.
@@ -5418,11 +6942,17 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Tutorial on Weighted Reduction using :py:func:`~xarray.DataArray.weighted`
"""
- pass
+ from xarray.core.weighted import DataArrayWeighted
- def rolling(self, dim: (Mapping[Any, int] | None)=None, min_periods: (
- int | None)=None, center: (bool | Mapping[Any, bool])=False, **
- window_kwargs: int) ->DataArrayRolling:
+ return DataArrayWeighted(self, weights)
+
+ def rolling(
+ self,
+ dim: Mapping[Any, int] | None = None,
+ min_periods: int | None = None,
+ center: bool | Mapping[Any, bool] = False,
+ **window_kwargs: int,
+ ) -> DataArrayRolling:
"""
Rolling window object for DataArrays.
@@ -5485,10 +7015,16 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Dataset.rolling
core.rolling.DataArrayRolling
"""
- pass
+ from xarray.core.rolling import DataArrayRolling
+
+ dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
+ return DataArrayRolling(self, dim, min_periods=min_periods, center=center)
- def cumulative(self, dim: (str | Iterable[Hashable]), min_periods: int=1
- ) ->DataArrayRolling:
+ def cumulative(
+ self,
+ dim: str | Iterable[Hashable],
+ min_periods: int = 1,
+ ) -> DataArrayRolling:
"""
Accumulating object for DataArrays.
@@ -5539,12 +7075,34 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Dataset.cumulative
core.rolling.DataArrayRolling
"""
- pass
+ from xarray.core.rolling import DataArrayRolling
- def coarsen(self, dim: (Mapping[Any, int] | None)=None, boundary:
- CoarsenBoundaryOptions='exact', side: (SideOptions | Mapping[Any,
- SideOptions])='left', coord_func: (str | Callable | Mapping[Any,
- str | Callable])='mean', **window_kwargs: int) ->DataArrayCoarsen:
+ # Could we abstract this "normalize and check 'dim'" logic? It's currently shared
+ # with the same method in Dataset.
+ if isinstance(dim, str):
+ if dim not in self.dims:
+ raise ValueError(
+ f"Dimension {dim} not found in data dimensions: {self.dims}"
+ )
+ dim = {dim: self.sizes[dim]}
+ else:
+ missing_dims = set(dim) - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {missing_dims} not found in data dimensions: {self.dims}"
+ )
+ dim = {d: self.sizes[d] for d in dim}
+
+ return DataArrayRolling(self, dim, min_periods=min_periods, center=False)
+
+ def coarsen(
+ self,
+ dim: Mapping[Any, int] | None = None,
+ boundary: CoarsenBoundaryOptions = "exact",
+ side: SideOptions | Mapping[Any, SideOptions] = "left",
+ coord_func: str | Callable | Mapping[Any, str | Callable] = "mean",
+ **window_kwargs: int,
+ ) -> DataArrayCoarsen:
"""
Coarsen object for DataArrays.
@@ -5671,15 +7229,30 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
Tutorial on windowed computation using :py:func:`~xarray.DataArray.coarsen`
"""
- pass
-
- @_deprecate_positional_args('v2024.07.0')
- def resample(self, indexer: (Mapping[Hashable, str | Resampler] | None)
- =None, *, skipna: (bool | None)=None, closed: (SideOptions | None)=
- None, label: (SideOptions | None)=None, offset: (pd.Timedelta |
- datetime.timedelta | str | None)=None, origin: (str | DatetimeLike)
- ='start_day', restore_coord_dims: (bool | None)=None, **
- indexer_kwargs: (str | Resampler)) ->DataArrayResample:
+ from xarray.core.rolling import DataArrayCoarsen
+
+ dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen")
+ return DataArrayCoarsen(
+ self,
+ dim,
+ boundary=boundary,
+ side=side,
+ coord_func=coord_func,
+ )
+
+ @_deprecate_positional_args("v2024.07.0")
+ def resample(
+ self,
+ indexer: Mapping[Hashable, str | Resampler] | None = None,
+ *,
+ skipna: bool | None = None,
+ closed: SideOptions | None = None,
+ label: SideOptions | None = None,
+ offset: pd.Timedelta | datetime.timedelta | str | None = None,
+ origin: str | DatetimeLike = "start_day",
+ restore_coord_dims: bool | None = None,
+ **indexer_kwargs: str | Resampler,
+ ) -> DataArrayResample:
"""Returns a Resample object for performing resampling operations.
Handles both downsampling and upsampling. The resampled
@@ -5806,10 +7379,25 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
----------
.. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
"""
- pass
-
- def to_dask_dataframe(self, dim_order: (Sequence[Hashable] | None)=None,
- set_index: bool=False) ->DaskDataFrame:
+ from xarray.core.resample import DataArrayResample
+
+ return self._resample(
+ resample_cls=DataArrayResample,
+ indexer=indexer,
+ skipna=skipna,
+ closed=closed,
+ label=label,
+ offset=offset,
+ origin=origin,
+ restore_coord_dims=restore_coord_dims,
+ **indexer_kwargs,
+ )
+
+ def to_dask_dataframe(
+ self,
+ dim_order: Sequence[Hashable] | None = None,
+ set_index: bool = False,
+ ) -> DaskDataFrame:
"""Convert this array into a dask.dataframe.DataFrame.
Parameters
@@ -5861,10 +7449,20 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
14 -20 130 2 11
15 -20 130 3 15
"""
- pass
- str = utils.UncachedAccessor(StringAccessor['DataArray'])
+ if self.name is None:
+ raise ValueError(
+ "Cannot convert an unnamed DataArray to a "
+ "dask dataframe : use the ``.rename`` method to assign a name."
+ )
+ name = self.name
+ ds = self._to_dataset_whole(name, shallow_copy=False)
+ return ds.to_dask_dataframe(dim_order, set_index)
+
+ # this needs to be at the end, or mypy will confuse with `str`
+ # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
+ str = utils.UncachedAccessor(StringAccessor["DataArray"])
- def drop_attrs(self, *, deep: bool=True) ->Self:
+ def drop_attrs(self, *, deep: bool = True) -> Self:
"""
Removes all attributes from the DataArray.
@@ -5877,4 +7475,6 @@ class DataArray(AbstractArray, DataWithCoords, DataArrayArithmetic,
-------
DataArray
"""
- pass
+ return (
+ self._to_temp_dataset().drop_attrs(deep=deep).pipe(self._from_temp_dataset)
+ )
diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py
index fec06ba8..cad2f00c 100644
--- a/xarray/core/dataset.py
+++ b/xarray/core/dataset.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import copy
import datetime
import inspect
@@ -7,47 +8,132 @@ import math
import sys
import warnings
from collections import defaultdict
-from collections.abc import Collection, Hashable, Iterable, Iterator, Mapping, MutableMapping, Sequence
+from collections.abc import (
+ Collection,
+ Hashable,
+ Iterable,
+ Iterator,
+ Mapping,
+ MutableMapping,
+ Sequence,
+)
from functools import partial
from html import escape
from numbers import Number
from operator import methodcaller
from os import PathLike
from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload
+
import numpy as np
from pandas.api.types import is_extension_array_dtype
+
+# remove once numpy 2.0 is the oldest supported version
try:
from numpy.exceptions import RankWarning
except ImportError:
- from numpy import RankWarning
+ from numpy import RankWarning # type: ignore[no-redef,attr-defined,unused-ignore]
+
import pandas as pd
+
from xarray.coding.calendar_ops import convert_calendar, interp_calendar
from xarray.coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
-from xarray.core import alignment, duck_array_ops, formatting, formatting_html, ops, utils
+from xarray.core import (
+ alignment,
+ duck_array_ops,
+ formatting,
+ formatting_html,
+ ops,
+ utils,
+)
from xarray.core import dtypes as xrdtypes
from xarray.core._aggregations import DatasetAggregations
-from xarray.core.alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
+from xarray.core.alignment import (
+ _broadcast_helper,
+ _get_broadcast_dims_map_common_coords,
+ align,
+)
from xarray.core.arithmetic import DatasetArithmetic
-from xarray.core.common import DataWithCoords, _contains_datetime_like_objects, get_chunksizes
+from xarray.core.common import (
+ DataWithCoords,
+ _contains_datetime_like_objects,
+ get_chunksizes,
+)
from xarray.core.computation import unify_chunks
-from xarray.core.coordinates import Coordinates, DatasetCoordinates, assert_coordinate_consistent, create_coords_with_default_indexes
+from xarray.core.coordinates import (
+ Coordinates,
+ DatasetCoordinates,
+ assert_coordinate_consistent,
+ create_coords_with_default_indexes,
+)
from xarray.core.duck_array_ops import datetime_to_numeric
-from xarray.core.indexes import Index, Indexes, PandasIndex, PandasMultiIndex, assert_no_index_corrupted, create_default_index_implicit, filter_indexes_from_coords, isel_indexes, remove_unused_levels_categories, roll_indexes
+from xarray.core.indexes import (
+ Index,
+ Indexes,
+ PandasIndex,
+ PandasMultiIndex,
+ assert_no_index_corrupted,
+ create_default_index_implicit,
+ filter_indexes_from_coords,
+ isel_indexes,
+ remove_unused_levels_categories,
+ roll_indexes,
+)
from xarray.core.indexing import is_fancy_indexer, map_index_queries
-from xarray.core.merge import dataset_merge_method, dataset_update_method, merge_coordinates_without_align, merge_core
+from xarray.core.merge import (
+ dataset_merge_method,
+ dataset_update_method,
+ merge_coordinates_without_align,
+ merge_core,
+)
from xarray.core.missing import get_clean_interp_index
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.types import Bins, NetcdfWriteModes, QuantileMethods, Self, T_ChunkDim, T_ChunksFreq, T_DataArray, T_DataArrayOrSet, T_Dataset, ZarrWriteModes
-from xarray.core.utils import Default, Frozen, FrozenMappingWarningOnValuesAccess, HybridMappingProxy, OrderedSet, _default, decode_numpy_dict_values, drop_dims_from_indexers, either_dict_or_kwargs, emit_user_level_warning, infix_dims, is_dict_like, is_duck_array, is_duck_dask_array, is_scalar, maybe_wrap_array
-from xarray.core.variable import IndexVariable, Variable, as_variable, broadcast_variables, calculate_dimensions
+from xarray.core.types import (
+ Bins,
+ NetcdfWriteModes,
+ QuantileMethods,
+ Self,
+ T_ChunkDim,
+ T_ChunksFreq,
+ T_DataArray,
+ T_DataArrayOrSet,
+ T_Dataset,
+ ZarrWriteModes,
+)
+from xarray.core.utils import (
+ Default,
+ Frozen,
+ FrozenMappingWarningOnValuesAccess,
+ HybridMappingProxy,
+ OrderedSet,
+ _default,
+ decode_numpy_dict_values,
+ drop_dims_from_indexers,
+ either_dict_or_kwargs,
+ emit_user_level_warning,
+ infix_dims,
+ is_dict_like,
+ is_duck_array,
+ is_duck_dask_array,
+ is_scalar,
+ maybe_wrap_array,
+)
+from xarray.core.variable import (
+ IndexVariable,
+ Variable,
+ as_variable,
+ broadcast_variables,
+ calculate_dimensions,
+)
from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager
from xarray.namedarray.pycompat import array_type, is_chunked_array
from xarray.plot.accessor import DatasetPlotAccessor
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
+
if TYPE_CHECKING:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from numpy.typing import ArrayLike
+
from xarray.backends import AbstractDataStore, ZarrStore
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
from xarray.core.dataarray import DataArray
@@ -55,89 +141,336 @@ if TYPE_CHECKING:
from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult
from xarray.core.resample import DatasetResample
from xarray.core.rolling import DatasetCoarsen, DatasetRolling
- from xarray.core.types import CFCalendar, CoarsenBoundaryOptions, CombineAttrsOptions, CompatOptions, DataVars, DatetimeLike, DatetimeUnitOptions, Dims, DsCompatible, ErrorOptions, ErrorOptionsWithWarn, InterpOptions, JoinOptions, PadModeOptions, PadReflectOptions, QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, SideOptions, T_ChunkDimFreq, T_Xarray
+ from xarray.core.types import (
+ CFCalendar,
+ CoarsenBoundaryOptions,
+ CombineAttrsOptions,
+ CompatOptions,
+ DataVars,
+ DatetimeLike,
+ DatetimeUnitOptions,
+ Dims,
+ DsCompatible,
+ ErrorOptions,
+ ErrorOptionsWithWarn,
+ InterpOptions,
+ JoinOptions,
+ PadModeOptions,
+ PadReflectOptions,
+ QueryEngineOptions,
+ QueryParserOptions,
+ ReindexMethodOptions,
+ SideOptions,
+ T_ChunkDimFreq,
+ T_Xarray,
+ )
from xarray.core.weighted import DatasetWeighted
from xarray.groupers import Grouper, Resampler
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
-_DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute',
- 'second', 'microsecond', 'nanosecond', 'date', 'time', 'dayofyear',
- 'weekofyear', 'dayofweek', 'quarter']
-def _get_virtual_variable(variables, key: Hashable, dim_sizes: (Mapping |
- None)=None) ->tuple[Hashable, Hashable, Variable]:
+# list of attributes of pd.DatetimeIndex that are ndarrays of time info
+_DATETIMEINDEX_COMPONENTS = [
+ "year",
+ "month",
+ "day",
+ "hour",
+ "minute",
+ "second",
+ "microsecond",
+ "nanosecond",
+ "date",
+ "time",
+ "dayofyear",
+ "weekofyear",
+ "dayofweek",
+ "quarter",
+]
+
+
+def _get_virtual_variable(
+ variables, key: Hashable, dim_sizes: Mapping | None = None
+) -> tuple[Hashable, Hashable, Variable]:
"""Get a virtual variable (e.g., 'time.year') from a dict of xarray.Variable
objects (if possible)
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ if dim_sizes is None:
+ dim_sizes = {}
+
+ if key in dim_sizes:
+ data = pd.Index(range(dim_sizes[key]), name=key)
+ variable = IndexVariable((key,), data)
+ return key, key, variable
+
+ if not isinstance(key, str):
+ raise KeyError(key)
+
+ split_key = key.split(".", 1)
+ if len(split_key) != 2:
+ raise KeyError(key)
+
+ ref_name, var_name = split_key
+ ref_var = variables[ref_name]
+
+ if _contains_datetime_like_objects(ref_var):
+ ref_var = DataArray(ref_var)
+ data = getattr(ref_var.dt, var_name).data
+ else:
+ data = getattr(ref_var, var_name).data
+ virtual_var = Variable(ref_var.dims, data)
+
+ return ref_name, var_name, virtual_var
def _get_chunk(var: Variable, chunks, chunkmanager: ChunkManagerEntrypoint):
"""
Return map from each dim to chunk sizes, accounting for backend's preferred chunks.
"""
- pass
+
+ if isinstance(var, IndexVariable):
+ return {}
+ dims = var.dims
+ shape = var.shape
+
+ # Determine the explicit requested chunks.
+ preferred_chunks = var.encoding.get("preferred_chunks", {})
+ preferred_chunk_shape = tuple(
+ preferred_chunks.get(dim, size) for dim, size in zip(dims, shape)
+ )
+ if isinstance(chunks, Number) or (chunks == "auto"):
+ chunks = dict.fromkeys(dims, chunks)
+ chunk_shape = tuple(
+ chunks.get(dim, None) or preferred_chunk_sizes
+ for dim, preferred_chunk_sizes in zip(dims, preferred_chunk_shape)
+ )
+
+ chunk_shape = chunkmanager.normalize_chunks(
+ chunk_shape, shape=shape, dtype=var.dtype, previous_chunks=preferred_chunk_shape
+ )
+
+ # Warn where requested chunks break preferred chunks, provided that the variable
+ # contains data.
+ if var.size:
+ for dim, size, chunk_sizes in zip(dims, shape, chunk_shape):
+ try:
+ preferred_chunk_sizes = preferred_chunks[dim]
+ except KeyError:
+ continue
+ # Determine the stop indices of the preferred chunks, but omit the last stop
+ # (equal to the dim size). In particular, assume that when a sequence
+ # expresses the preferred chunks, the sequence sums to the size.
+ preferred_stops = (
+ range(preferred_chunk_sizes, size, preferred_chunk_sizes)
+ if isinstance(preferred_chunk_sizes, int)
+ else itertools.accumulate(preferred_chunk_sizes[:-1])
+ )
+ # Gather any stop indices of the specified chunks that are not a stop index
+ # of a preferred chunk. Again, omit the last stop, assuming that it equals
+ # the dim size.
+ breaks = set(itertools.accumulate(chunk_sizes[:-1])).difference(
+ preferred_stops
+ )
+ if breaks:
+ warnings.warn(
+ "The specified chunks separate the stored chunks along "
+ f'dimension "{dim}" starting at index {min(breaks)}. This could '
+ "degrade performance. Instead, consider rechunking after loading."
+ )
+
+ return dict(zip(dims, chunk_shape))
+
+
+def _maybe_chunk(
+ name: Hashable,
+ var: Variable,
+ chunks: Mapping[Any, T_ChunkDim] | None,
+ token=None,
+ lock=None,
+ name_prefix: str = "xarray-",
+ overwrite_encoded_chunks: bool = False,
+ inline_array: bool = False,
+ chunked_array_type: str | ChunkManagerEntrypoint | None = None,
+ from_array_kwargs=None,
+) -> Variable:
+ from xarray.namedarray.daskmanager import DaskManager
+
+ if chunks is not None:
+ chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks}
+
+ if var.ndim:
+ chunked_array_type = guess_chunkmanager(
+ chunked_array_type
+ ) # coerce string to ChunkManagerEntrypoint type
+ if isinstance(chunked_array_type, DaskManager):
+ from dask.base import tokenize
+
+ # when rechunking by different amounts, make sure dask names change
+ # by providing chunks as an input to tokenize.
+ # subtle bugs result otherwise. see GH3350
+ # we use str() for speed, and use the name for the final array name on the next line
+ token2 = tokenize(token if token else var._data, str(chunks))
+ name2 = f"{name_prefix}{name}-{token2}"
+
+ from_array_kwargs = utils.consolidate_dask_from_array_kwargs(
+ from_array_kwargs,
+ name=name2,
+ lock=lock,
+ inline_array=inline_array,
+ )
+
+ var = var.chunk(
+ chunks,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=from_array_kwargs,
+ )
+
+ if overwrite_encoded_chunks and var.chunks is not None:
+ var.encoding["chunks"] = tuple(x[0] for x in var.chunks)
+ return var
+ else:
+ return var
-def as_dataset(obj: Any) ->Dataset:
+def as_dataset(obj: Any) -> Dataset:
"""Cast the given object to a Dataset.
Handles Datasets, DataArrays and dictionaries of variables. A new Dataset
object is only created if the provided object is not already one.
"""
- pass
+ if hasattr(obj, "to_dataset"):
+ obj = obj.to_dataset()
+ if not isinstance(obj, Dataset):
+ obj = Dataset(obj)
+ return obj
def _get_func_args(func, param_names):
"""Use `inspect.signature` to try accessing `func` args. Otherwise, ensure
they are provided by user.
"""
- pass
+ try:
+ func_args = inspect.signature(func).parameters
+ except ValueError:
+ func_args = {}
+ if not param_names:
+ raise ValueError(
+ "Unable to inspect `func` signature, and `param_names` was not provided."
+ )
+ if param_names:
+ params = param_names
+ else:
+ params = list(func_args)[1:]
+ if any(
+ [(p.kind in [p.VAR_POSITIONAL, p.VAR_KEYWORD]) for p in func_args.values()]
+ ):
+ raise ValueError(
+ "`param_names` must be provided because `func` takes variable length arguments."
+ )
+ return params, func_args
def _initialize_curvefit_params(params, p0, bounds, func_args):
"""Set initial guess and bounds for curvefit.
Priority: 1) passed args 2) func signature 3) scipy defaults
"""
- pass
+ from xarray.core.computation import where
+
+ def _initialize_feasible(lb, ub):
+ # Mimics functionality of scipy.optimize.minpack._initialize_feasible
+ lb_finite = np.isfinite(lb)
+ ub_finite = np.isfinite(ub)
+ p0 = where(
+ lb_finite,
+ where(
+ ub_finite,
+ 0.5 * (lb + ub), # both bounds finite
+ lb + 1, # lower bound finite, upper infinite
+ ),
+ where(
+ ub_finite,
+ ub - 1, # lower bound infinite, upper finite
+ 0, # both bounds infinite
+ ),
+ )
+ return p0
+
+ param_defaults = {p: 1 for p in params}
+ bounds_defaults = {p: (-np.inf, np.inf) for p in params}
+ for p in params:
+ if p in func_args and func_args[p].default is not func_args[p].empty:
+ param_defaults[p] = func_args[p].default
+ if p in bounds:
+ lb, ub = bounds[p]
+ bounds_defaults[p] = (lb, ub)
+ param_defaults[p] = where(
+ (param_defaults[p] < lb) | (param_defaults[p] > ub),
+ _initialize_feasible(lb, ub),
+ param_defaults[p],
+ )
+ if p in p0:
+ param_defaults[p] = p0[p]
+ return param_defaults, bounds_defaults
-def merge_data_and_coords(data_vars: DataVars, coords) ->_MergeResult:
+def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult:
"""Used in Dataset.__init__."""
- pass
+ if isinstance(coords, Coordinates):
+ coords = coords.copy()
+ else:
+ coords = create_coords_with_default_indexes(coords, data_vars)
+ # exclude coords from alignment (all variables in a Coordinates object should
+ # already be aligned together) and use coordinates' indexes to align data_vars
+ return merge_core(
+ [data_vars, coords],
+ compat="broadcast_equals",
+ join="outer",
+ explicit_coords=tuple(coords),
+ indexes=coords.xindexes,
+ priority_arg=1,
+ skip_align_args=[1],
+ )
-class DataVariables(Mapping[Any, 'DataArray']):
- __slots__ = '_dataset',
+
+class DataVariables(Mapping[Any, "DataArray"]):
+ __slots__ = ("_dataset",)
def __init__(self, dataset: Dataset):
self._dataset = dataset
- def __iter__(self) ->Iterator[Hashable]:
- return (key for key in self._dataset._variables if key not in self.
- _dataset._coord_names)
+ def __iter__(self) -> Iterator[Hashable]:
+ return (
+ key
+ for key in self._dataset._variables
+ if key not in self._dataset._coord_names
+ )
- def __len__(self) ->int:
- length = len(self._dataset._variables) - len(self._dataset._coord_names
- )
- assert length >= 0, 'something is wrong with Dataset._coord_names'
+ def __len__(self) -> int:
+ length = len(self._dataset._variables) - len(self._dataset._coord_names)
+ assert length >= 0, "something is wrong with Dataset._coord_names"
return length
- def __contains__(self, key: Hashable) ->bool:
- return (key in self._dataset._variables and key not in self.
- _dataset._coord_names)
+ def __contains__(self, key: Hashable) -> bool:
+ return key in self._dataset._variables and key not in self._dataset._coord_names
- def __getitem__(self, key: Hashable) ->DataArray:
+ def __getitem__(self, key: Hashable) -> DataArray:
if key not in self._dataset._coord_names:
return self._dataset[key]
raise KeyError(key)
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
return formatting.data_vars_repr(self)
@property
- def dtypes(self) ->Frozen[Hashable, np.dtype]:
+ def variables(self) -> Mapping[Hashable, Variable]:
+ all_variables = self._dataset.variables
+ return Frozen({k: all_variables[k] for k in self})
+
+ @property
+ def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from data variable names to dtypes.
Cannot be modified directly, but is updated when adding new variables.
@@ -146,35 +479,46 @@ class DataVariables(Mapping[Any, 'DataArray']):
--------
Dataset.dtype
"""
- pass
+ return self._dataset.dtypes
def _ipython_key_completions_(self):
"""Provide method for the key-autocompletions in IPython."""
- pass
+ return [
+ key
+ for key in self._dataset._ipython_key_completions_()
+ if key not in self._dataset._coord_names
+ ]
class _LocIndexer(Generic[T_Dataset]):
- __slots__ = 'dataset',
+ __slots__ = ("dataset",)
def __init__(self, dataset: T_Dataset):
self.dataset = dataset
- def __getitem__(self, key: Mapping[Any, Any]) ->T_Dataset:
+ def __getitem__(self, key: Mapping[Any, Any]) -> T_Dataset:
if not utils.is_dict_like(key):
- raise TypeError('can only lookup dictionaries from Dataset.loc')
+ raise TypeError("can only lookup dictionaries from Dataset.loc")
return self.dataset.sel(key)
- def __setitem__(self, key, value) ->None:
+ def __setitem__(self, key, value) -> None:
if not utils.is_dict_like(key):
raise TypeError(
- f'can only set locations defined by dictionaries from Dataset.loc. Got: {key}'
- )
+ "can only set locations defined by dictionaries from Dataset.loc."
+ f" Got: {key}"
+ )
+
+ # set new values
dim_indexers = map_index_queries(self.dataset, key).dim_indexers
self.dataset[dim_indexers] = value
-class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
- Mapping[Hashable, 'DataArray']):
+class Dataset(
+ DataWithCoords,
+ DatasetAggregations,
+ DatasetArithmetic,
+ Mapping[Hashable, "DataArray"],
+):
"""A multi-dimensional, in memory, array database.
A dataset resembles an in-memory representation of a NetCDF file,
@@ -322,6 +666,7 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
description: Weather related data.
"""
+
_attrs: dict[Hashable, Any] | None
_cache: dict[str, Any]
_coord_names: set[Hashable]
@@ -330,24 +675,45 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
_close: Callable[[], None] | None
_indexes: dict[Hashable, Index]
_variables: dict[Hashable, Variable]
- __slots__ = ('_attrs', '_cache', '_coord_names', '_dims', '_encoding',
- '_close', '_indexes', '_variables', '__weakref__')
- def __init__(self, data_vars: (DataVars | None)=None, coords: (Mapping[
- Any, Any] | None)=None, attrs: (Mapping[Any, Any] | None)=None) ->None:
+ __slots__ = (
+ "_attrs",
+ "_cache",
+ "_coord_names",
+ "_dims",
+ "_encoding",
+ "_close",
+ "_indexes",
+ "_variables",
+ "__weakref__",
+ )
+
+ def __init__(
+ self,
+ # could make a VariableArgs to use more generally, and refine these
+ # categories
+ data_vars: DataVars | None = None,
+ coords: Mapping[Any, Any] | None = None,
+ attrs: Mapping[Any, Any] | None = None,
+ ) -> None:
if data_vars is None:
data_vars = {}
if coords is None:
coords = {}
+
both_data_and_coords = set(data_vars) & set(coords)
if both_data_and_coords:
raise ValueError(
- f'variables {both_data_and_coords!r} are found in both data_vars and coords'
- )
+ f"variables {both_data_and_coords!r} are found in both data_vars and coords"
+ )
+
if isinstance(coords, Dataset):
coords = coords._variables
+
variables, coord_names, dims, indexes, _ = merge_data_and_coords(
- data_vars, coords)
+ data_vars, coords
+ )
+
self._attrs = dict(attrs) if attrs else None
self._close = None
self._encoding = None
@@ -356,18 +722,25 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
self._dims = dims
self._indexes = indexes
- def __eq__(self, other: DsCompatible) ->Self:
+ # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping
+ # related to https://github.com/python/mypy/issues/9319?
+ def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override]
return super().__eq__(other)
@classmethod
- def load_store(cls, store, decoder=None) ->Self:
+ def load_store(cls, store, decoder=None) -> Self:
"""Create a new dataset from the contents of a backends.*DataStore
object
"""
- pass
+ variables, attributes = store.load()
+ if decoder:
+ variables, attributes = decoder(variables, attributes)
+ obj = cls(variables, attrs=attributes)
+ obj.set_close(store.close)
+ return obj
@property
- def variables(self) ->Frozen[Hashable, Variable]:
+ def variables(self) -> Frozen[Hashable, Variable]:
"""Low level interface to Dataset contents as dict of Variable objects.
This ordered dictionary is frozen to prevent mutation that could
@@ -375,25 +748,44 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
constituting the Dataset, including both data variables and
coordinates.
"""
- pass
+ return Frozen(self._variables)
@property
- def attrs(self) ->dict[Any, Any]:
+ def attrs(self) -> dict[Any, Any]:
"""Dictionary of global attributes on this dataset"""
- pass
+ if self._attrs is None:
+ self._attrs = {}
+ return self._attrs
+
+ @attrs.setter
+ def attrs(self, value: Mapping[Any, Any]) -> None:
+ self._attrs = dict(value) if value else None
@property
- def encoding(self) ->dict[Any, Any]:
+ def encoding(self) -> dict[Any, Any]:
"""Dictionary of global encoding attributes on this dataset"""
- pass
+ if self._encoding is None:
+ self._encoding = {}
+ return self._encoding
+
+ @encoding.setter
+ def encoding(self, value: Mapping[Any, Any]) -> None:
+ self._encoding = dict(value)
+
+ def reset_encoding(self) -> Self:
+ warnings.warn(
+ "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead"
+ )
+ return self.drop_encoding()
- def drop_encoding(self) ->Self:
+ def drop_encoding(self) -> Self:
"""Return a new Dataset without encoding on the dataset or any of its
variables/coords."""
- pass
+ variables = {k: v.drop_encoding() for k, v in self.variables.items()}
+ return self._replace(variables=variables, encoding={})
@property
- def dims(self) ->Frozen[Hashable, int]:
+ def dims(self) -> Frozen[Hashable, int]:
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
@@ -408,10 +800,10 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.sizes
DataArray.dims
"""
- pass
+ return FrozenMappingWarningOnValuesAccess(self._dims)
@property
- def sizes(self) ->Frozen[Hashable, int]:
+ def sizes(self) -> Frozen[Hashable, int]:
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
@@ -423,10 +815,10 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
DataArray.sizes
"""
- pass
+ return Frozen(self._dims)
@property
- def dtypes(self) ->Frozen[Hashable, np.dtype]:
+ def dtypes(self) -> Frozen[Hashable, np.dtype]:
"""Mapping from data variable names to dtypes.
Cannot be modified directly, but is updated when adding new variables.
@@ -435,9 +827,15 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
DataArray.dtype
"""
- pass
+ return Frozen(
+ {
+ n: v.dtype
+ for n, v in self._variables.items()
+ if n not in self._coord_names
+ }
+ )
- def load(self, **kwargs) ->Self:
+ def load(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this dataset's data
from disk or a remote source into memory and return this dataset.
Unlike compute, the original dataset is modified and returned.
@@ -456,12 +854,34 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
dask.compute
"""
- pass
+ # access .data to coerce everything to numpy or dask arrays
+ lazy_data = {
+ k: v._data for k, v in self.variables.items() if is_chunked_array(v._data)
+ }
+ if lazy_data:
+ chunkmanager = get_chunked_array_type(*lazy_data.values())
- def __dask_tokenize__(self) ->object:
+ # evaluate all the chunked arrays simultaneously
+ evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute(
+ *lazy_data.values(), **kwargs
+ )
+
+ for k, data in zip(lazy_data, evaluated_data):
+ self.variables[k].data = data
+
+ # load everything else sequentially
+ for k, v in self.variables.items():
+ if k not in lazy_data:
+ v.load()
+
+ return self
+
+ def __dask_tokenize__(self) -> object:
from dask.base import normalize_token
- return normalize_token((type(self), self._variables, self.
- _coord_names, self._attrs or None))
+
+ return normalize_token(
+ (type(self), self._variables, self._coord_names, self._attrs or None)
+ )
def __dask_graph__(self):
graphs = {k: v.__dask_graph__() for k, v in self.variables.items()}
@@ -471,29 +891,44 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
else:
try:
from dask.highlevelgraph import HighLevelGraph
+
return HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict
+
return sharedict.merge(*graphs.values())
def __dask_keys__(self):
import dask
- return [v.__dask_keys__() for v in self.variables.values() if dask.
- is_dask_collection(v)]
+
+ return [
+ v.__dask_keys__()
+ for v in self.variables.values()
+ if dask.is_dask_collection(v)
+ ]
def __dask_layers__(self):
import dask
- return sum((v.__dask_layers__() for v in self.variables.values() if
- dask.is_dask_collection(v)), ())
+
+ return sum(
+ (
+ v.__dask_layers__()
+ for v in self.variables.values()
+ if dask.is_dask_collection(v)
+ ),
+ (),
+ )
@property
def __dask_optimize__(self):
import dask.array as da
+
return da.Array.__dask_optimize__
@property
def __dask_scheduler__(self):
import dask.array as da
+
return da.Array.__dask_scheduler__
def __dask_postcompute__(self):
@@ -502,7 +937,80 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
def __dask_postpersist__(self):
return self._dask_postpersist, ()
- def compute(self, **kwargs) ->Self:
+ def _dask_postcompute(self, results: Iterable[Variable]) -> Self:
+ import dask
+
+ variables = {}
+ results_iter = iter(results)
+
+ for k, v in self._variables.items():
+ if dask.is_dask_collection(v):
+ rebuild, args = v.__dask_postcompute__()
+ v = rebuild(next(results_iter), *args)
+ variables[k] = v
+
+ return type(self)._construct_direct(
+ variables,
+ self._coord_names,
+ self._dims,
+ self._attrs,
+ self._indexes,
+ self._encoding,
+ self._close,
+ )
+
+ def _dask_postpersist(
+ self, dsk: Mapping, *, rename: Mapping[str, str] | None = None
+ ) -> Self:
+ from dask import is_dask_collection
+ from dask.highlevelgraph import HighLevelGraph
+ from dask.optimization import cull
+
+ variables = {}
+
+ for k, v in self._variables.items():
+ if not is_dask_collection(v):
+ variables[k] = v
+ continue
+
+ if isinstance(dsk, HighLevelGraph):
+ # dask >= 2021.3
+ # __dask_postpersist__() was called by dask.highlevelgraph.
+ # Don't use dsk.cull(), as we need to prevent partial layers:
+ # https://github.com/dask/dask/issues/7137
+ layers = v.__dask_layers__()
+ if rename:
+ layers = [rename.get(k, k) for k in layers]
+ dsk2 = dsk.cull_layers(layers)
+ elif rename: # pragma: nocover
+ # At the moment of writing, this is only for forward compatibility.
+ # replace_name_in_key requires dask >= 2021.3.
+ from dask.base import flatten, replace_name_in_key
+
+ keys = [
+ replace_name_in_key(k, rename) for k in flatten(v.__dask_keys__())
+ ]
+ dsk2, _ = cull(dsk, keys)
+ else:
+ # __dask_postpersist__() was called by dask.optimize or dask.persist
+ dsk2, _ = cull(dsk, v.__dask_keys__())
+
+ rebuild, args = v.__dask_postpersist__()
+ # rename was added in dask 2021.3
+ kwargs = {"rename": rename} if rename else {}
+ variables[k] = rebuild(dsk2, *args, **kwargs)
+
+ return type(self)._construct_direct(
+ variables,
+ self._coord_names,
+ self._dims,
+ self._attrs,
+ self._indexes,
+ self._encoding,
+ self._close,
+ )
+
+ def compute(self, **kwargs) -> Self:
"""Manually trigger loading and/or computation of this dataset's data
from disk or a remote source into memory and return a new dataset.
Unlike load, the original dataset is left unaltered.
@@ -526,13 +1034,27 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
dask.compute
"""
- pass
+ new = self.copy(deep=False)
+ return new.load(**kwargs)
- def _persist_inplace(self, **kwargs) ->Self:
+ def _persist_inplace(self, **kwargs) -> Self:
"""Persist all Dask arrays in memory"""
- pass
+ # access .data to coerce everything to numpy or dask arrays
+ lazy_data = {
+ k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data)
+ }
+ if lazy_data:
+ import dask
+
+ # evaluate all the dask arrays simultaneously
+ evaluated_data = dask.persist(*lazy_data.values(), **kwargs)
+
+ for k, data in zip(lazy_data, evaluated_data):
+ self.variables[k].data = data
- def persist(self, **kwargs) ->Self:
+ return self
+
+ def persist(self, **kwargs) -> Self:
"""Trigger computation, keeping data as dask arrays
This operation can be used to trigger computation on underlying dask
@@ -556,23 +1078,47 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
dask.persist
"""
- pass
+ new = self.copy(deep=False)
+ return new._persist_inplace(**kwargs)
@classmethod
- def _construct_direct(cls, variables: dict[Any, Variable], coord_names:
- set[Hashable], dims: (dict[Any, int] | None)=None, attrs: (dict |
- None)=None, indexes: (dict[Any, Index] | None)=None, encoding: (
- dict | None)=None, close: (Callable[[], None] | None)=None) ->Self:
+ def _construct_direct(
+ cls,
+ variables: dict[Any, Variable],
+ coord_names: set[Hashable],
+ dims: dict[Any, int] | None = None,
+ attrs: dict | None = None,
+ indexes: dict[Any, Index] | None = None,
+ encoding: dict | None = None,
+ close: Callable[[], None] | None = None,
+ ) -> Self:
"""Shortcut around __init__ for internal use when we want to skip
costly validation
"""
- pass
-
- def _replace(self, variables: (dict[Hashable, Variable] | None)=None,
- coord_names: (set[Hashable] | None)=None, dims: (dict[Any, int] |
- None)=None, attrs: (dict[Hashable, Any] | None | Default)=_default,
- indexes: (dict[Hashable, Index] | None)=None, encoding: (dict |
- None | Default)=_default, inplace: bool=False) ->Self:
+ if dims is None:
+ dims = calculate_dimensions(variables)
+ if indexes is None:
+ indexes = {}
+ obj = object.__new__(cls)
+ obj._variables = variables
+ obj._coord_names = coord_names
+ obj._dims = dims
+ obj._indexes = indexes
+ obj._attrs = attrs
+ obj._close = close
+ obj._encoding = encoding
+ return obj
+
+ def _replace(
+ self,
+ variables: dict[Hashable, Variable] | None = None,
+ coord_names: set[Hashable] | None = None,
+ dims: dict[Any, int] | None = None,
+ attrs: dict[Hashable, Any] | None | Default = _default,
+ indexes: dict[Hashable, Index] | None = None,
+ encoding: dict | None | Default = _default,
+ inplace: bool = False,
+ ) -> Self:
"""Fastpath constructor for internal use.
Returns an object with optionally with replaced attributes.
@@ -581,40 +1127,146 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
dataset. It is up to the caller to ensure that they have the right type
and are not used elsewhere.
"""
- pass
-
- def _replace_with_new_dims(self, variables: dict[Hashable, Variable],
- coord_names: (set | None)=None, attrs: (dict[Hashable, Any] | None |
- Default)=_default, indexes: (dict[Hashable, Index] | None)=None,
- inplace: bool=False) ->Self:
+ if inplace:
+ if variables is not None:
+ self._variables = variables
+ if coord_names is not None:
+ self._coord_names = coord_names
+ if dims is not None:
+ self._dims = dims
+ if attrs is not _default:
+ self._attrs = attrs
+ if indexes is not None:
+ self._indexes = indexes
+ if encoding is not _default:
+ self._encoding = encoding
+ obj = self
+ else:
+ if variables is None:
+ variables = self._variables.copy()
+ if coord_names is None:
+ coord_names = self._coord_names.copy()
+ if dims is None:
+ dims = self._dims.copy()
+ if attrs is _default:
+ attrs = copy.copy(self._attrs)
+ if indexes is None:
+ indexes = self._indexes.copy()
+ if encoding is _default:
+ encoding = copy.copy(self._encoding)
+ obj = self._construct_direct(
+ variables, coord_names, dims, attrs, indexes, encoding
+ )
+ return obj
+
+ def _replace_with_new_dims(
+ self,
+ variables: dict[Hashable, Variable],
+ coord_names: set | None = None,
+ attrs: dict[Hashable, Any] | None | Default = _default,
+ indexes: dict[Hashable, Index] | None = None,
+ inplace: bool = False,
+ ) -> Self:
"""Replace variables with recalculated dimensions."""
- pass
-
- def _replace_vars_and_dims(self, variables: dict[Hashable, Variable],
- coord_names: (set | None)=None, dims: (dict[Hashable, int] | None)=
- None, attrs: (dict[Hashable, Any] | None | Default)=_default,
- inplace: bool=False) ->Self:
+ dims = calculate_dimensions(variables)
+ return self._replace(
+ variables, coord_names, dims, attrs, indexes, inplace=inplace
+ )
+
+ def _replace_vars_and_dims(
+ self,
+ variables: dict[Hashable, Variable],
+ coord_names: set | None = None,
+ dims: dict[Hashable, int] | None = None,
+ attrs: dict[Hashable, Any] | None | Default = _default,
+ inplace: bool = False,
+ ) -> Self:
"""Deprecated version of _replace_with_new_dims().
Unlike _replace_with_new_dims(), this method always recalculates
indexes from variables.
"""
- pass
-
- def _overwrite_indexes(self, indexes: Mapping[Hashable, Index],
- variables: (Mapping[Hashable, Variable] | None)=None,
- drop_variables: (list[Hashable] | None)=None, drop_indexes: (list[
- Hashable] | None)=None, rename_dims: (Mapping[Hashable, Hashable] |
- None)=None) ->Self:
+ if dims is None:
+ dims = calculate_dimensions(variables)
+ return self._replace(
+ variables, coord_names, dims, attrs, indexes=None, inplace=inplace
+ )
+
+ def _overwrite_indexes(
+ self,
+ indexes: Mapping[Hashable, Index],
+ variables: Mapping[Hashable, Variable] | None = None,
+ drop_variables: list[Hashable] | None = None,
+ drop_indexes: list[Hashable] | None = None,
+ rename_dims: Mapping[Hashable, Hashable] | None = None,
+ ) -> Self:
"""Maybe replace indexes.
This function may do a lot more depending on index query
results.
"""
- pass
+ if not indexes:
+ return self
+
+ if variables is None:
+ variables = {}
+ if drop_variables is None:
+ drop_variables = []
+ if drop_indexes is None:
+ drop_indexes = []
+
+ new_variables = self._variables.copy()
+ new_coord_names = self._coord_names.copy()
+ new_indexes = dict(self._indexes)
+
+ index_variables = {}
+ no_index_variables = {}
+ for name, var in variables.items():
+ old_var = self._variables.get(name)
+ if old_var is not None:
+ var.attrs.update(old_var.attrs)
+ var.encoding.update(old_var.encoding)
+ if name in indexes:
+ index_variables[name] = var
+ else:
+ no_index_variables[name] = var
+
+ for name in indexes:
+ new_indexes[name] = indexes[name]
+
+ for name, var in index_variables.items():
+ new_coord_names.add(name)
+ new_variables[name] = var
+
+ # append no-index variables at the end
+ for k in no_index_variables:
+ new_variables.pop(k)
+ new_variables.update(no_index_variables)
+
+ for name in drop_indexes:
+ new_indexes.pop(name)
+
+ for name in drop_variables:
+ new_variables.pop(name)
+ new_indexes.pop(name, None)
+ new_coord_names.remove(name)
+
+ replaced = self._replace(
+ variables=new_variables, coord_names=new_coord_names, indexes=new_indexes
+ )
+
+ if rename_dims:
+ # skip rename indexes: they should already have the right name(s)
+ dims = replaced._rename_dims(rename_dims)
+ new_variables, new_coord_names = replaced._rename_vars({}, rename_dims)
+ return replaced._replace(
+ variables=new_variables, coord_names=new_coord_names, dims=dims
+ )
+ else:
+ return replaced
- def copy(self, deep: bool=False, data: (DataVars | None)=None) ->Self:
+ def copy(self, deep: bool = False, data: DataVars | None = None) -> Self:
"""Returns a copy of this dataset.
If `deep=True`, a deep copy is made of each of the component variables.
@@ -711,15 +1363,58 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
pandas.DataFrame.copy
"""
- pass
+ return self._copy(deep=deep, data=data)
+
+ def _copy(
+ self,
+ deep: bool = False,
+ data: DataVars | None = None,
+ memo: dict[int, Any] | None = None,
+ ) -> Self:
+ if data is None:
+ data = {}
+ elif not utils.is_dict_like(data):
+ raise ValueError("Data must be dict-like")
+
+ if data:
+ var_keys = set(self.data_vars.keys())
+ data_keys = set(data.keys())
+ keys_not_in_vars = data_keys - var_keys
+ if keys_not_in_vars:
+ raise ValueError(
+ "Data must only contain variables in original "
+ f"dataset. Extra variables: {keys_not_in_vars}"
+ )
+ keys_missing_from_data = var_keys - data_keys
+ if keys_missing_from_data:
+ raise ValueError(
+ "Data must contain all variables in original "
+ f"dataset. Data is missing {keys_missing_from_data}"
+ )
+
+ indexes, index_vars = self.xindexes.copy_indexes(deep=deep)
+
+ variables = {}
+ for k, v in self._variables.items():
+ if k in index_vars:
+ variables[k] = index_vars[k]
+ else:
+ variables[k] = v._copy(deep=deep, data=data.get(k), memo=memo)
+
+ attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
+ encoding = (
+ copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
+ )
+
+ return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding)
- def __copy__(self) ->Self:
+ def __copy__(self) -> Self:
return self._copy(deep=False)
- def __deepcopy__(self, memo: (dict[int, Any] | None)=None) ->Self:
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
return self._copy(deep=True, memo=memo)
- def as_numpy(self) ->Self:
+ def as_numpy(self) -> Self:
"""
Coerces wrapped data and coordinates into numpy arrays, returning a Dataset.
@@ -728,84 +1423,152 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray.as_numpy
DataArray.to_numpy : Returns only the data as a numpy.ndarray object.
"""
- pass
+ numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()}
+ return self._replace(variables=numpy_variables)
- def _copy_listed(self, names: Iterable[Hashable]) ->Self:
+ def _copy_listed(self, names: Iterable[Hashable]) -> Self:
"""Create a new Dataset with the listed variables from this dataset and
the all relevant coordinates. Skips all validation.
"""
- pass
+ variables: dict[Hashable, Variable] = {}
+ coord_names = set()
+ indexes: dict[Hashable, Index] = {}
+
+ for name in names:
+ try:
+ variables[name] = self._variables[name]
+ except KeyError:
+ ref_name, var_name, var = _get_virtual_variable(
+ self._variables, name, self.sizes
+ )
+ variables[var_name] = var
+ if ref_name in self._coord_names or ref_name in self.dims:
+ coord_names.add(var_name)
+ if (var_name,) == var.dims:
+ index, index_vars = create_default_index_implicit(var, names)
+ indexes.update({k: index for k in index_vars})
+ variables.update(index_vars)
+ coord_names.update(index_vars)
+
+ needed_dims: OrderedSet[Hashable] = OrderedSet()
+ for v in variables.values():
+ needed_dims.update(v.dims)
+
+ dims = {k: self.sizes[k] for k in needed_dims}
- def _construct_dataarray(self, name: Hashable) ->DataArray:
+ # preserves ordering of coordinates
+ for k in self._variables:
+ if k not in self._coord_names:
+ continue
+
+ if set(self.variables[k].dims) <= needed_dims:
+ variables[k] = self._variables[k]
+ coord_names.add(k)
+
+ indexes.update(filter_indexes_from_coords(self._indexes, coord_names))
+
+ return self._replace(variables, coord_names, dims, indexes=indexes)
+
+ def _construct_dataarray(self, name: Hashable) -> DataArray:
"""Construct a DataArray by indexing this dataset"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ try:
+ variable = self._variables[name]
+ except KeyError:
+ _, name, variable = _get_virtual_variable(self._variables, name, self.sizes)
+
+ needed_dims = set(variable.dims)
+
+ coords: dict[Hashable, Variable] = {}
+ # preserve ordering
+ for k in self._variables:
+ if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
+ coords[k] = self._variables[k]
+
+ indexes = filter_indexes_from_coords(self._indexes, set(coords))
+
+ return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True)
@property
- def _attr_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for attribute-style access"""
- pass
+ yield from self._item_sources
+ yield self.attrs
@property
- def _item_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for key-completion"""
- pass
+ yield self.data_vars
+ yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords)
+
+ # virtual coordinates
+ yield HybridMappingProxy(keys=self.sizes, mapping=self)
- def __contains__(self, key: object) ->bool:
+ def __contains__(self, key: object) -> bool:
"""The 'in' operator will return true or false depending on whether
'key' is an array in the dataset or not.
"""
return key in self._variables
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self.data_vars)
- def __bool__(self) ->bool:
+ def __bool__(self) -> bool:
return bool(self.data_vars)
- def __iter__(self) ->Iterator[Hashable]:
+ def __iter__(self) -> Iterator[Hashable]:
return iter(self.data_vars)
+
if TYPE_CHECKING:
- __array__ = None
+ # needed because __getattr__ is returning Any and otherwise
+ # this class counts as part of the SupportsArray Protocol
+ __array__ = None # type: ignore[var-annotated,unused-ignore]
+
else:
def __array__(self, dtype=None, copy=None):
raise TypeError(
- 'cannot directly convert an xarray.Dataset into a numpy array. Instead, create an xarray.DataArray first, either with indexing on the Dataset or by invoking the `to_dataarray()` method.'
- )
+ "cannot directly convert an xarray.Dataset into a "
+ "numpy array. Instead, create an xarray.DataArray "
+ "first, either with indexing on the Dataset or by "
+ "invoking the `to_dataarray()` method."
+ )
@property
- def nbytes(self) ->int:
+ def nbytes(self) -> int:
"""
Total bytes consumed by the data arrays of all variables in this dataset.
If the backend array for any variable does not include ``nbytes``, estimates
the total bytes for that array based on the ``size`` and ``dtype``.
"""
- pass
+ return sum(v.nbytes for v in self.variables.values())
@property
- def loc(self) ->_LocIndexer[Self]:
+ def loc(self) -> _LocIndexer[Self]:
"""Attribute for location based indexing. Only supports __getitem__,
and only when the key is a dict of the form {dim: labels}.
"""
- pass
+ return _LocIndexer(self)
@overload
- def __getitem__(self, key: Hashable) ->DataArray:
- ...
+ def __getitem__(self, key: Hashable) -> DataArray: ...
+ # Mapping is Iterable
@overload
- def __getitem__(self, key: Iterable[Hashable]) ->Self:
- ...
+ def __getitem__(self, key: Iterable[Hashable]) -> Self: ...
- def __getitem__(self, key: (Mapping[Any, Any] | Hashable | Iterable[
- Hashable])) ->(Self | DataArray):
+ def __getitem__(
+ self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable]
+ ) -> Self | DataArray:
"""Access variables or coordinates of this dataset as a
:py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset.
Indexing with a list of names will return a new ``Dataset`` object.
"""
from xarray.core.formatting import shorten_list_repr
+
if utils.is_dict_like(key):
return self.isel(**key)
if utils.hashable(key):
@@ -813,14 +1576,16 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
return self._construct_dataarray(key)
except KeyError as e:
raise KeyError(
- f'No variable named {key!r}. Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}'
- ) from e
+ f"No variable named {key!r}. Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}"
+ ) from e
+
if utils.iterable_of_hashable(key):
return self._copy_listed(key)
- raise ValueError(f'Unsupported key-type {type(key)}')
+ raise ValueError(f"Unsupported key-type {type(key)}")
- def __setitem__(self, key: (Hashable | Iterable[Hashable] | Mapping),
- value: Any) ->None:
+ def __setitem__(
+ self, key: Hashable | Iterable[Hashable] | Mapping, value: Any
+ ) -> None:
"""Add an array to this dataset.
Multiple arrays can be added at the same time, in which case each of
the following operations is applied to the respective value.
@@ -840,8 +1605,11 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
variable.
"""
from xarray.core.dataarray import DataArray
+
if utils.is_dict_like(key):
+ # check for consistency and convert value to dataset
value = self._setitem_check(key, value)
+ # loop over dataset variables and set new values
processed = []
for name, var in self.items():
try:
@@ -850,37 +1618,43 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
except Exception as e:
if processed:
raise RuntimeError(
- f"""An error occurred while setting values of the variable '{name}'. The following variables have been successfully updated:
-{processed}"""
- ) from e
+ "An error occurred while setting values of the"
+ f" variable '{name}'. The following variables have"
+ f" been successfully updated:\n{processed}"
+ ) from e
else:
raise e
+
elif utils.hashable(key):
if isinstance(value, Dataset):
raise TypeError(
- 'Cannot assign a Dataset to a single key - only a DataArray or Variable object can be stored under a single key.'
- )
+ "Cannot assign a Dataset to a single key - only a DataArray or Variable "
+ "object can be stored under a single key."
+ )
self.update({key: value})
+
elif utils.iterable_of_hashable(key):
keylist = list(key)
if len(keylist) == 0:
- raise ValueError('Empty list of variables to be set')
+ raise ValueError("Empty list of variables to be set")
if len(keylist) == 1:
self.update({keylist[0]: value})
else:
if len(keylist) != len(value):
raise ValueError(
- f'Different lengths of variables to be set ({len(keylist)}) and data used as input for setting ({len(value)})'
- )
+ f"Different lengths of variables to be set "
+ f"({len(keylist)}) and data used as input for "
+ f"setting ({len(value)})"
+ )
if isinstance(value, Dataset):
self.update(dict(zip(keylist, value.data_vars.values())))
elif isinstance(value, DataArray):
- raise ValueError(
- 'Cannot assign single DataArray to multiple keys')
+ raise ValueError("Cannot assign single DataArray to multiple keys")
else:
self.update(dict(zip(keylist, value)))
+
else:
- raise ValueError(f'Unsupported key-type {type(key)}')
+ raise ValueError(f"Unsupported key-type {type(key)}")
def _setitem_check(self, key, value):
"""Consistency check for __setitem__
@@ -888,23 +1662,91 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
When assigning values to a subset of a Dataset, do consistency check beforehand
to avoid leaving the dataset in a partially updated state when an error occurs.
"""
- pass
+ from xarray.core.alignment import align
+ from xarray.core.dataarray import DataArray
+
+ if isinstance(value, Dataset):
+ missing_vars = [
+ name for name in value.data_vars if name not in self.data_vars
+ ]
+ if missing_vars:
+ raise ValueError(
+ f"Variables {missing_vars} in new values"
+ f" not available in original dataset:\n{self}"
+ )
+ elif not any([isinstance(value, t) for t in [DataArray, Number, str]]):
+ raise TypeError(
+ "Dataset assignment only accepts DataArrays, Datasets, and scalars."
+ )
+
+ new_value = Dataset()
+ for name, var in self.items():
+ # test indexing
+ try:
+ var_k = var[key]
+ except Exception as e:
+ raise ValueError(
+ f"Variable '{name}': indexer {key} not available"
+ ) from e
- def __delitem__(self, key: Hashable) ->None:
+ if isinstance(value, Dataset):
+ val = value[name]
+ else:
+ val = value
+
+ if isinstance(val, DataArray):
+ # check consistency of dimensions
+ for dim in val.dims:
+ if dim not in var_k.dims:
+ raise KeyError(
+ f"Variable '{name}': dimension '{dim}' appears in new values "
+ f"but not in the indexed original data"
+ )
+ dims = tuple(dim for dim in var_k.dims if dim in val.dims)
+ if dims != val.dims:
+ raise ValueError(
+ f"Variable '{name}': dimension order differs between"
+ f" original and new data:\n{dims}\nvs.\n{val.dims}"
+ )
+ else:
+ val = np.array(val)
+
+ # type conversion
+ new_value[name] = duck_array_ops.astype(val, dtype=var_k.dtype, copy=False)
+
+ # check consistency of dimension sizes and dimension coordinates
+ if isinstance(value, DataArray) or isinstance(value, Dataset):
+ align(self[key], value, join="exact", copy=False)
+
+ return new_value
+
+ def __delitem__(self, key: Hashable) -> None:
"""Remove a variable from this dataset."""
assert_no_index_corrupted(self.xindexes, {key})
+
if key in self._indexes:
del self._indexes[key]
del self._variables[key]
self._coord_names.discard(key)
self._dims = calculate_dimensions(self._variables)
- __hash__ = None
- def _all_compat(self, other: Self, compat_str: str) ->bool:
+ # mutable objects should not be hashable
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore[assignment]
+
+ def _all_compat(self, other: Self, compat_str: str) -> bool:
"""Helper function for equals and identical"""
- pass
- def broadcast_equals(self, other: Self) ->bool:
+ # some stores (e.g., scipy) do not seem to preserve order, so don't
+ # require matching order for equality
+ def compat(x: Variable, y: Variable) -> bool:
+ return getattr(x, compat_str)(y)
+
+ return self._coord_names == other._coord_names and utils.dict_equiv(
+ self._variables, other._variables, compat=compat
+ )
+
+ def broadcast_equals(self, other: Self) -> bool:
"""Two Datasets are broadcast equal if they are equal after
broadcasting all variables against each other.
@@ -966,9 +1808,12 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.identical
Dataset.broadcast
"""
- pass
+ try:
+ return self._all_compat(other, "broadcast_equals")
+ except (TypeError, AttributeError):
+ return False
- def equals(self, other: Self) ->bool:
+ def equals(self, other: Self) -> bool:
"""Two Datasets are equal if they have matching variables and
coordinates, all of which are equal.
@@ -1044,9 +1889,12 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.broadcast_equals
Dataset.identical
"""
- pass
+ try:
+ return self._all_compat(other, "equals")
+ except (TypeError, AttributeError):
+ return False
- def identical(self, other: Self) ->bool:
+ def identical(self, other: Self) -> bool:
"""Like equals, but also checks all dataset attributes and the
attributes on all variables and coordinates.
@@ -1115,10 +1963,15 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.broadcast_equals
Dataset.equals
"""
- pass
+ try:
+ return utils.dict_equiv(self.attrs, other.attrs) and self._all_compat(
+ other, "identical"
+ )
+ except (TypeError, AttributeError):
+ return False
@property
- def indexes(self) ->Indexes[pd.Index]:
+ def indexes(self) -> Indexes[pd.Index]:
"""Mapping of pandas.Index objects used for label based indexing.
Raises an error if this Dataset has indexes that cannot be coerced
@@ -1129,17 +1982,17 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.xindexes
"""
- pass
+ return self.xindexes.to_pandas_indexes()
@property
- def xindexes(self) ->Indexes[Index]:
+ def xindexes(self) -> Indexes[Index]:
"""Mapping of :py:class:`~xarray.indexes.Index` objects
used for label based indexing.
"""
- pass
+ return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes})
@property
- def coords(self) ->DatasetCoordinates:
+ def coords(self) -> DatasetCoordinates:
"""Mapping of :py:class:`~xarray.DataArray` objects corresponding to
coordinate variables.
@@ -1147,14 +2000,14 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Coordinates
"""
- pass
+ return DatasetCoordinates(self)
@property
- def data_vars(self) ->DataVariables:
+ def data_vars(self) -> DataVariables:
"""Dictionary of DataArray objects corresponding to data variables"""
- pass
+ return DataVariables(self)
- def set_coords(self, names: (Hashable | Iterable[Hashable])) ->Self:
+ def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self:
"""Given names of one or more variables, set them as coordinates
Parameters
@@ -1198,9 +2051,24 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.swap_dims
Dataset.assign_coords
"""
- pass
-
- def reset_coords(self, names: Dims=None, drop: bool=False) ->Self:
+ # TODO: allow inserting new coordinates with this method, like
+ # DataFrame.set_index?
+ # nb. check in self._variables, not self.data_vars to insure that the
+ # operation is idempotent
+ if isinstance(names, str) or not isinstance(names, Iterable):
+ names = [names]
+ else:
+ names = list(names)
+ self._assert_all_in_dataset(names)
+ obj = self.copy()
+ obj._coord_names.update(names)
+ return obj
+
+ def reset_coords(
+ self,
+ names: Dims = None,
+ drop: bool = False,
+ ) -> Self:
"""Given names of coordinates, reset them to become variables
Parameters
@@ -1273,18 +2141,108 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.set_coords
"""
- pass
-
- def dump_to_store(self, store: AbstractDataStore, **kwargs) ->None:
+ if names is None:
+ names = self._coord_names - set(self._indexes)
+ else:
+ if isinstance(names, str) or not isinstance(names, Iterable):
+ names = [names]
+ else:
+ names = list(names)
+ self._assert_all_in_dataset(names)
+ bad_coords = set(names) & set(self._indexes)
+ if bad_coords:
+ raise ValueError(
+ f"cannot remove index coordinates with reset_coords: {bad_coords}"
+ )
+ obj = self.copy()
+ obj._coord_names.difference_update(names)
+ if drop:
+ for name in names:
+ del obj._variables[name]
+ return obj
+
+ def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None:
"""Store dataset contents to a backends.*DataStore object."""
- pass
-
- def to_netcdf(self, path: (str | PathLike | None)=None, mode:
- NetcdfWriteModes='w', format: (T_NetcdfTypes | None)=None, group: (
- str | None)=None, engine: (T_NetcdfEngine | None)=None, encoding: (
- Mapping[Any, Mapping[str, Any]] | None)=None, unlimited_dims: (
- Iterable[Hashable] | None)=None, compute: bool=True, invalid_netcdf:
- bool=False) ->(bytes | Delayed | None):
+ from xarray.backends.api import dump_to_store
+
+ # TODO: rename and/or cleanup this method to make it more consistent
+ # with to_netcdf()
+ dump_to_store(self, store, **kwargs)
+
+ # path=None writes to bytes
+ @overload
+ def to_netcdf(
+ self,
+ path: None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Any, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ invalid_netcdf: bool = False,
+ ) -> bytes: ...
+
+ # compute=False returns dask.Delayed
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Any, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ *,
+ compute: Literal[False],
+ invalid_netcdf: bool = False,
+ ) -> Delayed: ...
+
+ # default return None
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Any, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: Literal[True] = True,
+ invalid_netcdf: bool = False,
+ ) -> None: ...
+
+ # if compute cannot be evaluated at type check time
+ # we may get back either Delayed or None
+ @overload
+ def to_netcdf(
+ self,
+ path: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Any, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ invalid_netcdf: bool = False,
+ ) -> Delayed | None: ...
+
+ def to_netcdf(
+ self,
+ path: str | PathLike | None = None,
+ mode: NetcdfWriteModes = "w",
+ format: T_NetcdfTypes | None = None,
+ group: str | None = None,
+ engine: T_NetcdfEngine | None = None,
+ encoding: Mapping[Any, Mapping[str, Any]] | None = None,
+ unlimited_dims: Iterable[Hashable] | None = None,
+ compute: bool = True,
+ invalid_netcdf: bool = False,
+ ) -> bytes | Delayed | None:
"""Write dataset contents to a netCDF file.
Parameters
@@ -1299,7 +2257,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Write ('w') or append ('a') mode. If mode='w', any existing file at
this location will be overwritten. If mode='a', existing variables
will be overwritten.
- format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC"}, optional
+ format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \
+ "NETCDF3_CLASSIC"}, optional
File format for the resulting netCDF file:
* NETCDF4: Data is stored in an HDF5 file, using netCDF4 API
@@ -1363,18 +2322,87 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
DataArray.to_netcdf
"""
- pass
-
- def to_zarr(self, store: (MutableMapping | str | PathLike[str] | None)=
- None, chunk_store: (MutableMapping | str | PathLike | None)=None,
- mode: (ZarrWriteModes | None)=None, synchronizer=None, group: (str |
- None)=None, encoding: (Mapping | None)=None, *, compute: bool=True,
- consolidated: (bool | None)=None, append_dim: (Hashable | None)=
- None, region: (Mapping[str, slice | Literal['auto']] | Literal[
- 'auto'] | None)=None, safe_chunks: bool=True, storage_options: (
- dict[str, str] | None)=None, zarr_version: (int | None)=None,
- write_empty_chunks: (bool | None)=None, chunkmanager_store_kwargs:
- (dict[str, Any] | None)=None) ->(ZarrStore | Delayed):
+ if encoding is None:
+ encoding = {}
+ from xarray.backends.api import to_netcdf
+
+ return to_netcdf( # type: ignore # mypy cannot resolve the overloads:(
+ self,
+ path,
+ mode=mode,
+ format=format,
+ group=group,
+ engine=engine,
+ encoding=encoding,
+ unlimited_dims=unlimited_dims,
+ compute=compute,
+ multifile=False,
+ invalid_netcdf=invalid_netcdf,
+ )
+
+ # compute=True (default) returns ZarrStore
+ @overload
+ def to_zarr(
+ self,
+ store: MutableMapping | str | PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: Literal[True] = True,
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ write_empty_chunks: bool | None = None,
+ chunkmanager_store_kwargs: dict[str, Any] | None = None,
+ ) -> ZarrStore: ...
+
+ # compute=False returns dask.Delayed
+ @overload
+ def to_zarr(
+ self,
+ store: MutableMapping | str | PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: Literal[False],
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ write_empty_chunks: bool | None = None,
+ chunkmanager_store_kwargs: dict[str, Any] | None = None,
+ ) -> Delayed: ...
+
+ def to_zarr(
+ self,
+ store: MutableMapping | str | PathLike[str] | None = None,
+ chunk_store: MutableMapping | str | PathLike | None = None,
+ mode: ZarrWriteModes | None = None,
+ synchronizer=None,
+ group: str | None = None,
+ encoding: Mapping | None = None,
+ *,
+ compute: bool = True,
+ consolidated: bool | None = None,
+ append_dim: Hashable | None = None,
+ region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
+ safe_chunks: bool = True,
+ storage_options: dict[str, str] | None = None,
+ zarr_version: int | None = None,
+ write_empty_chunks: bool | None = None,
+ chunkmanager_store_kwargs: dict[str, Any] | None = None,
+ ) -> ZarrStore | Delayed:
"""Write dataset contents to a zarr group.
Zarr chunks are determined in the following way:
@@ -1520,12 +2548,36 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
:ref:`io.zarr`
The I/O user guide, with more details and examples.
"""
- pass
-
- def __repr__(self) ->str:
+ from xarray.backends.api import to_zarr
+
+ return to_zarr( # type: ignore[call-overload,misc]
+ self,
+ store=store,
+ chunk_store=chunk_store,
+ storage_options=storage_options,
+ mode=mode,
+ synchronizer=synchronizer,
+ group=group,
+ encoding=encoding,
+ compute=compute,
+ consolidated=consolidated,
+ append_dim=append_dim,
+ region=region,
+ safe_chunks=safe_chunks,
+ zarr_version=zarr_version,
+ write_empty_chunks=write_empty_chunks,
+ chunkmanager_store_kwargs=chunkmanager_store_kwargs,
+ )
+
+ def __repr__(self) -> str:
return formatting.dataset_repr(self)
- def info(self, buf: (IO | None)=None) ->None:
+ def _repr_html_(self) -> str:
+ if OPTIONS["display_style"] == "text":
+ return f"<pre>{escape(repr(self))}</pre>"
+ return formatting_html.dataset_repr(self)
+
+ def info(self, buf: IO | None = None) -> None:
"""
Concise summary of a Dataset variables and attributes.
@@ -1539,10 +2591,29 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
pandas.DataFrame.assign
ncdump : netCDF's ncdump
"""
- pass
+ if buf is None: # pragma: no cover
+ buf = sys.stdout
+
+ lines = []
+ lines.append("xarray.Dataset {")
+ lines.append("dimensions:")
+ for name, size in self.sizes.items():
+ lines.append(f"\t{name} = {size} ;")
+ lines.append("\nvariables:")
+ for name, da in self.variables.items():
+ dims = ", ".join(map(str, da.dims))
+ lines.append(f"\t{da.dtype} {name}({dims}) ;")
+ for k, v in da.attrs.items():
+ lines.append(f"\t\t{name}:{k} = {v} ;")
+ lines.append("\n// global attributes:")
+ for k, v in self.attrs.items():
+ lines.append(f"\t:{k} = {v} ;")
+ lines.append("}")
+
+ buf.write("\n".join(lines))
@property
- def chunks(self) ->Mapping[Hashable, tuple[int, ...]]:
+ def chunks(self) -> Mapping[Hashable, tuple[int, ...]]:
"""
Mapping from dimension names to block lengths for this dataset's data, or None if
the underlying data is not a dask array.
@@ -1556,10 +2627,10 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.chunksizes
xarray.unify_chunks
"""
- pass
+ return get_chunksizes(self.variables.values())
@property
- def chunksizes(self) ->Mapping[Hashable, tuple[int, ...]]:
+ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]:
"""
Mapping from dimension names to block lengths for this dataset's data, or None if
the underlying data is not a dask array.
@@ -1573,12 +2644,19 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.chunks
xarray.unify_chunks
"""
- pass
-
- def chunk(self, chunks: T_ChunksFreq={}, name_prefix: str='xarray-',
- token: (str | None)=None, lock: bool=False, inline_array: bool=
- False, chunked_array_type: (str | ChunkManagerEntrypoint | None)=
- None, from_array_kwargs=None, **chunks_kwargs: T_ChunkDimFreq) ->Self:
+ return get_chunksizes(self.variables.values())
+
+ def chunk(
+ self,
+ chunks: T_ChunksFreq = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667)
+ name_prefix: str = "xarray-",
+ token: str | None = None,
+ lock: bool = False,
+ inline_array: bool = False,
+ chunked_array_type: str | ChunkManagerEntrypoint | None = None,
+ from_array_kwargs=None,
+ **chunks_kwargs: T_ChunkDimFreq,
+ ) -> Self:
"""Coerce all arrays in this dataset into dask arrays with the given
chunks.
@@ -1630,23 +2708,155 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
xarray.unify_chunks
dask.array.from_array
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.groupers import TimeResampler
- def _validate_indexers(self, indexers: Mapping[Any, Any], missing_dims:
- ErrorOptionsWithWarn='raise') ->Iterator[tuple[Hashable, int |
- slice | np.ndarray | Variable]]:
+ if chunks is None and not chunks_kwargs:
+ warnings.warn(
+ "None value for 'chunks' is deprecated. "
+ "It will raise an error in the future. Use instead '{}'",
+ category=DeprecationWarning,
+ )
+ chunks = {}
+ chunks_mapping: Mapping[Any, Any]
+ if not isinstance(chunks, Mapping) and chunks is not None:
+ if isinstance(chunks, (tuple, list)):
+ utils.emit_user_level_warning(
+ "Supplying chunks as dimension-order tuples is deprecated. "
+ "It will raise an error in the future. Instead use a dict with dimensions as keys.",
+ category=DeprecationWarning,
+ )
+ chunks_mapping = dict.fromkeys(self.dims, chunks)
+ else:
+ chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
+
+ bad_dims = chunks_mapping.keys() - self.sizes.keys()
+ if bad_dims:
+ raise ValueError(
+ f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
+ )
+
+ def _resolve_frequency(
+ name: Hashable, resampler: TimeResampler
+ ) -> tuple[int, ...]:
+ variable = self._variables.get(name, None)
+ if variable is None:
+ raise ValueError(
+ f"Cannot chunk by resampler {resampler!r} for virtual variables."
+ )
+ elif not _contains_datetime_like_objects(variable):
+ raise ValueError(
+ f"chunks={resampler!r} only supported for datetime variables. "
+ f"Received variable {name!r} with dtype {variable.dtype!r} instead."
+ )
+
+ assert variable.ndim == 1
+ chunks: tuple[int, ...] = tuple(
+ DataArray(
+ np.ones(variable.shape, dtype=int),
+ dims=(name,),
+ coords={name: variable},
+ )
+ .resample({name: resampler})
+ .sum()
+ .data.tolist()
+ )
+ return chunks
+
+ chunks_mapping_ints: Mapping[Any, T_ChunkDim] = {
+ name: (
+ _resolve_frequency(name, chunks)
+ if isinstance(chunks, TimeResampler)
+ else chunks
+ )
+ for name, chunks in chunks_mapping.items()
+ }
+
+ chunkmanager = guess_chunkmanager(chunked_array_type)
+ if from_array_kwargs is None:
+ from_array_kwargs = {}
+
+ variables = {
+ k: _maybe_chunk(
+ k,
+ v,
+ chunks_mapping_ints,
+ token,
+ lock,
+ name_prefix,
+ inline_array=inline_array,
+ chunked_array_type=chunkmanager,
+ from_array_kwargs=from_array_kwargs.copy(),
+ )
+ for k, v in self.variables.items()
+ }
+ return self._replace(variables)
+
+ def _validate_indexers(
+ self, indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise"
+ ) -> Iterator[tuple[Hashable, int | slice | np.ndarray | Variable]]:
"""Here we make sure
+ indexer has a valid keys
+ indexer is in a valid data type
+ string indexers are cast to the appropriate date type if the
associated index is a DatetimeIndex or CFTimeIndex
"""
- pass
+ from xarray.coding.cftimeindex import CFTimeIndex
+ from xarray.core.dataarray import DataArray
- def _validate_interp_indexers(self, indexers: Mapping[Any, Any]
- ) ->Iterator[tuple[Hashable, Variable]]:
+ indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)
+
+ # all indexers should be int, slice, np.ndarrays, or Variable
+ for k, v in indexers.items():
+ if isinstance(v, (int, slice, Variable)):
+ yield k, v
+ elif isinstance(v, DataArray):
+ yield k, v.variable
+ elif isinstance(v, tuple):
+ yield k, as_variable(v)
+ elif isinstance(v, Dataset):
+ raise TypeError("cannot use a Dataset as an indexer")
+ elif isinstance(v, Sequence) and len(v) == 0:
+ yield k, np.empty((0,), dtype="int64")
+ else:
+ if not is_duck_array(v):
+ v = np.asarray(v)
+
+ if v.dtype.kind in "US":
+ index = self._indexes[k].to_pandas_index()
+ if isinstance(index, pd.DatetimeIndex):
+ v = duck_array_ops.astype(v, dtype="datetime64[ns]")
+ elif isinstance(index, CFTimeIndex):
+ v = _parse_array_of_cftime_strings(v, index.date_type)
+
+ if v.ndim > 1:
+ raise IndexError(
+ "Unlabeled multi-dimensional array cannot be "
+ f"used for indexing: {k}"
+ )
+ yield k, v
+
+ def _validate_interp_indexers(
+ self, indexers: Mapping[Any, Any]
+ ) -> Iterator[tuple[Hashable, Variable]]:
"""Variant of _validate_indexers to be used for interpolation"""
- pass
+ for k, v in self._validate_indexers(indexers):
+ if isinstance(v, Variable):
+ if v.ndim == 1:
+ yield k, v.to_index_variable()
+ else:
+ yield k, v
+ elif isinstance(v, int):
+ yield k, Variable((), v, attrs=self.coords[k].attrs)
+ elif isinstance(v, np.ndarray):
+ if v.ndim == 0:
+ yield k, Variable((), v, attrs=self.coords[k].attrs)
+ elif v.ndim == 1:
+ yield k, IndexVariable((k,), v, attrs=self.coords[k].attrs)
+ else:
+ raise AssertionError() # Already tested by _validate_indexers
+ else:
+ raise TypeError(type(v))
def _get_indexers_coords_and_indexes(self, indexers):
"""Extract coordinates and indexes from indexers.
@@ -1654,11 +2864,45 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Only coordinate with a name different from any of self.variables will
be attached.
"""
- pass
+ from xarray.core.dataarray import DataArray
- def isel(self, indexers: (Mapping[Any, Any] | None)=None, drop: bool=
- False, missing_dims: ErrorOptionsWithWarn='raise', **
- indexers_kwargs: Any) ->Self:
+ coords_list = []
+ for k, v in indexers.items():
+ if isinstance(v, DataArray):
+ if v.dtype.kind == "b":
+ if v.ndim != 1: # we only support 1-d boolean array
+ raise ValueError(
+ f"{v.ndim:d}d-boolean array is used for indexing along "
+ f"dimension {k!r}, but only 1d boolean arrays are "
+ "supported."
+ )
+ # Make sure in case of boolean DataArray, its
+ # coordinate also should be indexed.
+ v_coords = v[v.values.nonzero()[0]].coords
+ else:
+ v_coords = v.coords
+ coords_list.append(v_coords)
+
+ # we don't need to call align() explicitly or check indexes for
+ # alignment, because merge_variables already checks for exact alignment
+ # between dimension coordinates
+ coords, indexes = merge_coordinates_without_align(coords_list)
+ assert_coordinate_consistent(self, coords)
+
+ # silently drop the conflicted variables.
+ attached_coords = {k: v for k, v in coords.items() if k not in self._variables}
+ attached_indexes = {
+ k: v for k, v in indexes.items() if k not in self._variables
+ }
+ return attached_coords, attached_indexes
+
+ def isel(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ drop: bool = False,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Returns a new dataset with each array indexed along the specified
dimension(s).
@@ -1769,11 +3013,93 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Tutorial material on basics of indexing
"""
- pass
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
+ if any(is_fancy_indexer(idx) for idx in indexers.values()):
+ return self._isel_fancy(indexers, drop=drop, missing_dims=missing_dims)
+
+ # Much faster algorithm for when all indexers are ints, slices, one-dimensional
+ # lists, or zero or one-dimensional np.ndarray's
+ indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)
+
+ variables = {}
+ dims: dict[Hashable, int] = {}
+ coord_names = self._coord_names.copy()
+
+ indexes, index_variables = isel_indexes(self.xindexes, indexers)
- def sel(self, indexers: (Mapping[Any, Any] | None)=None, method: (str |
- None)=None, tolerance: (int | float | Iterable[int | float] | None)
- =None, drop: bool=False, **indexers_kwargs: Any) ->Self:
+ for name, var in self._variables.items():
+ # preserve variable order
+ if name in index_variables:
+ var = index_variables[name]
+ else:
+ var_indexers = {k: v for k, v in indexers.items() if k in var.dims}
+ if var_indexers:
+ var = var.isel(var_indexers)
+ if drop and var.ndim == 0 and name in coord_names:
+ coord_names.remove(name)
+ continue
+ variables[name] = var
+ dims.update(zip(var.dims, var.shape))
+
+ return self._construct_direct(
+ variables=variables,
+ coord_names=coord_names,
+ dims=dims,
+ attrs=self._attrs,
+ indexes=indexes,
+ encoding=self._encoding,
+ close=self._close,
+ )
+
+ def _isel_fancy(
+ self,
+ indexers: Mapping[Any, Any],
+ *,
+ drop: bool,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ ) -> Self:
+ valid_indexers = dict(self._validate_indexers(indexers, missing_dims))
+
+ variables: dict[Hashable, Variable] = {}
+ indexes, index_variables = isel_indexes(self.xindexes, valid_indexers)
+
+ for name, var in self.variables.items():
+ if name in index_variables:
+ new_var = index_variables[name]
+ else:
+ var_indexers = {
+ k: v for k, v in valid_indexers.items() if k in var.dims
+ }
+ if var_indexers:
+ new_var = var.isel(indexers=var_indexers)
+ # drop scalar coordinates
+ # https://github.com/pydata/xarray/issues/6554
+ if name in self.coords and drop and new_var.ndim == 0:
+ continue
+ else:
+ new_var = var.copy(deep=False)
+ if name not in indexes:
+ new_var = new_var.to_base_variable()
+ variables[name] = new_var
+
+ coord_names = self._coord_names & variables.keys()
+ selected = self._replace_with_new_dims(variables, coord_names, indexes)
+
+ # Extract coordinates from indexers
+ coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
+ variables.update(coord_vars)
+ indexes.update(new_indexes)
+ coord_names = self._coord_names & variables.keys() | coord_vars.keys()
+ return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
+
+ def sel(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ method: str | None = None,
+ tolerance: int | float | Iterable[int | float] | None = None,
+ drop: bool = False,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Returns a new dataset with each array indexed by tick labels
along the specified dimension(s).
@@ -1841,10 +3167,29 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Tutorial material on basics of indexing
"""
- pass
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "sel")
+ query_results = map_index_queries(
+ self, indexers=indexers, method=method, tolerance=tolerance
+ )
- def head(self, indexers: (Mapping[Any, int] | int | None)=None, **
- indexers_kwargs: Any) ->Self:
+ if drop:
+ no_scalar_variables = {}
+ for k, v in query_results.variables.items():
+ if v.dims:
+ no_scalar_variables[k] = v
+ else:
+ if k in self._coord_names:
+ query_results.drop_coords.append(k)
+ query_results.variables = no_scalar_variables
+
+ result = self.isel(indexers=query_results.dim_indexers, drop=drop)
+ return result._overwrite_indexes(*query_results.as_tuple()[1:])
+
+ def head(
+ self,
+ indexers: Mapping[Any, int] | int | None = None,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Returns a new dataset with the first `n` values of each array
for the specified dimension(s).
@@ -1908,10 +3253,33 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.thin
DataArray.head
"""
- pass
-
- def tail(self, indexers: (Mapping[Any, int] | int | None)=None, **
- indexers_kwargs: Any) ->Self:
+ if not indexers_kwargs:
+ if indexers is None:
+ indexers = 5
+ if not isinstance(indexers, int) and not is_dict_like(indexers):
+ raise TypeError("indexers must be either dict-like or a single integer")
+ if isinstance(indexers, int):
+ indexers = {dim: indexers for dim in self.dims}
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "head")
+ for k, v in indexers.items():
+ if not isinstance(v, int):
+ raise TypeError(
+ "expected integer type indexer for "
+ f"dimension {k!r}, found {type(v)!r}"
+ )
+ elif v < 0:
+ raise ValueError(
+ "expected positive integer as indexer "
+ f"for dimension {k!r}, found {v}"
+ )
+ indexers_slices = {k: slice(val) for k, val in indexers.items()}
+ return self.isel(indexers_slices)
+
+ def tail(
+ self,
+ indexers: Mapping[Any, int] | int | None = None,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Returns a new dataset with the last `n` values of each array
for the specified dimension(s).
@@ -1973,10 +3341,36 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.thin
DataArray.tail
"""
- pass
-
- def thin(self, indexers: (Mapping[Any, int] | int | None)=None, **
- indexers_kwargs: Any) ->Self:
+ if not indexers_kwargs:
+ if indexers is None:
+ indexers = 5
+ if not isinstance(indexers, int) and not is_dict_like(indexers):
+ raise TypeError("indexers must be either dict-like or a single integer")
+ if isinstance(indexers, int):
+ indexers = {dim: indexers for dim in self.dims}
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "tail")
+ for k, v in indexers.items():
+ if not isinstance(v, int):
+ raise TypeError(
+ "expected integer type indexer for "
+ f"dimension {k!r}, found {type(v)!r}"
+ )
+ elif v < 0:
+ raise ValueError(
+ "expected positive integer as indexer "
+ f"for dimension {k!r}, found {v}"
+ )
+ indexers_slices = {
+ k: slice(-val, None) if val != 0 else slice(val)
+ for k, val in indexers.items()
+ }
+ return self.isel(indexers_slices)
+
+ def thin(
+ self,
+ indexers: Mapping[Any, int] | int | None = None,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Returns a new dataset with each array indexed along every `n`-th
value for the specified dimension(s)
@@ -2032,10 +3426,36 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.tail
DataArray.thin
"""
- pass
-
- def broadcast_like(self, other: T_DataArrayOrSet, exclude: (Iterable[
- Hashable] | None)=None) ->Self:
+ if (
+ not indexers_kwargs
+ and not isinstance(indexers, int)
+ and not is_dict_like(indexers)
+ ):
+ raise TypeError("indexers must be either dict-like or a single integer")
+ if isinstance(indexers, int):
+ indexers = {dim: indexers for dim in self.dims}
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "thin")
+ for k, v in indexers.items():
+ if not isinstance(v, int):
+ raise TypeError(
+ "expected integer type indexer for "
+ f"dimension {k!r}, found {type(v)!r}"
+ )
+ elif v < 0:
+ raise ValueError(
+ "expected positive integer as indexer "
+ f"for dimension {k!r}, found {v}"
+ )
+ elif v == 0:
+ raise ValueError("step cannot be zero")
+ indexers_slices = {k: slice(None, None, val) for k, val in indexers.items()}
+ return self.isel(indexers_slices)
+
+ def broadcast_like(
+ self,
+ other: T_DataArrayOrSet,
+ exclude: Iterable[Hashable] | None = None,
+ ) -> Self:
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]
@@ -2047,19 +3467,85 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dimensions that must not be broadcasted
"""
- pass
-
- def _reindex_callback(self, aligner: alignment.Aligner,
- dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable,
- Variable], indexes: dict[Hashable, Index], fill_value: Any,
- exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable]
- ) ->Self:
+ if exclude is None:
+ exclude = set()
+ else:
+ exclude = set(exclude)
+ args = align(other, self, join="outer", copy=False, exclude=exclude)
+
+ dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude)
+
+ return _broadcast_helper(args[1], exclude, dims_map, common_coords)
+
+ def _reindex_callback(
+ self,
+ aligner: alignment.Aligner,
+ dim_pos_indexers: dict[Hashable, Any],
+ variables: dict[Hashable, Variable],
+ indexes: dict[Hashable, Index],
+ fill_value: Any,
+ exclude_dims: frozenset[Hashable],
+ exclude_vars: frozenset[Hashable],
+ ) -> Self:
"""Callback called from ``Aligner`` to create a new reindexed Dataset."""
- pass
- def reindex_like(self, other: T_Xarray, method: ReindexMethodOptions=
- None, tolerance: (float | Iterable[float] | str | None)=None, copy:
- bool=True, fill_value: Any=xrdtypes.NA) ->Self:
+ new_variables = variables.copy()
+ new_indexes = indexes.copy()
+
+ # re-assign variable metadata
+ for name, new_var in new_variables.items():
+ var = self._variables.get(name)
+ if var is not None:
+ new_var.attrs = var.attrs
+ new_var.encoding = var.encoding
+
+ # pass through indexes from excluded dimensions
+ # no extra check needed for multi-coordinate indexes, potential conflicts
+ # should already have been detected when aligning the indexes
+ for name, idx in self._indexes.items():
+ var = self._variables[name]
+ if set(var.dims) <= exclude_dims:
+ new_indexes[name] = idx
+ new_variables[name] = var
+
+ if not dim_pos_indexers:
+ # fast path for no reindexing necessary
+ if set(new_indexes) - set(self._indexes):
+ # this only adds new indexes and their coordinate variables
+ reindexed = self._overwrite_indexes(new_indexes, new_variables)
+ else:
+ reindexed = self.copy(deep=aligner.copy)
+ else:
+ to_reindex = {
+ k: v
+ for k, v in self.variables.items()
+ if k not in variables and k not in exclude_vars
+ }
+ reindexed_vars = alignment.reindex_variables(
+ to_reindex,
+ dim_pos_indexers,
+ copy=aligner.copy,
+ fill_value=fill_value,
+ sparse=aligner.sparse,
+ )
+ new_variables.update(reindexed_vars)
+ new_coord_names = self._coord_names | set(new_indexes)
+ reindexed = self._replace_with_new_dims(
+ new_variables, new_coord_names, indexes=new_indexes
+ )
+
+ reindexed.encoding = self.encoding
+
+ return reindexed
+
+ def reindex_like(
+ self,
+ other: T_Xarray,
+ method: ReindexMethodOptions = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value: Any = xrdtypes.NA,
+ ) -> Self:
"""
Conform this object onto the indexes of another object, for indexes which the
objects share. Missing values are filled with ``fill_value``. The default fill
@@ -2113,12 +3599,24 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
align
"""
- pass
-
- def reindex(self, indexers: (Mapping[Any, Any] | None)=None, method:
- ReindexMethodOptions=None, tolerance: (float | Iterable[float] |
- str | None)=None, copy: bool=True, fill_value: Any=xrdtypes.NA, **
- indexers_kwargs: Any) ->Self:
+ return alignment.reindex_like(
+ self,
+ other=other,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ )
+
+ def reindex(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ method: ReindexMethodOptions = None,
+ tolerance: float | Iterable[float] | str | None = None,
+ copy: bool = True,
+ fill_value: Any = xrdtypes.NA,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Conform this object onto a new set of indexes, filling in
missing values with ``fill_value``. The default fill value is NaN.
@@ -2316,21 +3814,49 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
original dataset, use the :py:meth:`~Dataset.fillna()` method.
"""
- pass
-
- def _reindex(self, indexers: (Mapping[Any, Any] | None)=None, method: (
- str | None)=None, tolerance: (int | float | Iterable[int | float] |
- None)=None, copy: bool=True, fill_value: Any=xrdtypes.NA, sparse:
- bool=False, **indexers_kwargs: Any) ->Self:
+ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")
+ return alignment.reindex(
+ self,
+ indexers=indexers,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ )
+
+ def _reindex(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ method: str | None = None,
+ tolerance: int | float | Iterable[int | float] | None = None,
+ copy: bool = True,
+ fill_value: Any = xrdtypes.NA,
+ sparse: bool = False,
+ **indexers_kwargs: Any,
+ ) -> Self:
"""
Same as reindex but supports sparse option.
"""
- pass
-
- def interp(self, coords: (Mapping[Any, Any] | None)=None, method:
- InterpOptions='linear', assume_sorted: bool=False, kwargs: (Mapping
- [str, Any] | None)=None, method_non_numeric: str='nearest', **
- coords_kwargs: Any) ->Self:
+ indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")
+ return alignment.reindex(
+ self,
+ indexers=indexers,
+ method=method,
+ tolerance=tolerance,
+ copy=copy,
+ fill_value=fill_value,
+ sparse=sparse,
+ )
+
+ def interp(
+ self,
+ coords: Mapping[Any, Any] | None = None,
+ method: InterpOptions = "linear",
+ assume_sorted: bool = False,
+ kwargs: Mapping[str, Any] | None = None,
+ method_non_numeric: str = "nearest",
+ **coords_kwargs: Any,
+ ) -> Self:
"""Interpolate a Dataset onto new coordinates
Performs univariate or multivariate interpolation of a Dataset onto
@@ -2349,7 +3875,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
New coordinate can be a scalar, array-like or DataArray.
If DataArrays are passed as new coordinates, their dimensions are
used for the broadcasting. Missing values are skipped.
- method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
+ method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \
+ "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
String indicating which method to use for interpolation:
- 'linear': linear interpolation. Additional keyword
@@ -2468,11 +3995,149 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
a (x) float64 32B 5.0 6.5 6.25 4.75
b (x, y) float64 96B 2.5 3.0 nan 4.0 5.625 ... nan nan nan nan nan
"""
- pass
+ from xarray.core import missing
+
+ if kwargs is None:
+ kwargs = {}
- def interp_like(self, other: T_Xarray, method: InterpOptions='linear',
- assume_sorted: bool=False, kwargs: (Mapping[str, Any] | None)=None,
- method_non_numeric: str='nearest') ->Self:
+ coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
+ indexers = dict(self._validate_interp_indexers(coords))
+
+ if coords:
+ # This avoids broadcasting over coordinates that are both in
+ # the original array AND in the indexing array. It essentially
+ # forces interpolation along the shared coordinates.
+ sdims = (
+ set(self.dims)
+ .intersection(*[set(nx.dims) for nx in indexers.values()])
+ .difference(coords.keys())
+ )
+ indexers.update({d: self.variables[d] for d in sdims})
+
+ obj = self if assume_sorted else self.sortby([k for k in coords])
+
+ def maybe_variable(obj, k):
+ # workaround to get variable for dimension without coordinate.
+ try:
+ return obj._variables[k]
+ except KeyError:
+ return as_variable((k, range(obj.sizes[k])))
+
+ def _validate_interp_indexer(x, new_x):
+ # In the case of datetimes, the restrictions placed on indexers
+ # used with interp are stronger than those which are placed on
+ # isel, so we need an additional check after _validate_indexers.
+ if _contains_datetime_like_objects(
+ x
+ ) and not _contains_datetime_like_objects(new_x):
+ raise TypeError(
+ "When interpolating over a datetime-like "
+ "coordinate, the coordinates to "
+ "interpolate to must be either datetime "
+ "strings or datetimes. "
+ f"Instead got\n{new_x}"
+ )
+ return x, new_x
+
+ validated_indexers = {
+ k: _validate_interp_indexer(maybe_variable(obj, k), v)
+ for k, v in indexers.items()
+ }
+
+ # optimization: subset to coordinate range of the target index
+ if method in ["linear", "nearest"]:
+ for k, v in validated_indexers.items():
+ obj, newidx = missing._localize(obj, {k: v})
+ validated_indexers[k] = newidx[k]
+
+ # optimization: create dask coordinate arrays once per Dataset
+ # rather than once per Variable when dask.array.unify_chunks is called later
+ # GH4739
+ if obj.__dask_graph__():
+ dask_indexers = {
+ k: (index.to_base_variable().chunk(), dest.to_base_variable().chunk())
+ for k, (index, dest) in validated_indexers.items()
+ }
+
+ variables: dict[Hashable, Variable] = {}
+ reindex: bool = False
+ for name, var in obj._variables.items():
+ if name in indexers:
+ continue
+
+ if is_duck_dask_array(var.data):
+ use_indexers = dask_indexers
+ else:
+ use_indexers = validated_indexers
+
+ dtype_kind = var.dtype.kind
+ if dtype_kind in "uifc":
+ # For normal number types do the interpolation:
+ var_indexers = {k: v for k, v in use_indexers.items() if k in var.dims}
+ variables[name] = missing.interp(var, var_indexers, method, **kwargs)
+ elif dtype_kind in "ObU" and (use_indexers.keys() & var.dims):
+ # For types that we do not understand do stepwise
+ # interpolation to avoid modifying the elements.
+ # reindex the variable instead because it supports
+ # booleans and objects and retains the dtype but inside
+ # this loop there might be some duplicate code that slows it
+ # down, therefore collect these signals and run it later:
+ reindex = True
+ elif all(d not in indexers for d in var.dims):
+ # For anything else we can only keep variables if they
+ # are not dependent on any coords that are being
+ # interpolated along:
+ variables[name] = var
+
+ if reindex:
+ reindex_indexers = {
+ k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)
+ }
+ reindexed = alignment.reindex(
+ obj,
+ indexers=reindex_indexers,
+ method=method_non_numeric,
+ exclude_vars=variables.keys(),
+ )
+ indexes = dict(reindexed._indexes)
+ variables.update(reindexed.variables)
+ else:
+ # Get the indexes that are not being interpolated along
+ indexes = {k: v for k, v in obj._indexes.items() if k not in indexers}
+
+ # Get the coords that also exist in the variables:
+ coord_names = obj._coord_names & variables.keys()
+ selected = self._replace_with_new_dims(
+ variables.copy(), coord_names, indexes=indexes
+ )
+
+ # Attach indexer as coordinate
+ for k, v in indexers.items():
+ assert isinstance(v, Variable)
+ if v.dims == (k,):
+ index = PandasIndex(v, k, coord_dtype=v.dtype)
+ index_vars = index.create_variables({k: v})
+ indexes[k] = index
+ variables.update(index_vars)
+ else:
+ variables[k] = v
+
+ # Extract coordinates from indexers
+ coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords)
+ variables.update(coord_vars)
+ indexes.update(new_indexes)
+
+ coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
+ return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
+
+ def interp_like(
+ self,
+ other: T_Xarray,
+ method: InterpOptions = "linear",
+ assume_sorted: bool = False,
+ kwargs: Mapping[str, Any] | None = None,
+ method_non_numeric: str = "nearest",
+ ) -> Self:
"""Interpolate this object onto the coordinates of another object,
filling the out of range values with NaN.
@@ -2489,7 +4154,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Object with an 'indexes' attribute giving a mapping from dimension
names to an 1d array-like, which provides coordinates upon
which to index the variables in this dataset. Missing values are skipped.
- method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
+ method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \
+ "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
String indicating which method to use for interpolation:
- 'linear': linear interpolation. Additional keyword
@@ -2529,17 +4195,147 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.interp
Dataset.reindex_like
"""
- pass
-
- def _rename(self, name_dict: (Mapping[Any, Hashable] | None)=None, **
- names: Hashable) ->Self:
+ if kwargs is None:
+ kwargs = {}
+
+ # pick only dimension coordinates with a single index
+ coords: dict[Hashable, Variable] = {}
+ other_indexes = other.xindexes
+ for dim in self.dims:
+ other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore")
+ if len(other_dim_coords) == 1:
+ coords[dim] = other_dim_coords[dim]
+
+ numeric_coords: dict[Hashable, Variable] = {}
+ object_coords: dict[Hashable, Variable] = {}
+ for k, v in coords.items():
+ if v.dtype.kind in "uifcMm":
+ numeric_coords[k] = v
+ else:
+ object_coords[k] = v
+
+ ds = self
+ if object_coords:
+ # We do not support interpolation along object coordinate.
+ # reindex instead.
+ ds = self.reindex(object_coords)
+ return ds.interp(
+ coords=numeric_coords,
+ method=method,
+ assume_sorted=assume_sorted,
+ kwargs=kwargs,
+ method_non_numeric=method_non_numeric,
+ )
+
+ # Helper methods for rename()
+ def _rename_vars(
+ self, name_dict, dims_dict
+ ) -> tuple[dict[Hashable, Variable], set[Hashable]]:
+ variables = {}
+ coord_names = set()
+ for k, v in self.variables.items():
+ var = v.copy(deep=False)
+ var.dims = tuple(dims_dict.get(dim, dim) for dim in v.dims)
+ name = name_dict.get(k, k)
+ if name in variables:
+ raise ValueError(f"the new name {name!r} conflicts")
+ variables[name] = var
+ if k in self._coord_names:
+ coord_names.add(name)
+ return variables, coord_names
+
+ def _rename_dims(self, name_dict: Mapping[Any, Hashable]) -> dict[Hashable, int]:
+ return {name_dict.get(k, k): v for k, v in self.sizes.items()}
+
+ def _rename_indexes(
+ self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable]
+ ) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
+ if not self._indexes:
+ return {}, {}
+
+ indexes = {}
+ variables = {}
+
+ for index, coord_names in self.xindexes.group_by_index():
+ new_index = index.rename(name_dict, dims_dict)
+ new_coord_names = [name_dict.get(k, k) for k in coord_names]
+ indexes.update({k: new_index for k in new_coord_names})
+ new_index_vars = new_index.create_variables(
+ {
+ new: self._variables[old]
+ for old, new in zip(coord_names, new_coord_names)
+ }
+ )
+ variables.update(new_index_vars)
+
+ return indexes, variables
+
+ def _rename_all(
+ self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable]
+ ) -> tuple[
+ dict[Hashable, Variable],
+ set[Hashable],
+ dict[Hashable, int],
+ dict[Hashable, Index],
+ ]:
+ variables, coord_names = self._rename_vars(name_dict, dims_dict)
+ dims = self._rename_dims(dims_dict)
+
+ indexes, index_vars = self._rename_indexes(name_dict, dims_dict)
+ variables = {k: index_vars.get(k, v) for k, v in variables.items()}
+
+ return variables, coord_names, dims, indexes
+
+ def _rename(
+ self,
+ name_dict: Mapping[Any, Hashable] | None = None,
+ **names: Hashable,
+ ) -> Self:
"""Also used internally by DataArray so that the warning (if any)
is raised at the right stack level.
"""
- pass
+ name_dict = either_dict_or_kwargs(name_dict, names, "rename")
+ for k in name_dict.keys():
+ if k not in self and k not in self.dims:
+ raise ValueError(
+ f"cannot rename {k!r} because it is not a "
+ "variable or dimension in this dataset"
+ )
+
+ create_dim_coord = False
+ new_k = name_dict[k]
+
+ if k == new_k:
+ continue # Same name, nothing to do
+
+ if k in self.dims and new_k in self._coord_names:
+ coord_dims = self._variables[name_dict[k]].dims
+ if coord_dims == (k,):
+ create_dim_coord = True
+ elif k in self._coord_names and new_k in self.dims:
+ coord_dims = self._variables[k].dims
+ if coord_dims == (new_k,):
+ create_dim_coord = True
+
+ if create_dim_coord:
+ warnings.warn(
+ f"rename {k!r} to {name_dict[k]!r} does not create an index "
+ "anymore. Try using swap_dims instead or use set_index "
+ "after rename to create an indexed coordinate.",
+ UserWarning,
+ stacklevel=3,
+ )
- def rename(self, name_dict: (Mapping[Any, Hashable] | None)=None, **
- names: Hashable) ->Self:
+ variables, coord_names, dims, indexes = self._rename_all(
+ name_dict=name_dict, dims_dict=name_dict
+ )
+ return self._replace(variables, coord_names, dims=dims, indexes=indexes)
+
+ def rename(
+ self,
+ name_dict: Mapping[Any, Hashable] | None = None,
+ **names: Hashable,
+ ) -> Self:
"""Returns a new object with renamed variables, coordinates and dimensions.
Parameters
@@ -2563,10 +4359,13 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.rename_dims
DataArray.rename
"""
- pass
+ return self._rename(name_dict=name_dict, **names)
- def rename_dims(self, dims_dict: (Mapping[Any, Hashable] | None)=None,
- **dims: Hashable) ->Self:
+ def rename_dims(
+ self,
+ dims_dict: Mapping[Any, Hashable] | None = None,
+ **dims: Hashable,
+ ) -> Self:
"""Returns a new object with renamed dimensions only.
Parameters
@@ -2591,10 +4390,29 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.rename_vars
DataArray.rename
"""
- pass
+ dims_dict = either_dict_or_kwargs(dims_dict, dims, "rename_dims")
+ for k, v in dims_dict.items():
+ if k not in self.dims:
+ raise ValueError(
+ f"cannot rename {k!r} because it is not found "
+ f"in the dimensions of this dataset {tuple(self.dims)}"
+ )
+ if v in self.dims or v in self:
+ raise ValueError(
+ f"Cannot rename {k} to {v} because {v} already exists. "
+ "Try using swap_dims instead."
+ )
- def rename_vars(self, name_dict: (Mapping[Any, Hashable] | None)=None,
- **names: Hashable) ->Self:
+ variables, coord_names, sizes, indexes = self._rename_all(
+ name_dict={}, dims_dict=dims_dict
+ )
+ return self._replace(variables, coord_names, dims=sizes, indexes=indexes)
+
+ def rename_vars(
+ self,
+ name_dict: Mapping[Any, Hashable] | None = None,
+ **names: Hashable,
+ ) -> Self:
"""Returns a new object with renamed variables including coordinates
Parameters
@@ -2618,10 +4436,21 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.rename_dims
DataArray.rename
"""
- pass
-
- def swap_dims(self, dims_dict: (Mapping[Any, Hashable] | None)=None, **
- dims_kwargs) ->Self:
+ name_dict = either_dict_or_kwargs(name_dict, names, "rename_vars")
+ for k in name_dict:
+ if k not in self:
+ raise ValueError(
+ f"cannot rename {k!r} because it is not a "
+ "variable or coordinate in this dataset"
+ )
+ variables, coord_names, dims, indexes = self._rename_all(
+ name_dict=name_dict, dims_dict={}
+ )
+ return self._replace(variables, coord_names, dims=dims, indexes=indexes)
+
+ def swap_dims(
+ self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs
+ ) -> Self:
"""Returns a new object with swapped dimensions.
Parameters
@@ -2680,11 +4509,59 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.rename
DataArray.swap_dims
"""
- pass
+ # TODO: deprecate this method in favor of a (less confusing)
+ # rename_dims() method that only renames dimensions.
- def expand_dims(self, dim: (None | Hashable | Sequence[Hashable] |
- Mapping[Any, Any])=None, axis: (None | int | Sequence[int])=None,
- create_index_for_new_dim: bool=True, **dim_kwargs: Any) ->Self:
+ dims_dict = either_dict_or_kwargs(dims_dict, dims_kwargs, "swap_dims")
+ for current_name, new_name in dims_dict.items():
+ if current_name not in self.dims:
+ raise ValueError(
+ f"cannot swap from dimension {current_name!r} because it is "
+ f"not one of the dimensions of this dataset {tuple(self.dims)}"
+ )
+ if new_name in self.variables and self.variables[new_name].dims != (
+ current_name,
+ ):
+ raise ValueError(
+ f"replacement dimension {new_name!r} is not a 1D "
+ f"variable along the old dimension {current_name!r}"
+ )
+
+ result_dims = {dims_dict.get(dim, dim) for dim in self.dims}
+
+ coord_names = self._coord_names.copy()
+ coord_names.update({dim for dim in dims_dict.values() if dim in self.variables})
+
+ variables: dict[Hashable, Variable] = {}
+ indexes: dict[Hashable, Index] = {}
+ for current_name, current_variable in self.variables.items():
+ dims = tuple(dims_dict.get(dim, dim) for dim in current_variable.dims)
+ var: Variable
+ if current_name in result_dims:
+ var = current_variable.to_index_variable()
+ var.dims = dims
+ if current_name in self._indexes:
+ indexes[current_name] = self._indexes[current_name]
+ variables[current_name] = var
+ else:
+ index, index_vars = create_default_index_implicit(var)
+ indexes.update({name: index for name in index_vars})
+ variables.update(index_vars)
+ coord_names.update(index_vars)
+ else:
+ var = current_variable.to_base_variable()
+ var.dims = dims
+ variables[current_name] = var
+
+ return self._replace_with_new_dims(variables, coord_names, indexes=indexes)
+
+ def expand_dims(
+ self,
+ dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None,
+ axis: None | int | Sequence[int] = None,
+ create_index_for_new_dim: bool = True,
+ **dim_kwargs: Any,
+ ) -> Self:
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape. The new object is a
view into the underlying array, not a copy.
@@ -2812,11 +4689,123 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
DataArray.expand_dims
"""
- pass
+ if dim is None:
+ pass
+ elif isinstance(dim, Mapping):
+ # We're later going to modify dim in place; don't tamper with
+ # the input
+ dim = dict(dim)
+ elif isinstance(dim, int):
+ raise TypeError(
+ "dim should be hashable or sequence of hashables or mapping"
+ )
+ elif isinstance(dim, str) or not isinstance(dim, Sequence):
+ dim = {dim: 1}
+ elif isinstance(dim, Sequence):
+ if len(dim) != len(set(dim)):
+ raise ValueError("dims should not contain duplicate values.")
+ dim = {d: 1 for d in dim}
+
+ dim = either_dict_or_kwargs(dim, dim_kwargs, "expand_dims")
+ assert isinstance(dim, MutableMapping)
+
+ if axis is None:
+ axis = list(range(len(dim)))
+ elif not isinstance(axis, Sequence):
+ axis = [axis]
+
+ if len(dim) != len(axis):
+ raise ValueError("lengths of dim and axis should be identical.")
+ for d in dim:
+ if d in self.dims:
+ raise ValueError(f"Dimension {d} already exists.")
+ if d in self._variables and not utils.is_scalar(self._variables[d]):
+ raise ValueError(
+ f"{d} already exists as coordinate or" " variable name."
+ )
- def set_index(self, indexes: (Mapping[Any, Hashable | Sequence[Hashable
- ]] | None)=None, append: bool=False, **indexes_kwargs: (Hashable |
- Sequence[Hashable])) ->Self:
+ variables: dict[Hashable, Variable] = {}
+ indexes: dict[Hashable, Index] = dict(self._indexes)
+ coord_names = self._coord_names.copy()
+ # If dim is a dict, then ensure that the values are either integers
+ # or iterables.
+ for k, v in dim.items():
+ if hasattr(v, "__iter__"):
+ # If the value for the new dimension is an iterable, then
+ # save the coordinates to the variables dict, and set the
+ # value within the dim dict to the length of the iterable
+ # for later use.
+
+ if create_index_for_new_dim:
+ index = PandasIndex(v, k)
+ indexes[k] = index
+ name_and_new_1d_var = index.create_variables()
+ else:
+ name_and_new_1d_var = {k: Variable(data=v, dims=k)}
+ variables.update(name_and_new_1d_var)
+ coord_names.add(k)
+ dim[k] = variables[k].size
+ elif isinstance(v, int):
+ pass # Do nothing if the dimensions value is just an int
+ else:
+ raise TypeError(
+ f"The value of new dimension {k} must be " "an iterable or an int"
+ )
+
+ for k, v in self._variables.items():
+ if k not in dim:
+ if k in coord_names: # Do not change coordinates
+ variables[k] = v
+ else:
+ result_ndim = len(v.dims) + len(axis)
+ for a in axis:
+ if a < -result_ndim or result_ndim - 1 < a:
+ raise IndexError(
+ f"Axis {a} of variable {k} is out of bounds of the "
+ f"expanded dimension size {result_ndim}"
+ )
+
+ axis_pos = [a if a >= 0 else result_ndim + a for a in axis]
+ if len(axis_pos) != len(set(axis_pos)):
+ raise ValueError("axis should not contain duplicate values")
+ # We need to sort them to make sure `axis` equals to the
+ # axis positions of the result array.
+ zip_axis_dim = sorted(zip(axis_pos, dim.items()))
+
+ all_dims = list(zip(v.dims, v.shape))
+ for d, c in zip_axis_dim:
+ all_dims.insert(d, c)
+ variables[k] = v.set_dims(dict(all_dims))
+ else:
+ if k not in variables:
+ if k in coord_names and create_index_for_new_dim:
+ # If dims includes a label of a non-dimension coordinate,
+ # it will be promoted to a 1D coordinate with a single value.
+ index, index_vars = create_default_index_implicit(v.set_dims(k))
+ indexes[k] = index
+ variables.update(index_vars)
+ else:
+ if create_index_for_new_dim:
+ warnings.warn(
+ f"No index created for dimension {k} because variable {k} is not a coordinate. "
+ f"To create an index for {k}, please first call `.set_coords('{k}')` on this object.",
+ UserWarning,
+ )
+
+ # create 1D variable without creating a new index
+ new_1d_var = v.set_dims(k)
+ variables.update({k: new_1d_var})
+
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ def set_index(
+ self,
+ indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None,
+ append: bool = False,
+ **indexes_kwargs: Hashable | Sequence[Hashable],
+ ) -> Self:
"""Set Dataset (multi-)indexes using one or more existing coordinates
or variables.
@@ -2875,11 +4864,114 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.set_xindex
Dataset.swap_dims
"""
- pass
+ dim_coords = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index")
+
+ new_indexes: dict[Hashable, Index] = {}
+ new_variables: dict[Hashable, Variable] = {}
+ drop_indexes: set[Hashable] = set()
+ drop_variables: set[Hashable] = set()
+ replace_dims: dict[Hashable, Hashable] = {}
+ all_var_names: set[Hashable] = set()
+
+ for dim, _var_names in dim_coords.items():
+ if isinstance(_var_names, str) or not isinstance(_var_names, Sequence):
+ var_names = [_var_names]
+ else:
+ var_names = list(_var_names)
+
+ invalid_vars = set(var_names) - set(self._variables)
+ if invalid_vars:
+ raise ValueError(
+ ", ".join([str(v) for v in invalid_vars])
+ + " variable(s) do not exist"
+ )
+
+ all_var_names.update(var_names)
+ drop_variables.update(var_names)
- @_deprecate_positional_args('v2023.10.0')
- def reset_index(self, dims_or_levels: (Hashable | Sequence[Hashable]),
- *, drop: bool=False) ->Self:
+ # drop any pre-existing index involved and its corresponding coordinates
+ index_coord_names = self.xindexes.get_all_coords(dim, errors="ignore")
+ all_index_coord_names = set(index_coord_names)
+ for k in var_names:
+ all_index_coord_names.update(
+ self.xindexes.get_all_coords(k, errors="ignore")
+ )
+
+ drop_indexes.update(all_index_coord_names)
+ drop_variables.update(all_index_coord_names)
+
+ if len(var_names) == 1 and (not append or dim not in self._indexes):
+ var_name = var_names[0]
+ var = self._variables[var_name]
+ # an error with a better message will be raised for scalar variables
+ # when creating the PandasIndex
+ if var.ndim > 0 and var.dims != (dim,):
+ raise ValueError(
+ f"dimension mismatch: try setting an index for dimension {dim!r} with "
+ f"variable {var_name!r} that has dimensions {var.dims}"
+ )
+ idx = PandasIndex.from_variables({dim: var}, options={})
+ idx_vars = idx.create_variables({var_name: var})
+
+ # trick to preserve coordinate order in this case
+ if dim in self._coord_names:
+ drop_variables.remove(dim)
+ else:
+ if append:
+ current_variables = {
+ k: self._variables[k] for k in index_coord_names
+ }
+ else:
+ current_variables = {}
+ idx, idx_vars = PandasMultiIndex.from_variables_maybe_expand(
+ dim,
+ current_variables,
+ {k: self._variables[k] for k in var_names},
+ )
+ for n in idx.index.names:
+ replace_dims[n] = dim
+
+ new_indexes.update({k: idx for k in idx_vars})
+ new_variables.update(idx_vars)
+
+ # re-add deindexed coordinates (convert to base variables)
+ for k in drop_variables:
+ if (
+ k not in new_variables
+ and k not in all_var_names
+ and k in self._coord_names
+ ):
+ new_variables[k] = self._variables[k].to_base_variable()
+
+ indexes_: dict[Any, Index] = {
+ k: v for k, v in self._indexes.items() if k not in drop_indexes
+ }
+ indexes_.update(new_indexes)
+
+ variables = {
+ k: v for k, v in self._variables.items() if k not in drop_variables
+ }
+ variables.update(new_variables)
+
+ # update dimensions if necessary, GH: 3512
+ for k, v in variables.items():
+ if any(d in replace_dims for d in v.dims):
+ new_dims = [replace_dims.get(d, d) for d in v.dims]
+ variables[k] = v._replace(dims=new_dims)
+
+ coord_names = self._coord_names - drop_variables | set(new_variables)
+
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes_
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def reset_index(
+ self,
+ dims_or_levels: Hashable | Sequence[Hashable],
+ *,
+ drop: bool = False,
+ ) -> Self:
"""Reset the specified index(es) or multi-index level(s).
This legacy method is specific to pandas (multi-)indexes and
@@ -2908,10 +5000,90 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.set_xindex
Dataset.drop_indexes
"""
- pass
+ if isinstance(dims_or_levels, str) or not isinstance(dims_or_levels, Sequence):
+ dims_or_levels = [dims_or_levels]
- def set_xindex(self, coord_names: (str | Sequence[Hashable]), index_cls:
- (type[Index] | None)=None, **options) ->Self:
+ invalid_coords = set(dims_or_levels) - set(self._indexes)
+ if invalid_coords:
+ raise ValueError(
+ f"{tuple(invalid_coords)} are not coordinates with an index"
+ )
+
+ drop_indexes: set[Hashable] = set()
+ drop_variables: set[Hashable] = set()
+ seen: set[Index] = set()
+ new_indexes: dict[Hashable, Index] = {}
+ new_variables: dict[Hashable, Variable] = {}
+
+ def drop_or_convert(var_names):
+ if drop:
+ drop_variables.update(var_names)
+ else:
+ base_vars = {
+ k: self._variables[k].to_base_variable() for k in var_names
+ }
+ new_variables.update(base_vars)
+
+ for name in dims_or_levels:
+ index = self._indexes[name]
+
+ if index in seen:
+ continue
+ seen.add(index)
+
+ idx_var_names = set(self.xindexes.get_all_coords(name))
+ drop_indexes.update(idx_var_names)
+
+ if isinstance(index, PandasMultiIndex):
+ # special case for pd.MultiIndex
+ level_names = index.index.names
+ keep_level_vars = {
+ k: self._variables[k]
+ for k in level_names
+ if k not in dims_or_levels
+ }
+
+ if index.dim not in dims_or_levels and keep_level_vars:
+ # do not drop the multi-index completely
+ # instead replace it by a new (multi-)index with dropped level(s)
+ idx = index.keep_levels(keep_level_vars)
+ idx_vars = idx.create_variables(keep_level_vars)
+ new_indexes.update({k: idx for k in idx_vars})
+ new_variables.update(idx_vars)
+ if not isinstance(idx, PandasMultiIndex):
+ # multi-index reduced to single index
+ # backward compatibility: unique level coordinate renamed to dimension
+ drop_variables.update(keep_level_vars)
+ drop_or_convert(
+ [k for k in level_names if k not in keep_level_vars]
+ )
+ else:
+ # always drop the multi-index dimension variable
+ drop_variables.add(index.dim)
+ drop_or_convert(level_names)
+ else:
+ drop_or_convert(idx_var_names)
+
+ indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes}
+ indexes.update(new_indexes)
+
+ variables = {
+ k: v for k, v in self._variables.items() if k not in drop_variables
+ }
+ variables.update(new_variables)
+
+ coord_names = self._coord_names - drop_variables
+
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ def set_xindex(
+ self,
+ coord_names: str | Sequence[Hashable],
+ index_cls: type[Index] | None = None,
+ **options,
+ ) -> Self:
"""Set a new, Xarray-compatible index from one or more existing
coordinate(s).
@@ -2933,11 +5105,96 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Another dataset, with this dataset's data and with a new index.
"""
- pass
+ # the Sequence check is required for mypy
+ if is_scalar(coord_names) or not isinstance(coord_names, Sequence):
+ coord_names = [coord_names]
+
+ if index_cls is None:
+ if len(coord_names) == 1:
+ index_cls = PandasIndex
+ else:
+ index_cls = PandasMultiIndex
+ else:
+ if not issubclass(index_cls, Index):
+ raise TypeError(f"{index_cls} is not a subclass of xarray.Index")
+
+ invalid_coords = set(coord_names) - self._coord_names
+
+ if invalid_coords:
+ msg = ["invalid coordinate(s)"]
+ no_vars = invalid_coords - set(self._variables)
+ data_vars = invalid_coords - no_vars
+ if no_vars:
+ msg.append(f"those variables don't exist: {no_vars}")
+ if data_vars:
+ msg.append(
+ f"those variables are data variables: {data_vars}, use `set_coords` first"
+ )
+ raise ValueError("\n".join(msg))
+
+ # we could be more clever here (e.g., drop-in index replacement if index
+ # coordinates do not conflict), but let's not allow this for now
+ indexed_coords = set(coord_names) & set(self._indexes)
+
+ if indexed_coords:
+ raise ValueError(
+ f"those coordinates already have an index: {indexed_coords}"
+ )
+
+ coord_vars = {name: self._variables[name] for name in coord_names}
+
+ index = index_cls.from_variables(coord_vars, options=options)
+
+ new_coord_vars = index.create_variables(coord_vars)
+
+ # special case for setting a pandas multi-index from level coordinates
+ # TODO: remove it once we depreciate pandas multi-index dimension (tuple
+ # elements) coordinate
+ if isinstance(index, PandasMultiIndex):
+ coord_names = [index.dim] + list(coord_names)
- def reorder_levels(self, dim_order: (Mapping[Any, Sequence[int |
- Hashable]] | None)=None, **dim_order_kwargs: Sequence[int | Hashable]
- ) ->Self:
+ variables: dict[Hashable, Variable]
+ indexes: dict[Hashable, Index]
+
+ if len(coord_names) == 1:
+ variables = self._variables.copy()
+ indexes = self._indexes.copy()
+
+ name = list(coord_names).pop()
+ if name in new_coord_vars:
+ variables[name] = new_coord_vars[name]
+ indexes[name] = index
+ else:
+ # reorder variables and indexes so that coordinates having the same
+ # index are next to each other
+ variables = {}
+ for name, var in self._variables.items():
+ if name not in coord_names:
+ variables[name] = var
+
+ indexes = {}
+ for name, idx in self._indexes.items():
+ if name not in coord_names:
+ indexes[name] = idx
+
+ for name in coord_names:
+ try:
+ variables[name] = new_coord_vars[name]
+ except KeyError:
+ variables[name] = self._variables[name]
+ indexes[name] = index
+
+ return self._replace(
+ variables=variables,
+ coord_names=self._coord_names | set(coord_names),
+ indexes=indexes,
+ )
+
+ def reorder_levels(
+ self,
+ dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None,
+ **dim_order_kwargs: Sequence[int | Hashable],
+ ) -> Self:
"""Rearrange index levels using input order.
Parameters
@@ -2956,10 +5213,38 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Another dataset, with this dataset's data but replaced
coordinates.
"""
- pass
+ dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels")
+ variables = self._variables.copy()
+ indexes = dict(self._indexes)
+ new_indexes: dict[Hashable, Index] = {}
+ new_variables: dict[Hashable, IndexVariable] = {}
+
+ for dim, order in dim_order.items():
+ index = self._indexes[dim]
+
+ if not isinstance(index, PandasMultiIndex):
+ raise ValueError(f"coordinate {dim} has no MultiIndex")
- def _get_stack_index(self, dim, multi=False, create_index=False) ->tuple[
- Index | None, dict[Hashable, Variable]]:
+ level_vars = {k: self._variables[k] for k in order}
+ idx = index.reorder_levels(level_vars)
+ idx_vars = idx.create_variables(level_vars)
+ new_indexes.update({k: idx for k in idx_vars})
+ new_variables.update(idx_vars)
+
+ indexes = {k: v for k, v in self._indexes.items() if k not in new_indexes}
+ indexes.update(new_indexes)
+
+ variables = {k: v for k, v in self._variables.items() if k not in new_variables}
+ variables.update(new_variables)
+
+ return self._replace(variables, indexes=indexes)
+
+ def _get_stack_index(
+ self,
+ dim,
+ multi=False,
+ create_index=False,
+ ) -> tuple[Index | None, dict[Hashable, Variable]]:
"""Used by stack and unstack to get one pandas (multi-)index among
the indexed coordinates along dimension `dim`.
@@ -2970,13 +5255,112 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
an error if multiple indexes are found.
"""
- pass
+ stack_index: Index | None = None
+ stack_coords: dict[Hashable, Variable] = {}
+
+ for name, index in self._indexes.items():
+ var = self._variables[name]
+ if (
+ var.ndim == 1
+ and var.dims[0] == dim
+ and (
+ # stack: must be a single coordinate index
+ not multi
+ and not self.xindexes.is_multi(name)
+ # unstack: must be an index that implements .unstack
+ or multi
+ and type(index).unstack is not Index.unstack
+ )
+ ):
+ if stack_index is not None and index is not stack_index:
+ # more than one index found, stop
+ if create_index:
+ raise ValueError(
+ f"cannot stack dimension {dim!r} with `create_index=True` "
+ "and with more than one index found along that dimension"
+ )
+ return None, {}
+ stack_index = index
+ stack_coords[name] = var
- @partial(deprecate_dims, old_name='dimensions')
- def stack(self, dim: (Mapping[Any, Sequence[Hashable | ellipsis]] |
- None)=None, create_index: (bool | None)=True, index_cls: type[Index
- ]=PandasMultiIndex, **dim_kwargs: Sequence[Hashable | ellipsis]
- ) ->Self:
+ if create_index and stack_index is None:
+ if dim in self._variables:
+ var = self._variables[dim]
+ else:
+ _, _, var = _get_virtual_variable(self._variables, dim, self.sizes)
+ # dummy index (only `stack_coords` will be used to construct the multi-index)
+ stack_index = PandasIndex([0], dim)
+ stack_coords = {dim: var}
+
+ return stack_index, stack_coords
+
+ def _stack_once(
+ self,
+ dims: Sequence[Hashable | ellipsis],
+ new_dim: Hashable,
+ index_cls: type[Index],
+ create_index: bool | None = True,
+ ) -> Self:
+ if dims == ...:
+ raise ValueError("Please use [...] for dims, rather than just ...")
+ if ... in dims:
+ dims = list(infix_dims(dims, self.dims))
+
+ new_variables: dict[Hashable, Variable] = {}
+ stacked_var_names: list[Hashable] = []
+ drop_indexes: list[Hashable] = []
+
+ for name, var in self.variables.items():
+ if any(d in var.dims for d in dims):
+ add_dims = [d for d in dims if d not in var.dims]
+ vdims = list(var.dims) + add_dims
+ shape = [self.sizes[d] for d in vdims]
+ exp_var = var.set_dims(vdims, shape)
+ stacked_var = exp_var.stack(**{new_dim: dims})
+ new_variables[name] = stacked_var
+ stacked_var_names.append(name)
+ else:
+ new_variables[name] = var.copy(deep=False)
+
+ # drop indexes of stacked coordinates (if any)
+ for name in stacked_var_names:
+ drop_indexes += list(self.xindexes.get_all_coords(name, errors="ignore"))
+
+ new_indexes = {}
+ new_coord_names = set(self._coord_names)
+ if create_index or create_index is None:
+ product_vars: dict[Any, Variable] = {}
+ for dim in dims:
+ idx, idx_vars = self._get_stack_index(dim, create_index=create_index)
+ if idx is not None:
+ product_vars.update(idx_vars)
+
+ if len(product_vars) == len(dims):
+ idx = index_cls.stack(product_vars, new_dim)
+ new_indexes[new_dim] = idx
+ new_indexes.update({k: idx for k in product_vars})
+ idx_vars = idx.create_variables(product_vars)
+ # keep consistent multi-index coordinate order
+ for k in idx_vars:
+ new_variables.pop(k, None)
+ new_variables.update(idx_vars)
+ new_coord_names.update(idx_vars)
+
+ indexes = {k: v for k, v in self._indexes.items() if k not in drop_indexes}
+ indexes.update(new_indexes)
+
+ return self._replace_with_new_dims(
+ new_variables, coord_names=new_coord_names, indexes=indexes
+ )
+
+ @partial(deprecate_dims, old_name="dimensions")
+ def stack(
+ self,
+ dim: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None,
+ create_index: bool | None = True,
+ index_cls: type[Index] = PandasMultiIndex,
+ **dim_kwargs: Sequence[Hashable | ellipsis],
+ ) -> Self:
"""
Stack any number of existing dimensions into a single new dimension.
@@ -3014,11 +5398,19 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.unstack
"""
- pass
-
- def to_stacked_array(self, new_dim: Hashable, sample_dims: Collection[
- Hashable], variable_dim: Hashable='variable', name: (Hashable |
- None)=None) ->DataArray:
+ dim = either_dict_or_kwargs(dim, dim_kwargs, "stack")
+ result = self
+ for new_dim, dims in dim.items():
+ result = result._stack_once(dims, new_dim, index_cls, create_index)
+ return result
+
+ def to_stacked_array(
+ self,
+ new_dim: Hashable,
+ sample_dims: Collection[Hashable],
+ variable_dim: Hashable = "variable",
+ name: Hashable | None = None,
+ ) -> DataArray:
"""Combine variables of differing dimensionality into a DataArray
without broadcasting.
@@ -3086,11 +5478,141 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dimensions without coordinates: x
"""
- pass
+ from xarray.core.concat import concat
+
+ stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims)
+
+ for key, da in self.data_vars.items():
+ missing_sample_dims = set(sample_dims) - set(da.dims)
+ if missing_sample_dims:
+ raise ValueError(
+ "Variables in the dataset must contain all ``sample_dims`` "
+ f"({sample_dims!r}) but '{key}' misses {sorted(map(str, missing_sample_dims))}"
+ )
+
+ def stack_dataarray(da):
+ # add missing dims/ coords and the name of the variable
+
+ missing_stack_coords = {variable_dim: da.name}
+ for dim in set(stacking_dims) - set(da.dims):
+ missing_stack_coords[dim] = None
+
+ missing_stack_dims = list(missing_stack_coords)
+
+ return (
+ da.assign_coords(**missing_stack_coords)
+ .expand_dims(missing_stack_dims)
+ .stack({new_dim: (variable_dim,) + stacking_dims})
+ )
+
+ # concatenate the arrays
+ stackable_vars = [stack_dataarray(da) for da in self.data_vars.values()]
+ data_array = concat(stackable_vars, dim=new_dim)
+
+ if name is not None:
+ data_array.name = name
+
+ return data_array
+
+ def _unstack_once(
+ self,
+ dim: Hashable,
+ index_and_vars: tuple[Index, dict[Hashable, Variable]],
+ fill_value,
+ sparse: bool = False,
+ ) -> Self:
+ index, index_vars = index_and_vars
+ variables: dict[Hashable, Variable] = {}
+ indexes = {k: v for k, v in self._indexes.items() if k != dim}
+
+ new_indexes, clean_index = index.unstack()
+ indexes.update(new_indexes)
+
+ for name, idx in new_indexes.items():
+ variables.update(idx.create_variables(index_vars))
+
+ for name, var in self.variables.items():
+ if name not in index_vars:
+ if dim in var.dims:
+ if isinstance(fill_value, Mapping):
+ fill_value_ = fill_value[name]
+ else:
+ fill_value_ = fill_value
+
+ variables[name] = var._unstack_once(
+ index=clean_index,
+ dim=dim,
+ fill_value=fill_value_,
+ sparse=sparse,
+ )
+ else:
+ variables[name] = var
+
+ coord_names = set(self._coord_names) - {dim} | set(new_indexes)
+
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ def _unstack_full_reindex(
+ self,
+ dim: Hashable,
+ index_and_vars: tuple[Index, dict[Hashable, Variable]],
+ fill_value,
+ sparse: bool,
+ ) -> Self:
+ index, index_vars = index_and_vars
+ variables: dict[Hashable, Variable] = {}
+ indexes = {k: v for k, v in self._indexes.items() if k != dim}
+
+ new_indexes, clean_index = index.unstack()
+ indexes.update(new_indexes)
+
+ new_index_variables = {}
+ for name, idx in new_indexes.items():
+ new_index_variables.update(idx.create_variables(index_vars))
+
+ new_dim_sizes = {k: v.size for k, v in new_index_variables.items()}
+ variables.update(new_index_variables)
+
+ # take a shortcut in case the MultiIndex was not modified.
+ full_idx = pd.MultiIndex.from_product(
+ clean_index.levels, names=clean_index.names
+ )
+ if clean_index.equals(full_idx):
+ obj = self
+ else:
+ # TODO: we may depreciate implicit re-indexing with a pandas.MultiIndex
+ xr_full_idx = PandasMultiIndex(full_idx, dim)
+ indexers = Indexes(
+ {k: xr_full_idx for k in index_vars},
+ xr_full_idx.create_variables(index_vars),
+ )
+ obj = self._reindex(
+ indexers, copy=False, fill_value=fill_value, sparse=sparse
+ )
+
+ for name, var in obj.variables.items():
+ if name not in index_vars:
+ if dim in var.dims:
+ variables[name] = var.unstack({dim: new_dim_sizes})
+ else:
+ variables[name] = var
+
+ coord_names = set(self._coord_names) - {dim} | set(new_dim_sizes)
- @_deprecate_positional_args('v2023.10.0')
- def unstack(self, dim: Dims=None, *, fill_value: Any=xrdtypes.NA,
- sparse: bool=False) ->Self:
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def unstack(
+ self,
+ dim: Dims = None,
+ *,
+ fill_value: Any = xrdtypes.NA,
+ sparse: bool = False,
+ ) -> Self:
"""
Unstack existing dimensions corresponding to MultiIndexes into
multiple new dimensions.
@@ -3118,9 +5640,76 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.stack
"""
- pass
- def update(self, other: CoercibleMapping) ->Self:
+ if dim is None:
+ dims = list(self.dims)
+ else:
+ if isinstance(dim, str) or not isinstance(dim, Iterable):
+ dims = [dim]
+ else:
+ dims = list(dim)
+
+ missing_dims = set(dims) - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ # each specified dimension must have exactly one multi-index
+ stacked_indexes: dict[Any, tuple[Index, dict[Hashable, Variable]]] = {}
+ for d in dims:
+ idx, idx_vars = self._get_stack_index(d, multi=True)
+ if idx is not None:
+ stacked_indexes[d] = idx, idx_vars
+
+ if dim is None:
+ dims = list(stacked_indexes)
+ else:
+ non_multi_dims = set(dims) - set(stacked_indexes)
+ if non_multi_dims:
+ raise ValueError(
+ "cannot unstack dimensions that do not "
+ f"have exactly one multi-index: {tuple(non_multi_dims)}"
+ )
+
+ result = self.copy(deep=False)
+
+ # we want to avoid allocating an object-dtype ndarray for a MultiIndex,
+ # so we can't just access self.variables[v].data for every variable.
+ # We only check the non-index variables.
+ # https://github.com/pydata/xarray/issues/5902
+ nonindexes = [
+ self.variables[k] for k in set(self.variables) - set(self._indexes)
+ ]
+ # Notes for each of these cases:
+ # 1. Dask arrays don't support assignment by index, which the fast unstack
+ # function requires.
+ # https://github.com/pydata/xarray/pull/4746#issuecomment-753282125
+ # 2. Sparse doesn't currently support (though we could special-case it)
+ # https://github.com/pydata/sparse/issues/422
+ # 3. pint requires checking if it's a NumPy array until
+ # https://github.com/pydata/xarray/pull/4751 is resolved,
+ # Once that is resolved, explicitly exclude pint arrays.
+ # pint doesn't implement `np.full_like` in a way that's
+ # currently compatible.
+ sparse_array_type = array_type("sparse")
+ needs_full_reindex = any(
+ is_duck_dask_array(v.data)
+ or isinstance(v.data, sparse_array_type)
+ or not isinstance(v.data, np.ndarray)
+ for v in nonindexes
+ )
+
+ for d in dims:
+ if needs_full_reindex:
+ result = result._unstack_full_reindex(
+ d, stacked_indexes[d], fill_value, sparse
+ )
+ else:
+ result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse)
+ return result
+
+ def update(self, other: CoercibleMapping) -> Self:
"""Update this dataset's variables with those from another dataset.
Just like :py:meth:`dict.update` this is a in-place operation.
@@ -3156,12 +5745,18 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.assign
Dataset.merge
"""
- pass
-
- def merge(self, other: (CoercibleMapping | DataArray), overwrite_vars:
- (Hashable | Iterable[Hashable])=frozenset(), compat: CompatOptions=
- 'no_conflicts', join: JoinOptions='outer', fill_value: Any=xrdtypes
- .NA, combine_attrs: CombineAttrsOptions='override') ->Self:
+ merge_result = dataset_update_method(self, other)
+ return self._replace(inplace=True, **merge_result._asdict())
+
+ def merge(
+ self,
+ other: CoercibleMapping | DataArray,
+ overwrite_vars: Hashable | Iterable[Hashable] = frozenset(),
+ compat: CompatOptions = "no_conflicts",
+ join: JoinOptions = "outer",
+ fill_value: Any = xrdtypes.NA,
+ combine_attrs: CombineAttrsOptions = "override",
+ ) -> Self:
"""Merge the arrays of two datasets into a single dataset.
This method generally does not allow for overriding data, with the
@@ -3176,7 +5771,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
overwrite_vars : hashable or iterable of hashable, optional
If provided, update variables of these name(s) without checking for
conflicts in this dataset.
- compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal"}, default: "no_conflicts"
+ compat : {"identical", "equals", "broadcast_equals", \
+ "no_conflicts", "override", "minimal"}, default: "no_conflicts"
String indicating how to compare variables of the same name for
potential conflicts:
@@ -3191,7 +5787,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
- 'override': skip comparing and pick variable from first dataset
- 'minimal': drop conflicting coordinates
- join : {"outer", "inner", "left", "right", "exact", "override"}, default: "outer"
+ join : {"outer", "inner", "left", "right", "exact", "override"}, \
+ default: "outer"
Method for joining ``self`` and ``other`` along shared dimensions:
- 'outer': use the union of the indexes
@@ -3205,7 +5802,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
fill_value : scalar or dict-like, optional
Value to use for newly missing values. If a dict-like, maps
variable names (including coordinates) to fill values.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -3235,10 +5833,38 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.update
"""
- pass
+ from xarray.core.dataarray import DataArray
- def drop_vars(self, names: (str | Iterable[Hashable] | Callable[[Self],
- str | Iterable[Hashable]]), *, errors: ErrorOptions='raise') ->Self:
+ other = other.to_dataset() if isinstance(other, DataArray) else other
+ merge_result = dataset_merge_method(
+ self,
+ other,
+ overwrite_vars=overwrite_vars,
+ compat=compat,
+ join=join,
+ fill_value=fill_value,
+ combine_attrs=combine_attrs,
+ )
+ return self._replace(**merge_result._asdict())
+
+ def _assert_all_in_dataset(
+ self, names: Iterable[Hashable], virtual_okay: bool = False
+ ) -> None:
+ bad_names = set(names) - set(self._variables)
+ if virtual_okay:
+ bad_names -= self.virtual_variables
+ if bad_names:
+ ordered_bad_names = [name for name in names if name in bad_names]
+ raise ValueError(
+ f"These variables cannot be found in this dataset: {ordered_bad_names}"
+ )
+
+ def drop_vars(
+ self,
+ names: str | Iterable[Hashable] | Callable[[Self], str | Iterable[Hashable]],
+ *,
+ errors: ErrorOptions = "raise",
+ ) -> Self:
"""Drop variables from this dataset.
Parameters
@@ -3357,10 +5983,48 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray.drop_vars
"""
- pass
+ if callable(names):
+ names = names(self)
+ # the Iterable check is required for mypy
+ if is_scalar(names) or not isinstance(names, Iterable):
+ names_set = {names}
+ else:
+ names_set = set(names)
+ if errors == "raise":
+ self._assert_all_in_dataset(names_set)
+
+ # GH6505
+ other_names = set()
+ for var in names_set:
+ maybe_midx = self._indexes.get(var, None)
+ if isinstance(maybe_midx, PandasMultiIndex):
+ idx_coord_names = set(list(maybe_midx.index.names) + [maybe_midx.dim])
+ idx_other_names = idx_coord_names - set(names_set)
+ other_names.update(idx_other_names)
+ if other_names:
+ names_set |= set(other_names)
+ warnings.warn(
+ f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. "
+ f"Please also drop the following variables: {other_names!r} to avoid an error in the future.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
- def drop_indexes(self, coord_names: (Hashable | Iterable[Hashable]), *,
- errors: ErrorOptions='raise') ->Self:
+ assert_no_index_corrupted(self.xindexes, names_set)
+
+ variables = {k: v for k, v in self._variables.items() if k not in names_set}
+ coord_names = {k for k in self._coord_names if k in variables}
+ indexes = {k: v for k, v in self._indexes.items() if k not in names_set}
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ def drop_indexes(
+ self,
+ coord_names: Hashable | Iterable[Hashable],
+ *,
+ errors: ErrorOptions = "raise",
+ ) -> Self:
"""Drop the indexes assigned to the given coordinates.
Parameters
@@ -3372,16 +6036,53 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
passed have no index or are not in the dataset.
If 'ignore', no error is raised.
- Returns
- -------
- dropped : Dataset
- A new dataset with dropped indexes.
+ Returns
+ -------
+ dropped : Dataset
+ A new dataset with dropped indexes.
+
+ """
+ # the Iterable check is required for mypy
+ if is_scalar(coord_names) or not isinstance(coord_names, Iterable):
+ coord_names = {coord_names}
+ else:
+ coord_names = set(coord_names)
+
+ if errors == "raise":
+ invalid_coords = coord_names - self._coord_names
+ if invalid_coords:
+ raise ValueError(
+ f"The coordinates {tuple(invalid_coords)} are not found in the "
+ f"dataset coordinates {tuple(self.coords.keys())}"
+ )
+
+ unindexed_coords = set(coord_names) - set(self._indexes)
+ if unindexed_coords:
+ raise ValueError(
+ f"those coordinates do not have an index: {unindexed_coords}"
+ )
+
+ assert_no_index_corrupted(self.xindexes, coord_names, action="remove index(es)")
+
+ variables = {}
+ for name, var in self._variables.items():
+ if name in coord_names:
+ variables[name] = var.to_base_variable()
+ else:
+ variables[name] = var
+
+ indexes = {k: v for k, v in self._indexes.items() if k not in coord_names}
- """
- pass
+ return self._replace(variables=variables, indexes=indexes)
- def drop(self, labels=None, dim=None, *, errors: ErrorOptions='raise',
- **labels_kwargs) ->Self:
+ def drop(
+ self,
+ labels=None,
+ dim=None,
+ *,
+ errors: ErrorOptions = "raise",
+ **labels_kwargs,
+ ) -> Self:
"""Backward compatible method based on `drop_vars` and `drop_sel`
Using either `drop_vars` or `drop_sel` is encouraged
@@ -3391,10 +6092,48 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.drop_vars
Dataset.drop_sel
"""
- pass
+ if errors not in ["raise", "ignore"]:
+ raise ValueError('errors must be either "raise" or "ignore"')
+
+ if is_dict_like(labels) and not isinstance(labels, dict):
+ emit_user_level_warning(
+ "dropping coordinates using `drop` is deprecated; use drop_vars.",
+ DeprecationWarning,
+ )
+ return self.drop_vars(labels, errors=errors)
- def drop_sel(self, labels=None, *, errors: ErrorOptions='raise', **
- labels_kwargs) ->Self:
+ if labels_kwargs or isinstance(labels, dict):
+ if dim is not None:
+ raise ValueError("cannot specify dim and dict-like arguments.")
+ labels = either_dict_or_kwargs(labels, labels_kwargs, "drop")
+
+ if dim is None and (is_scalar(labels) or isinstance(labels, Iterable)):
+ emit_user_level_warning(
+ "dropping variables using `drop` is deprecated; use drop_vars.",
+ DeprecationWarning,
+ )
+ # for mypy
+ if is_scalar(labels):
+ labels = [labels]
+ return self.drop_vars(labels, errors=errors)
+ if dim is not None:
+ warnings.warn(
+ "dropping labels using list-like labels is deprecated; using "
+ "dict-like arguments with `drop_sel`, e.g. `ds.drop_sel(dim=[labels]).",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.drop_sel({dim: labels}, errors=errors, **labels_kwargs)
+
+ emit_user_level_warning(
+ "dropping labels using `drop` is deprecated; use `drop_sel` instead.",
+ DeprecationWarning,
+ )
+ return self.drop_sel(labels, errors=errors)
+
+ def drop_sel(
+ self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs
+ ) -> Self:
"""Drop index labels from this dataset.
Parameters
@@ -3443,9 +6182,27 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Data variables:
A (x, y) int64 32B 0 2 3 5
"""
- pass
+ if errors not in ["raise", "ignore"]:
+ raise ValueError('errors must be either "raise" or "ignore"')
+
+ labels = either_dict_or_kwargs(labels, labels_kwargs, "drop_sel")
- def drop_isel(self, indexers=None, **indexers_kwargs) ->Self:
+ ds = self
+ for dim, labels_for_dim in labels.items():
+ # Don't cast to set, as it would harm performance when labels
+ # is a large numpy array
+ if utils.is_scalar(labels_for_dim):
+ labels_for_dim = [labels_for_dim]
+ labels_for_dim = np.asarray(labels_for_dim)
+ try:
+ index = self.get_index(dim)
+ except KeyError:
+ raise ValueError(f"dimension {dim!r} does not have coordinate labels")
+ new_index = index.drop(labels_for_dim, errors=errors)
+ ds = ds.loc[{dim: new_index}]
+ return ds
+
+ def drop_isel(self, indexers=None, **indexers_kwargs) -> Self:
"""Drop index positions from this Dataset.
Parameters
@@ -3493,10 +6250,29 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Data variables:
A (x, y) int64 32B 0 2 3 5
"""
- pass
- def drop_dims(self, drop_dims: (str | Iterable[Hashable]), *, errors:
- ErrorOptions='raise') ->Self:
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "drop_isel")
+
+ ds = self
+ dimension_index = {}
+ for dim, pos_for_dim in indexers.items():
+ # Don't cast to set, as it would harm performance when labels
+ # is a large numpy array
+ if utils.is_scalar(pos_for_dim):
+ pos_for_dim = [pos_for_dim]
+ pos_for_dim = np.asarray(pos_for_dim)
+ index = self.get_index(dim)
+ new_index = index.delete(pos_for_dim)
+ dimension_index[dim] = new_index
+ ds = ds.loc[dimension_index]
+ return ds
+
+ def drop_dims(
+ self,
+ drop_dims: str | Iterable[Hashable],
+ *,
+ errors: ErrorOptions = "raise",
+ ) -> Self:
"""Drop dimensions and associated variables from this dataset.
Parameters
@@ -3514,11 +6290,30 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
The dataset without the given dimensions (or any variables
containing those dimensions).
"""
- pass
+ if errors not in ["raise", "ignore"]:
+ raise ValueError('errors must be either "raise" or "ignore"')
+
+ if isinstance(drop_dims, str) or not isinstance(drop_dims, Iterable):
+ drop_dims = {drop_dims}
+ else:
+ drop_dims = set(drop_dims)
+
+ if errors == "raise":
+ missing_dims = drop_dims - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ drop_vars = {k for k, v in self._variables.items() if set(v.dims) & drop_dims}
+ return self.drop_vars(drop_vars)
@deprecate_dims
- def transpose(self, *dim: Hashable, missing_dims: ErrorOptionsWithWarn=
- 'raise') ->Self:
+ def transpose(
+ self,
+ *dim: Hashable,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ ) -> Self:
"""Return a new Dataset object with all array dimensions transposed.
Although the order of dimensions on each array will change, the dataset
@@ -3553,12 +6348,32 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
numpy.transpose
DataArray.transpose
"""
- pass
+ # Raise error if list is passed as dim
+ if (len(dim) > 0) and (isinstance(dim[0], list)):
+ list_fix = [f"{repr(x)}" if isinstance(x, str) else f"{x}" for x in dim[0]]
+ raise TypeError(
+ f'transpose requires dim to be passed as multiple arguments. Expected `{", ".join(list_fix)}`. Received `{dim[0]}` instead'
+ )
- @_deprecate_positional_args('v2023.10.0')
- def dropna(self, dim: Hashable, *, how: Literal['any', 'all']='any',
- thresh: (int | None)=None, subset: (Iterable[Hashable] | None)=None
- ) ->Self:
+ # Use infix_dims to check once for missing dimensions
+ if len(dim) != 0:
+ _ = list(infix_dims(dim, self.dims, missing_dims))
+
+ ds = self.copy()
+ for name, var in self._variables.items():
+ var_dims = tuple(d for d in dim if d in (var.dims + (...,)))
+ ds._variables[name] = var.transpose(*var_dims)
+ return ds
+
+ @_deprecate_positional_args("v2023.10.0")
+ def dropna(
+ self,
+ dim: Hashable,
+ *,
+ how: Literal["any", "all"] = "any",
+ thresh: int | None = None,
+ subset: Iterable[Hashable] | None = None,
+ ) -> Self:
"""Returns a new dataset with dropped labels for missing values along
the provided dimension.
@@ -3650,9 +6465,42 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray.dropna
pandas.DataFrame.dropna
"""
- pass
+ # TODO: consider supporting multiple dimensions? Or not, given that
+ # there are some ugly edge cases, e.g., pandas's dropna differs
+ # depending on the order of the supplied axes.
+
+ if dim not in self.dims:
+ raise ValueError(
+ f"Dimension {dim!r} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ if subset is None:
+ subset = iter(self.data_vars)
+
+ count = np.zeros(self.sizes[dim], dtype=np.int64)
+ size = np.int_(0) # for type checking
+
+ for k in subset:
+ array = self._variables[k]
+ if dim in array.dims:
+ dims = [d for d in array.dims if d != dim]
+ count += np.asarray(array.count(dims))
+ size += math.prod([self.sizes[d] for d in dims])
+
+ if thresh is not None:
+ mask = count >= thresh
+ elif how == "any":
+ mask = count == size
+ elif how == "all":
+ mask = count > 0
+ elif how is not None:
+ raise ValueError(f"invalid how option: {how}")
+ else:
+ raise TypeError("must specify how or thresh")
+
+ return self.isel({dim: mask})
- def fillna(self, value: Any) ->Self:
+ def fillna(self, value: Any) -> Self:
"""Fill missing values in this object.
This operation follows the normal broadcasting and alignment rules that
@@ -3722,20 +6570,41 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
C (x) float64 32B 2.0 2.0 2.0 5.0
D (x) float64 32B 3.0 3.0 3.0 4.0
"""
- pass
-
- def interpolate_na(self, dim: (Hashable | None)=None, method:
- InterpOptions='linear', limit: (int | None)=None, use_coordinate: (
- bool | Hashable)=True, max_gap: (int | float | str | pd.Timedelta |
- np.timedelta64 | datetime.timedelta | None)=None, **kwargs: Any
- ) ->Self:
+ if utils.is_dict_like(value):
+ value_keys = getattr(value, "data_vars", value).keys()
+ if not set(value_keys) <= set(self.data_vars.keys()):
+ raise ValueError(
+ "all variables in the argument to `fillna` "
+ "must be contained in the original dataset"
+ )
+ out = ops.fillna(self, value)
+ return out
+
+ def interpolate_na(
+ self,
+ dim: Hashable | None = None,
+ method: InterpOptions = "linear",
+ limit: int | None = None,
+ use_coordinate: bool | Hashable = True,
+ max_gap: (
+ int
+ | float
+ | str
+ | pd.Timedelta
+ | np.timedelta64
+ | datetime.timedelta
+ | None
+ ) = None,
+ **kwargs: Any,
+ ) -> Self:
"""Fill in NaNs by interpolating according to different methods.
Parameters
----------
dim : Hashable or None, optional
Specifies the dimension along which to interpolate.
- method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
+ method : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial", \
+ "barycentric", "krogh", "pchip", "spline", "akima"}, default: "linear"
String indicating which method to use for interpolation:
- 'linear': linear interpolation. Additional keyword
@@ -3758,7 +6627,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
or None for no limit. This filling is done regardless of the size of
the gap in the data. To only interpolate over gaps less than a given length,
see ``max_gap``.
- max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta or None, default: None
+ max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta \
+ or None, default: None
Maximum size of gap, a continuous sequence of NaNs, that will be filled.
Use None for no limit. When interpolating along a datetime64 dimension
and ``use_coordinate=True``, ``max_gap`` can be one of the following:
@@ -3843,9 +6713,21 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
C (x) float64 40B 20.0 15.0 10.0 5.0 0.0
D (x) float64 40B 5.0 3.0 1.0 -1.0 4.0
"""
- pass
-
- def ffill(self, dim: Hashable, limit: (int | None)=None) ->Self:
+ from xarray.core.missing import _apply_over_vars_with_dim, interp_na
+
+ new = _apply_over_vars_with_dim(
+ interp_na,
+ self,
+ dim=dim,
+ method=method,
+ limit=limit,
+ use_coordinate=use_coordinate,
+ max_gap=max_gap,
+ **kwargs,
+ )
+ return new
+
+ def ffill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values forward
*Requires bottleneck.*
@@ -3904,9 +6786,12 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.bfill
"""
- pass
+ from xarray.core.missing import _apply_over_vars_with_dim, ffill
+
+ new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit)
+ return new
- def bfill(self, dim: Hashable, limit: (int | None)=None) ->Self:
+ def bfill(self, dim: Hashable, limit: int | None = None) -> Self:
"""Fill NaN values by propagating values backward
*Requires bottleneck.*
@@ -3966,9 +6851,12 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.ffill
"""
- pass
+ from xarray.core.missing import _apply_over_vars_with_dim, bfill
- def combine_first(self, other: Self) ->Self:
+ new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit)
+ return new
+
+ def combine_first(self, other: Self) -> Self:
"""Combine two Datasets, default to data_vars of self.
The new coordinates follow the normal broadcasting and alignment rules
@@ -3984,11 +6872,19 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
-------
Dataset
"""
- pass
-
- def reduce(self, func: Callable, dim: Dims=None, *, keep_attrs: (bool |
- None)=None, keepdims: bool=False, numeric_only: bool=False, **
- kwargs: Any) ->Self:
+ out = ops.fillna(self, other, join="outer", dataset_join="outer")
+ return out
+
+ def reduce(
+ self,
+ func: Callable,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ numeric_only: bool = False,
+ **kwargs: Any,
+ ) -> Self:
"""Reduce this dataset by applying `func` along some dimension(s).
Parameters
@@ -4051,10 +6947,77 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
math_scores (student) float64 24B 91.0 82.5 96.5
english_scores (student) float64 24B 91.0 80.5 94.5
"""
- pass
+ if kwargs.get("axis", None) is not None:
+ raise ValueError(
+ "passing 'axis' to Dataset reduce methods is ambiguous."
+ " Please use 'dim' instead."
+ )
+
+ if dim is None or dim is ...:
+ dims = set(self.dims)
+ elif isinstance(dim, str) or not isinstance(dim, Iterable):
+ dims = {dim}
+ else:
+ dims = set(dim)
+
+ missing_dimensions = tuple(d for d in dims if d not in self.dims)
+ if missing_dimensions:
+ raise ValueError(
+ f"Dimensions {missing_dimensions} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ variables: dict[Hashable, Variable] = {}
+ for name, var in self._variables.items():
+ reduce_dims = [d for d in var.dims if d in dims]
+ if name in self.coords:
+ if not reduce_dims:
+ variables[name] = var
+ else:
+ if (
+ # Some reduction functions (e.g. std, var) need to run on variables
+ # that don't have the reduce dims: PR5393
+ not is_extension_array_dtype(var.dtype)
+ and (
+ not reduce_dims
+ or not numeric_only
+ or np.issubdtype(var.dtype, np.number)
+ or (var.dtype == np.bool_)
+ )
+ ):
+ # prefer to aggregate over axis=None rather than
+ # axis=(0, 1) if they will be equivalent, because
+ # the former is often more efficient
+ # keep single-element dims as list, to support Hashables
+ reduce_maybe_single = (
+ None
+ if len(reduce_dims) == var.ndim and var.ndim != 1
+ else reduce_dims
+ )
+ variables[name] = var.reduce(
+ func,
+ dim=reduce_maybe_single,
+ keep_attrs=keep_attrs,
+ keepdims=keepdims,
+ **kwargs,
+ )
- def map(self, func: Callable, keep_attrs: (bool | None)=None, args:
- Iterable[Any]=(), **kwargs: Any) ->Self:
+ coord_names = {k for k in self.coords if k in variables}
+ indexes = {k: v for k, v in self._indexes.items() if k in variables}
+ attrs = self.attrs if keep_attrs else None
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, attrs=attrs, indexes=indexes
+ )
+
+ def map(
+ self,
+ func: Callable,
+ keep_attrs: bool | None = None,
+ args: Iterable[Any] = (),
+ **kwargs: Any,
+ ) -> Self:
"""Apply a function to each data variable in this dataset
Parameters
@@ -4096,10 +7059,25 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 16B 1.0 2.0
"""
- pass
-
- def apply(self, func: Callable, keep_attrs: (bool | None)=None, args:
- Iterable[Any]=(), **kwargs: Any) ->Self:
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+ variables = {
+ k: maybe_wrap_array(v, func(v, *args, **kwargs))
+ for k, v in self.data_vars.items()
+ }
+ if keep_attrs:
+ for k, v in variables.items():
+ v._copy_attrs_from(self.data_vars[k])
+ attrs = self.attrs if keep_attrs else None
+ return type(self)(variables, attrs=attrs)
+
+ def apply(
+ self,
+ func: Callable,
+ keep_attrs: bool | None = None,
+ args: Iterable[Any] = (),
+ **kwargs: Any,
+ ) -> Self:
"""
Backward compatible implementation of ``map``
@@ -4107,10 +7085,18 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.map
"""
- pass
-
- def assign(self, variables: (Mapping[Any, Any] | None)=None, **
- variables_kwargs: Any) ->Self:
+ warnings.warn(
+ "Dataset.apply may be deprecated in the future. Using Dataset.map is encouraged",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.map(func, keep_attrs, args, **kwargs)
+
+ def assign(
+ self,
+ variables: Mapping[Any, Any] | None = None,
+ **variables_kwargs: Any,
+ ) -> Self:
"""Assign new data variables to a Dataset, returning a new object
with all the original variables in addition to the new ones.
@@ -4196,10 +7182,30 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
temperature_f (lat, lon) float64 32B 51.76 57.75 53.7 51.62
"""
- pass
+ variables = either_dict_or_kwargs(variables, variables_kwargs, "assign")
+ data = self.copy()
+
+ # do all calculations first...
+ results: CoercibleMapping = data._calc_assign_results(variables)
+
+ # split data variables to add/replace vs. coordinates to replace
+ results_data_vars: dict[Hashable, CoercibleValue] = {}
+ results_coords: dict[Hashable, CoercibleValue] = {}
+ for k, v in results.items():
+ if k in data._coord_names:
+ results_coords[k] = v
+ else:
+ results_data_vars[k] = v
- def to_dataarray(self, dim: Hashable='variable', name: (Hashable | None
- )=None) ->DataArray:
+ # ... and then assign
+ data.coords.update(results_coords)
+ data.update(results_data_vars)
+
+ return data
+
+ def to_dataarray(
+ self, dim: Hashable = "variable", name: Hashable | None = None
+ ) -> DataArray:
"""Convert this dataset into an xarray.DataArray
The data variables of this dataset will be broadcast against each other
@@ -4217,15 +7223,32 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
-------
array : xarray.DataArray
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ data_vars = [self.variables[k] for k in self.data_vars]
+ broadcast_vars = broadcast_variables(*data_vars)
+ data = duck_array_ops.stack([b.data for b in broadcast_vars], axis=0)
- def to_array(self, dim: Hashable='variable', name: (Hashable | None)=None
- ) ->DataArray:
+ dims = (dim,) + broadcast_vars[0].dims
+ variable = Variable(dims, data, self.attrs, fastpath=True)
+
+ coords = {k: v.variable for k, v in self.coords.items()}
+ indexes = filter_indexes_from_coords(self._indexes, set(coords))
+ new_dim_index = PandasIndex(list(self.data_vars), dim)
+ indexes[dim] = new_dim_index
+ coords.update(new_dim_index.create_variables())
+
+ return DataArray._construct_direct(variable, coords, name, indexes)
+
+ def to_array(
+ self, dim: Hashable = "variable", name: Hashable | None = None
+ ) -> DataArray:
"""Deprecated version of to_dataarray"""
- pass
+ return self.to_dataarray(dim=dim, name=name)
- def _normalize_dim_order(self, dim_order: (Sequence[Hashable] | None)=None
- ) ->dict[Hashable, int]:
+ def _normalize_dim_order(
+ self, dim_order: Sequence[Hashable] | None = None
+ ) -> dict[Hashable, int]:
"""
Check the validity of the provided dimensions if any and return the mapping
between dimension name and their size.
@@ -4241,9 +7264,19 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Validated dimensions mapping.
"""
- pass
+ if dim_order is None:
+ dim_order = list(self.dims)
+ elif set(dim_order) != set(self.dims):
+ raise ValueError(
+ f"dim_order {dim_order} does not match the set of dimensions of this "
+ f"Dataset: {list(self.dims)}"
+ )
+
+ ordered_dims = {k: self.sizes[k] for k in dim_order}
- def to_pandas(self) ->(pd.Series | pd.DataFrame):
+ return ordered_dims
+
+ def to_pandas(self) -> pd.Series | pd.DataFrame:
"""Convert this dataset into a pandas object without changing the number of dimensions.
The type of the returned object depends on the number of Dataset
@@ -4254,10 +7287,50 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Only works for Datasets with 1 or fewer dimensions.
"""
- pass
+ if len(self.dims) == 0:
+ return pd.Series({k: v.item() for k, v in self.items()})
+ if len(self.dims) == 1:
+ return self.to_dataframe()
+ raise ValueError(
+ f"cannot convert Datasets with {len(self.dims)} dimensions into "
+ "pandas objects without changing the number of dimensions. "
+ "Please use Dataset.to_dataframe() instead."
+ )
+
+ def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
+ columns_in_order = [k for k in self.variables if k not in self.dims]
+ non_extension_array_columns = [
+ k
+ for k in columns_in_order
+ if not is_extension_array_dtype(self.variables[k].data)
+ ]
+ extension_array_columns = [
+ k
+ for k in columns_in_order
+ if is_extension_array_dtype(self.variables[k].data)
+ ]
+ data = [
+ self._variables[k].set_dims(ordered_dims).values.reshape(-1)
+ for k in non_extension_array_columns
+ ]
+ index = self.coords.to_index([*ordered_dims])
+ broadcasted_df = pd.DataFrame(
+ dict(zip(non_extension_array_columns, data)), index=index
+ )
+ for extension_array_column in extension_array_columns:
+ extension_array = self.variables[extension_array_column].data.array
+ index = self[self.variables[extension_array_column].dims[0]].data
+ extension_array_df = pd.DataFrame(
+ {extension_array_column: extension_array},
+ index=self[self.variables[extension_array_column].dims[0]].data,
+ )
+ extension_array_df.index.name = self.variables[extension_array_column].dims[
+ 0
+ ]
+ broadcasted_df = broadcasted_df.join(extension_array_df)
+ return broadcasted_df[columns_in_order]
- def to_dataframe(self, dim_order: (Sequence[Hashable] | None)=None
- ) ->pd.DataFrame:
+ def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame:
"""Convert this dataset into a pandas.DataFrame.
Non-index variables in this dataset form the columns of the
@@ -4283,11 +7356,82 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset as a pandas DataFrame.
"""
- pass
+
+ ordered_dims = self._normalize_dim_order(dim_order=dim_order)
+
+ return self._to_dataframe(ordered_dims=ordered_dims)
+
+ def _set_sparse_data_from_dataframe(
+ self, idx: pd.Index, arrays: list[tuple[Hashable, np.ndarray]], dims: tuple
+ ) -> None:
+ from sparse import COO
+
+ if isinstance(idx, pd.MultiIndex):
+ coords = np.stack([np.asarray(code) for code in idx.codes], axis=0)
+ is_sorted = idx.is_monotonic_increasing
+ shape = tuple(lev.size for lev in idx.levels)
+ else:
+ coords = np.arange(idx.size).reshape(1, -1)
+ is_sorted = True
+ shape = (idx.size,)
+
+ for name, values in arrays:
+ # In virtually all real use cases, the sparse array will now have
+ # missing values and needs a fill_value. For consistency, don't
+ # special case the rare exceptions (e.g., dtype=int without a
+ # MultiIndex).
+ dtype, fill_value = xrdtypes.maybe_promote(values.dtype)
+ values = np.asarray(values, dtype=dtype)
+
+ data = COO(
+ coords,
+ values,
+ shape,
+ has_duplicates=False,
+ sorted=is_sorted,
+ fill_value=fill_value,
+ )
+ self[name] = (dims, data)
+
+ def _set_numpy_data_from_dataframe(
+ self, idx: pd.Index, arrays: list[tuple[Hashable, np.ndarray]], dims: tuple
+ ) -> None:
+ if not isinstance(idx, pd.MultiIndex):
+ for name, values in arrays:
+ self[name] = (dims, values)
+ return
+
+ # NB: similar, more general logic, now exists in
+ # variable.unstack_once; we could consider combining them at some
+ # point.
+
+ shape = tuple(lev.size for lev in idx.levels)
+ indexer = tuple(idx.codes)
+
+ # We already verified that the MultiIndex has all unique values, so
+ # there are missing values if and only if the size of output arrays is
+ # larger that the index.
+ missing_values = math.prod(shape) > idx.shape[0]
+
+ for name, values in arrays:
+ # NumPy indexing is much faster than using DataFrame.reindex() to
+ # fill in missing values:
+ # https://stackoverflow.com/a/35049899/809705
+ if missing_values:
+ dtype, fill_value = xrdtypes.maybe_promote(values.dtype)
+ data = np.full(shape, fill_value, dtype)
+ else:
+ # If there are no missing values, keep the existing dtype
+ # instead of promoting to support NA, e.g., keep integer
+ # columns as integers.
+ # TODO: consider removing this special case, which doesn't
+ # exist for sparse=True.
+ data = np.zeros(shape, values.dtype)
+ data[indexer] = values
+ self[name] = (dims, data)
@classmethod
- def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool=False
- ) ->Self:
+ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
"""Convert a pandas.DataFrame into an xarray.Dataset
Each column will be converted into an independent variable in the
@@ -4317,10 +7461,64 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
xarray.DataArray.from_series
pandas.DataFrame.to_xarray
"""
- pass
+ # TODO: Add an option to remove dimensions along which the variables
+ # are constant, to enable consistent serialization to/from a dataframe,
+ # even if some variables have different dimensionality.
+
+ if not dataframe.columns.is_unique:
+ raise ValueError("cannot convert DataFrame with non-unique columns")
+
+ idx = remove_unused_levels_categories(dataframe.index)
+
+ if isinstance(idx, pd.MultiIndex) and not idx.is_unique:
+ raise ValueError(
+ "cannot convert a DataFrame with a non-unique MultiIndex into xarray"
+ )
+
+ arrays = []
+ extension_arrays = []
+ for k, v in dataframe.items():
+ if not is_extension_array_dtype(v) or isinstance(
+ v.array, (pd.arrays.DatetimeArray, pd.arrays.TimedeltaArray)
+ ):
+ arrays.append((k, np.asarray(v)))
+ else:
+ extension_arrays.append((k, v))
+
+ indexes: dict[Hashable, Index] = {}
+ index_vars: dict[Hashable, Variable] = {}
+
+ if isinstance(idx, pd.MultiIndex):
+ dims = tuple(
+ name if name is not None else "level_%i" % n
+ for n, name in enumerate(idx.names)
+ )
+ for dim, lev in zip(dims, idx.levels):
+ xr_idx = PandasIndex(lev, dim)
+ indexes[dim] = xr_idx
+ index_vars.update(xr_idx.create_variables())
+ arrays += [(k, np.asarray(v)) for k, v in extension_arrays]
+ extension_arrays = []
+ else:
+ index_name = idx.name if idx.name is not None else "index"
+ dims = (index_name,)
+ xr_idx = PandasIndex(idx, index_name)
+ indexes[index_name] = xr_idx
+ index_vars.update(xr_idx.create_variables())
+
+ obj = cls._construct_direct(index_vars, set(index_vars), indexes=indexes)
+
+ if sparse:
+ obj._set_sparse_data_from_dataframe(idx, arrays, dims)
+ else:
+ obj._set_numpy_data_from_dataframe(idx, arrays, dims)
+ for name, extension_array in extension_arrays:
+ obj[name] = (dims, extension_array)
+ return obj[dataframe.columns] if len(dataframe.columns) else obj
- def to_dask_dataframe(self, dim_order: (Sequence[Hashable] | None)=None,
- set_index: bool=False) ->DaskDataFrame:
+ def to_dask_dataframe(
+ self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False
+ ) -> DaskDataFrame:
"""
Convert this dataset into a dask.dataframe.DataFrame.
@@ -4348,10 +7546,63 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
-------
dask.dataframe.DataFrame
"""
- pass
- def to_dict(self, data: (bool | Literal['list', 'array'])='list',
- encoding: bool=False) ->dict[str, Any]:
+ import dask.array as da
+ import dask.dataframe as dd
+
+ ordered_dims = self._normalize_dim_order(dim_order=dim_order)
+
+ columns = list(ordered_dims)
+ columns.extend(k for k in self.coords if k not in self.dims)
+ columns.extend(self.data_vars)
+
+ ds_chunks = self.chunks
+
+ series_list = []
+ df_meta = pd.DataFrame()
+ for name in columns:
+ try:
+ var = self.variables[name]
+ except KeyError:
+ # dimension without a matching coordinate
+ size = self.sizes[name]
+ data = da.arange(size, chunks=size, dtype=np.int64)
+ var = Variable((name,), data)
+
+ # IndexVariable objects have a dummy .chunk() method
+ if isinstance(var, IndexVariable):
+ var = var.to_base_variable()
+
+ # Make sure var is a dask array, otherwise the array can become too large
+ # when it is broadcasted to several dimensions:
+ if not is_duck_dask_array(var._data):
+ var = var.chunk()
+
+ # Broadcast then flatten the array:
+ var_new_dims = var.set_dims(ordered_dims).chunk(ds_chunks)
+ dask_array = var_new_dims._data.reshape(-1)
+
+ series = dd.from_dask_array(dask_array, columns=name, meta=df_meta)
+ series_list.append(series)
+
+ df = dd.concat(series_list, axis=1)
+
+ if set_index:
+ dim_order = [*ordered_dims]
+
+ if len(dim_order) == 1:
+ (dim,) = dim_order
+ df = df.set_index(dim)
+ else:
+ # triggers an error about multi-indexes, even if only one
+ # dimension is passed
+ df = df.set_index(dim_order)
+
+ return df
+
+ def to_dict(
+ self, data: bool | Literal["list", "array"] = "list", encoding: bool = False
+ ) -> dict[str, Any]:
"""
Convert this dataset to a dictionary following xarray naming
conventions.
@@ -4384,17 +7635,34 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.from_dict
DataArray.to_dict
"""
- pass
+ d: dict = {
+ "coords": {},
+ "attrs": decode_numpy_dict_values(self.attrs),
+ "dims": dict(self.sizes),
+ "data_vars": {},
+ }
+ for k in self.coords:
+ d["coords"].update(
+ {k: self[k].variable.to_dict(data=data, encoding=encoding)}
+ )
+ for k in self.data_vars:
+ d["data_vars"].update(
+ {k: self[k].variable.to_dict(data=data, encoding=encoding)}
+ )
+ if encoding:
+ d["encoding"] = dict(self.encoding)
+ return d
@classmethod
- def from_dict(cls, d: Mapping[Any, Any]) ->Self:
+ def from_dict(cls, d: Mapping[Any, Any]) -> Self:
"""Convert a dictionary into an xarray.Dataset.
Parameters
----------
d : dict-like
Mapping with a minimum structure of
- ``{"var_0": {"dims": [..], "data": [..]}, ...}``
+ ``{"var_0": {"dims": [..], "data": [..]}, \
+ ...}``
Returns
-------
@@ -4446,11 +7714,147 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
title: air temperature
"""
- pass
- @_deprecate_positional_args('v2023.10.0')
- def diff(self, dim: Hashable, n: int=1, *, label: Literal['upper',
- 'lower']='upper') ->Self:
+ variables: Iterable[tuple[Hashable, Any]]
+ if not {"coords", "data_vars"}.issubset(set(d)):
+ variables = d.items()
+ else:
+ import itertools
+
+ variables = itertools.chain(
+ d.get("coords", {}).items(), d.get("data_vars", {}).items()
+ )
+ try:
+ variable_dict = {
+ k: (v["dims"], v["data"], v.get("attrs"), v.get("encoding"))
+ for k, v in variables
+ }
+ except KeyError as e:
+ raise ValueError(f"cannot convert dict without the key '{str(e.args[0])}'")
+ obj = cls(variable_dict)
+
+ # what if coords aren't dims?
+ coords = set(d.get("coords", {})) - set(d.get("dims", {}))
+ obj = obj.set_coords(coords)
+
+ obj.attrs.update(d.get("attrs", {}))
+ obj.encoding.update(d.get("encoding", {}))
+
+ return obj
+
+ def _unary_op(self, f, *args, **kwargs) -> Self:
+ variables = {}
+ keep_attrs = kwargs.pop("keep_attrs", None)
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+ for k, v in self._variables.items():
+ if k in self._coord_names:
+ variables[k] = v
+ else:
+ variables[k] = f(v, *args, **kwargs)
+ if keep_attrs:
+ variables[k]._attrs = v._attrs
+ attrs = self._attrs if keep_attrs else None
+ return self._replace_with_new_dims(variables, attrs=attrs)
+
+ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
+ from xarray.core.dataarray import DataArray
+ from xarray.core.groupby import GroupBy
+
+ if isinstance(other, GroupBy):
+ return NotImplemented
+ align_type = OPTIONS["arithmetic_join"] if join is None else join
+ if isinstance(other, (DataArray, Dataset)):
+ self, other = align(self, other, join=align_type, copy=False)
+ g = f if not reflexive else lambda x, y: f(y, x)
+ ds = self._calculate_binary_op(g, other, join=align_type)
+ keep_attrs = _get_keep_attrs(default=False)
+ if keep_attrs:
+ ds.attrs = self.attrs
+ return ds
+
+ def _inplace_binary_op(self, other, f) -> Self:
+ from xarray.core.dataarray import DataArray
+ from xarray.core.groupby import GroupBy
+
+ if isinstance(other, GroupBy):
+ raise TypeError(
+ "in-place operations between a Dataset and "
+ "a grouped object are not permitted"
+ )
+ # we don't actually modify arrays in-place with in-place Dataset
+ # arithmetic -- this lets us automatically align things
+ if isinstance(other, (DataArray, Dataset)):
+ other = other.reindex_like(self, copy=False)
+ g = ops.inplace_to_noninplace_op(f)
+ ds = self._calculate_binary_op(g, other, inplace=True)
+ self._replace_with_new_dims(
+ ds._variables,
+ ds._coord_names,
+ attrs=ds._attrs,
+ indexes=ds._indexes,
+ inplace=True,
+ )
+ return self
+
+ def _calculate_binary_op(
+ self, f, other, join="inner", inplace: bool = False
+ ) -> Dataset:
+ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
+ if inplace and set(lhs_data_vars) != set(rhs_data_vars):
+ raise ValueError(
+ "datasets must have the same data variables "
+ f"for in-place arithmetic operations: {list(lhs_data_vars)}, {list(rhs_data_vars)}"
+ )
+
+ dest_vars = {}
+
+ for k in lhs_data_vars:
+ if k in rhs_data_vars:
+ dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
+ elif join in ["left", "outer"]:
+ dest_vars[k] = f(lhs_vars[k], np.nan)
+ for k in rhs_data_vars:
+ if k not in dest_vars and join in ["right", "outer"]:
+ dest_vars[k] = f(rhs_vars[k], np.nan)
+ return dest_vars
+
+ if utils.is_dict_like(other) and not isinstance(other, Dataset):
+ # can't use our shortcut of doing the binary operation with
+ # Variable objects, so apply over our data vars instead.
+ new_data_vars = apply_over_both(
+ self.data_vars, other, self.data_vars, other
+ )
+ return type(self)(new_data_vars)
+
+ other_coords: Coordinates | None = getattr(other, "coords", None)
+ ds = self.coords.merge(other_coords)
+
+ if isinstance(other, Dataset):
+ new_vars = apply_over_both(
+ self.data_vars, other.data_vars, self.variables, other.variables
+ )
+ else:
+ other_variable = getattr(other, "variable", other)
+ new_vars = {k: f(self.variables[k], other_variable) for k in self.data_vars}
+ ds._variables.update(new_vars)
+ ds._dims = calculate_dimensions(ds._variables)
+ return ds
+
+ def _copy_attrs_from(self, other):
+ self.attrs = other.attrs
+ for v in other.variables:
+ if v in self.variables:
+ self.variables[v].attrs = other.variables[v].attrs
+
+ @_deprecate_positional_args("v2023.10.0")
+ def diff(
+ self,
+ dim: Hashable,
+ n: int = 1,
+ *,
+ label: Literal["upper", "lower"] = "upper",
+ ) -> Self:
"""Calculate the n-th order discrete difference along given axis.
Parameters
@@ -4494,10 +7898,50 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
Dataset.differentiate
"""
- pass
+ if n == 0:
+ return self
+ if n < 0:
+ raise ValueError(f"order `n` must be non-negative but got {n}")
+
+ # prepare slices
+ slice_start = {dim: slice(None, -1)}
+ slice_end = {dim: slice(1, None)}
+
+ # prepare new coordinate
+ if label == "upper":
+ slice_new = slice_end
+ elif label == "lower":
+ slice_new = slice_start
+ else:
+ raise ValueError("The 'label' argument has to be either 'upper' or 'lower'")
+
+ indexes, index_vars = isel_indexes(self.xindexes, slice_new)
+ variables = {}
+
+ for name, var in self.variables.items():
+ if name in index_vars:
+ variables[name] = index_vars[name]
+ elif dim in var.dims:
+ if name in self.data_vars:
+ variables[name] = var.isel(slice_end) - var.isel(slice_start)
+ else:
+ variables[name] = var.isel(slice_new)
+ else:
+ variables[name] = var
+
+ difference = self._replace_with_new_dims(variables, indexes=indexes)
- def shift(self, shifts: (Mapping[Any, int] | None)=None, fill_value:
- Any=xrdtypes.NA, **shifts_kwargs: int) ->Self:
+ if n > 1:
+ return difference.diff(dim, n - 1)
+ else:
+ return difference
+
+ def shift(
+ self,
+ shifts: Mapping[Any, int] | None = None,
+ fill_value: Any = xrdtypes.NA,
+ **shifts_kwargs: int,
+ ) -> Self:
"""Shift this dataset by an offset along one or more dimensions.
Only data variables are moved; coordinates stay in place. This is
@@ -4540,10 +7984,35 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Data variables:
foo (x) object 40B nan nan 'a' 'b' 'c'
"""
- pass
+ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift")
+ invalid = tuple(k for k in shifts if k not in self.dims)
+ if invalid:
+ raise ValueError(
+ f"Dimensions {invalid} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ variables = {}
+ for name, var in self.variables.items():
+ if name in self.data_vars:
+ fill_value_ = (
+ fill_value.get(name, xrdtypes.NA)
+ if isinstance(fill_value, dict)
+ else fill_value
+ )
+
+ var_shifts = {k: v for k, v in shifts.items() if k in var.dims}
+ variables[name] = var.shift(fill_value=fill_value_, shifts=var_shifts)
+ else:
+ variables[name] = var
+
+ return self._replace(variables)
- def roll(self, shifts: (Mapping[Any, int] | None)=None, roll_coords:
- bool=False, **shifts_kwargs: int) ->Self:
+ def roll(
+ self,
+ shifts: Mapping[Any, int] | None = None,
+ roll_coords: bool = False,
+ **shifts_kwargs: int,
+ ) -> Self:
"""Roll this dataset by an offset along one or more dimensions.
Unlike shift, roll treats the given dimensions as periodic, so will not
@@ -4594,11 +8063,46 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
foo (x) <U1 20B 'd' 'e' 'a' 'b' 'c'
"""
- pass
+ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll")
+ invalid = [k for k in shifts if k not in self.dims]
+ if invalid:
+ raise ValueError(
+ f"Dimensions {invalid} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ unrolled_vars: tuple[Hashable, ...]
+
+ if roll_coords:
+ indexes, index_vars = roll_indexes(self.xindexes, shifts)
+ unrolled_vars = ()
+ else:
+ indexes = dict(self._indexes)
+ index_vars = dict(self.xindexes.variables)
+ unrolled_vars = tuple(self.coords)
+
+ variables = {}
+ for k, var in self.variables.items():
+ if k in index_vars:
+ variables[k] = index_vars[k]
+ elif k not in unrolled_vars:
+ variables[k] = var.roll(
+ shifts={k: s for k, s in shifts.items() if k in var.dims}
+ )
+ else:
+ variables[k] = var
+
+ return self._replace(variables, indexes=indexes)
- def sortby(self, variables: (Hashable | DataArray | Sequence[Hashable |
- DataArray] | Callable[[Self], Hashable | DataArray | list[Hashable |
- DataArray]]), ascending: bool=True) ->Self:
+ def sortby(
+ self,
+ variables: (
+ Hashable
+ | DataArray
+ | Sequence[Hashable | DataArray]
+ | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]]
+ ),
+ ascending: bool = True,
+ ) -> Self:
"""
Sort object by labels or values (along an axis).
@@ -4666,13 +8170,43 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
A (x, y) int64 32B 1 2 3 4
B (x, y) int64 32B 5 6 7 8
"""
- pass
+ from xarray.core.dataarray import DataArray
- @_deprecate_positional_args('v2023.10.0')
- def quantile(self, q: ArrayLike, dim: Dims=None, *, method:
- QuantileMethods='linear', numeric_only: bool=False, keep_attrs: (
- bool | None)=None, skipna: (bool | None)=None, interpolation: (
- QuantileMethods | None)=None) ->Self:
+ if callable(variables):
+ variables = variables(self)
+ if not isinstance(variables, list):
+ variables = [variables]
+ else:
+ variables = variables
+ arrays = [v if isinstance(v, DataArray) else self[v] for v in variables]
+ aligned_vars = align(self, *arrays, join="left")
+ aligned_self = cast("Self", aligned_vars[0])
+ aligned_other_vars = cast(tuple[DataArray, ...], aligned_vars[1:])
+ vars_by_dim = defaultdict(list)
+ for data_array in aligned_other_vars:
+ if data_array.ndim != 1:
+ raise ValueError("Input DataArray is not 1-D.")
+ (key,) = data_array.dims
+ vars_by_dim[key].append(data_array)
+
+ indices = {}
+ for key, arrays in vars_by_dim.items():
+ order = np.lexsort(tuple(reversed(arrays)))
+ indices[key] = order if ascending else order[::-1]
+ return aligned_self.isel(indices)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def quantile(
+ self,
+ q: ArrayLike,
+ dim: Dims = None,
+ *,
+ method: QuantileMethods = "linear",
+ numeric_only: bool = False,
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ interpolation: QuantileMethods | None = None,
+ ) -> Self:
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements for each variable
@@ -4779,11 +8313,76 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
- pass
- @_deprecate_positional_args('v2023.10.0')
- def rank(self, dim: Hashable, *, pct: bool=False, keep_attrs: (bool |
- None)=None) ->Self:
+ # interpolation renamed to method in version 0.21.0
+ # check here and in variable to avoid repeated warnings
+ if interpolation is not None:
+ warnings.warn(
+ "The `interpolation` argument to quantile was renamed to `method`.",
+ FutureWarning,
+ )
+
+ if method != "linear":
+ raise TypeError("Cannot pass interpolation and method keywords!")
+
+ method = interpolation
+
+ dims: set[Hashable]
+ if isinstance(dim, str):
+ dims = {dim}
+ elif dim is None or dim is ...:
+ dims = set(self.dims)
+ else:
+ dims = set(dim)
+
+ invalid_dims = set(dims) - set(self.dims)
+ if invalid_dims:
+ raise ValueError(
+ f"Dimensions {tuple(invalid_dims)} not found in data dimensions {tuple(self.dims)}"
+ )
+
+ q = np.asarray(q, dtype=np.float64)
+
+ variables = {}
+ for name, var in self.variables.items():
+ reduce_dims = [d for d in var.dims if d in dims]
+ if reduce_dims or not var.dims:
+ if name not in self.coords:
+ if (
+ not numeric_only
+ or np.issubdtype(var.dtype, np.number)
+ or var.dtype == np.bool_
+ ):
+ variables[name] = var.quantile(
+ q,
+ dim=reduce_dims,
+ method=method,
+ keep_attrs=keep_attrs,
+ skipna=skipna,
+ )
+
+ else:
+ variables[name] = var
+
+ # construct the new dataset
+ coord_names = {k for k in self.coords if k in variables}
+ indexes = {k: v for k, v in self._indexes.items() if k in variables}
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+ attrs = self.attrs if keep_attrs else None
+ new = self._replace_with_new_dims(
+ variables, coord_names=coord_names, attrs=attrs, indexes=indexes
+ )
+ return new.assign_coords(quantile=q)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def rank(
+ self,
+ dim: Hashable,
+ *,
+ pct: bool = False,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""Ranks the data.
Equal values are assigned a rank that is the average of the ranks that
@@ -4811,10 +8410,37 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
ranked : Dataset
Variables that do not depend on `dim` are dropped.
"""
- pass
+ if not OPTIONS["use_bottleneck"]:
+ raise RuntimeError(
+ "rank requires bottleneck to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` to enable it."
+ )
+
+ if dim not in self.dims:
+ raise ValueError(
+ f"Dimension {dim!r} not found in data dimensions {tuple(self.dims)}"
+ )
- def differentiate(self, coord: Hashable, edge_order: Literal[1, 2]=1,
- datetime_unit: (DatetimeUnitOptions | None)=None) ->Self:
+ variables = {}
+ for name, var in self.variables.items():
+ if name in self.data_vars:
+ if dim in var.dims:
+ variables[name] = var.rank(dim, pct=pct)
+ else:
+ variables[name] = var
+
+ coord_names = set(self.coords)
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+ attrs = self.attrs if keep_attrs else None
+ return self._replace(variables, coord_names, attrs=attrs)
+
+ def differentiate(
+ self,
+ coord: Hashable,
+ edge_order: Literal[1, 2] = 1,
+ datetime_unit: DatetimeUnitOptions | None = None,
+ ) -> Self:
"""Differentiate with the second order accurate central
differences.
@@ -4828,7 +8454,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
The coordinate to be used to compute the gradient.
edge_order : {1, 2}, default: 1
N-th order accurate differences at the boundaries.
- datetime_unit : None or {"Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as", None}, default: None
+ datetime_unit : None or {"Y", "M", "W", "D", "h", "m", "s", "ms", \
+ "us", "ns", "ps", "fs", "as", None}, default: None
Unit to compute gradient. Only valid for datetime coordinate.
Returns
@@ -4839,10 +8466,52 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
numpy.gradient: corresponding numpy function
"""
- pass
+ from xarray.core.variable import Variable
+
+ if coord not in self.variables and coord not in self.dims:
+ variables_and_dims = tuple(set(self.variables.keys()).union(self.dims))
+ raise ValueError(
+ f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}."
+ )
+
+ coord_var = self[coord].variable
+ if coord_var.ndim != 1:
+ raise ValueError(
+ f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}"
+ " dimensional"
+ )
- def integrate(self, coord: (Hashable | Sequence[Hashable]),
- datetime_unit: DatetimeUnitOptions=None) ->Self:
+ dim = coord_var.dims[0]
+ if _contains_datetime_like_objects(coord_var):
+ if coord_var.dtype.kind in "mM" and datetime_unit is None:
+ datetime_unit = cast(
+ "DatetimeUnitOptions", np.datetime_data(coord_var.dtype)[0]
+ )
+ elif datetime_unit is None:
+ datetime_unit = "s" # Default to seconds for cftime objects
+ coord_var = coord_var._to_numeric(datetime_unit=datetime_unit)
+
+ variables = {}
+ for k, v in self.variables.items():
+ if k in self.data_vars and dim in v.dims and k not in self.coords:
+ if _contains_datetime_like_objects(v):
+ v = v._to_numeric(datetime_unit=datetime_unit)
+ grad = duck_array_ops.gradient(
+ v.data,
+ coord_var.data,
+ edge_order=edge_order,
+ axis=v.get_axis_num(dim),
+ )
+ variables[k] = Variable(v.dims, grad)
+ else:
+ variables[k] = v
+ return self._replace(variables)
+
+ def integrate(
+ self,
+ coord: Hashable | Sequence[Hashable],
+ datetime_unit: DatetimeUnitOptions = None,
+ ) -> Self:
"""Integrate along the given coordinate using the trapezoidal rule.
.. note::
@@ -4853,7 +8522,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
----------
coord : hashable, or sequence of hashable
Coordinate(s) used for the integration.
- datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as', None}, optional
+ datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
+ 'ps', 'fs', 'as', None}, optional
Specify the unit if datetime coordinate is used.
Returns
@@ -4893,10 +8563,74 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
a float64 8B 20.0
b float64 8B 4.0
"""
- pass
+ if not isinstance(coord, (list, tuple)):
+ coord = (coord,)
+ result = self
+ for c in coord:
+ result = result._integrate_one(c, datetime_unit=datetime_unit)
+ return result
+
+ def _integrate_one(self, coord, datetime_unit=None, cumulative=False):
+ from xarray.core.variable import Variable
+
+ if coord not in self.variables and coord not in self.dims:
+ variables_and_dims = tuple(set(self.variables.keys()).union(self.dims))
+ raise ValueError(
+ f"Coordinate {coord!r} not found in variables or dimensions {variables_and_dims}."
+ )
- def cumulative_integrate(self, coord: (Hashable | Sequence[Hashable]),
- datetime_unit: DatetimeUnitOptions=None) ->Self:
+ coord_var = self[coord].variable
+ if coord_var.ndim != 1:
+ raise ValueError(
+ f"Coordinate {coord} must be 1 dimensional but is {coord_var.ndim}"
+ " dimensional"
+ )
+
+ dim = coord_var.dims[0]
+ if _contains_datetime_like_objects(coord_var):
+ if coord_var.dtype.kind in "mM" and datetime_unit is None:
+ datetime_unit, _ = np.datetime_data(coord_var.dtype)
+ elif datetime_unit is None:
+ datetime_unit = "s" # Default to seconds for cftime objects
+ coord_var = coord_var._replace(
+ data=datetime_to_numeric(coord_var.data, datetime_unit=datetime_unit)
+ )
+
+ variables = {}
+ coord_names = set()
+ for k, v in self.variables.items():
+ if k in self.coords:
+ if dim not in v.dims or cumulative:
+ variables[k] = v
+ coord_names.add(k)
+ else:
+ if k in self.data_vars and dim in v.dims:
+ if _contains_datetime_like_objects(v):
+ v = datetime_to_numeric(v, datetime_unit=datetime_unit)
+ if cumulative:
+ integ = duck_array_ops.cumulative_trapezoid(
+ v.data, coord_var.data, axis=v.get_axis_num(dim)
+ )
+ v_dims = v.dims
+ else:
+ integ = duck_array_ops.trapz(
+ v.data, coord_var.data, axis=v.get_axis_num(dim)
+ )
+ v_dims = list(v.dims)
+ v_dims.remove(dim)
+ variables[k] = Variable(v_dims, integ)
+ else:
+ variables[k] = v
+ indexes = {k: v for k, v in self._indexes.items() if k in variables}
+ return self._replace_with_new_dims(
+ variables, coord_names=coord_names, indexes=indexes
+ )
+
+ def cumulative_integrate(
+ self,
+ coord: Hashable | Sequence[Hashable],
+ datetime_unit: DatetimeUnitOptions = None,
+ ) -> Self:
"""Integrate along the given coordinate using the trapezoidal rule.
.. note::
@@ -4911,7 +8645,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
----------
coord : hashable, or sequence of hashable
Coordinate(s) used for the integration.
- datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as', None}, optional
+ datetime_unit : {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', \
+ 'ps', 'fs', 'as', None}, optional
Specify the unit if datetime coordinate is used.
Returns
@@ -4957,10 +8692,17 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
a (x) float64 32B 0.0 30.0 8.0 20.0
b (x) float64 32B 0.0 9.0 3.0 4.0
"""
- pass
+ if not isinstance(coord, (list, tuple)):
+ coord = (coord,)
+ result = self
+ for c in coord:
+ result = result._integrate_one(
+ c, datetime_unit=datetime_unit, cumulative=True
+ )
+ return result
@property
- def real(self) ->Self:
+ def real(self) -> Self:
"""
The real part of each data variable.
@@ -4968,10 +8710,10 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
numpy.ndarray.real
"""
- pass
+ return self.map(lambda x: x.real, keep_attrs=True)
@property
- def imag(self) ->Self:
+ def imag(self) -> Self:
"""
The imaginary part of each data variable.
@@ -4979,10 +8721,11 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
numpy.ndarray.imag
"""
- pass
+ return self.map(lambda x: x.imag, keep_attrs=True)
+
plot = utils.UncachedAccessor(DatasetPlotAccessor)
- def filter_by_attrs(self, **kwargs) ->Self:
+ def filter_by_attrs(self, **kwargs) -> Self:
"""Returns a ``Dataset`` with variables that match specific conditions.
Can pass in ``key=value`` or ``key=callable``. A Dataset is returned
@@ -5063,9 +8806,21 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
precipitation (x, y, time) float64 96B 5.68 9.256 0.7104 ... 4.615 7.805
"""
- pass
+ selection = []
+ for var_name, variable in self.variables.items():
+ has_value_flag = False
+ for attr_name, pattern in kwargs.items():
+ attr_value = variable.attrs.get(attr_name)
+ if (callable(pattern) and pattern(attr_value)) or attr_value == pattern:
+ has_value_flag = True
+ else:
+ has_value_flag = False
+ break
+ if has_value_flag is True:
+ selection.append(var_name)
+ return self[selection]
- def unify_chunks(self) ->Self:
+ def unify_chunks(self) -> Self:
"""Unify chunk size along all chunked dimensions of this Dataset.
Returns
@@ -5076,11 +8831,16 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
dask.array.core.unify_chunks
"""
- pass
- def map_blocks(self, func: Callable[..., T_Xarray], args: Sequence[Any]
- =(), kwargs: (Mapping[str, Any] | None)=None, template: (DataArray |
- Dataset | None)=None) ->T_Xarray:
+ return unify_chunks(self)[0]
+
+ def map_blocks(
+ self,
+ func: Callable[..., T_Xarray],
+ args: Sequence[Any] = (),
+ kwargs: Mapping[str, Any] | None = None,
+ template: DataArray | Dataset | None = None,
+ ) -> T_Xarray:
"""
Apply a function to each block of this Dataset.
@@ -5181,11 +8941,20 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Data variables:
a (time) float64 192B dask.array<chunksize=(24,), meta=np.ndarray>
"""
- pass
+ from xarray.core.parallel import map_blocks
- def polyfit(self, dim: Hashable, deg: int, skipna: (bool | None)=None,
- rcond: (float | None)=None, w: (Hashable | Any)=None, full: bool=
- False, cov: (bool | Literal['unscaled'])=False) ->Self:
+ return map_blocks(func, self, args, kwargs, template)
+
+ def polyfit(
+ self,
+ dim: Hashable,
+ deg: int,
+ skipna: bool | None = None,
+ rcond: float | None = None,
+ w: Hashable | Any = None,
+ full: bool = False,
+ cov: bool | Literal["unscaled"] = False,
+ ) -> Self:
"""
Least squares polynomial fit.
@@ -5244,16 +9013,147 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
numpy.polyval
xarray.polyval
"""
- pass
+ from xarray.core.dataarray import DataArray
- def pad(self, pad_width: (Mapping[Any, int | tuple[int, int]] | None)=
- None, mode: PadModeOptions='constant', stat_length: (int | tuple[
- int, int] | Mapping[Any, tuple[int, int]] | None)=None,
- constant_values: (float | tuple[float, float] | Mapping[Any, tuple[
- float, float]] | None)=None, end_values: (int | tuple[int, int] |
- Mapping[Any, tuple[int, int]] | None)=None, reflect_type:
- PadReflectOptions=None, keep_attrs: (bool | None)=None, **
- pad_width_kwargs: Any) ->Self:
+ variables = {}
+ skipna_da = skipna
+
+ x = get_clean_interp_index(self, dim, strict=False)
+ xname = f"{self[dim].name}_"
+ order = int(deg) + 1
+ lhs = np.vander(x, order)
+
+ if rcond is None:
+ rcond = x.shape[0] * np.finfo(x.dtype).eps
+
+ # Weights:
+ if w is not None:
+ if isinstance(w, Hashable):
+ w = self.coords[w]
+ w = np.asarray(w)
+ if w.ndim != 1:
+ raise TypeError("Expected a 1-d array for weights.")
+ if w.shape[0] != lhs.shape[0]:
+ raise TypeError(f"Expected w and {dim} to have the same length")
+ lhs *= w[:, np.newaxis]
+
+ # Scaling
+ scale = np.sqrt((lhs * lhs).sum(axis=0))
+ lhs /= scale
+
+ degree_dim = utils.get_temp_dimname(self.dims, "degree")
+
+ rank = np.linalg.matrix_rank(lhs)
+
+ if full:
+ rank = DataArray(rank, name=xname + "matrix_rank")
+ variables[rank.name] = rank
+ _sing = np.linalg.svd(lhs, compute_uv=False)
+ sing = DataArray(
+ _sing,
+ dims=(degree_dim,),
+ coords={degree_dim: np.arange(rank - 1, -1, -1)},
+ name=xname + "singular_values",
+ )
+ variables[sing.name] = sing
+
+ for name, da in self.data_vars.items():
+ if dim not in da.dims:
+ continue
+
+ if is_duck_dask_array(da.data) and (
+ rank != order or full or skipna is None
+ ):
+ # Current algorithm with dask and skipna=False neither supports
+ # deficient ranks nor does it output the "full" info (issue dask/dask#6516)
+ skipna_da = True
+ elif skipna is None:
+ skipna_da = bool(np.any(da.isnull()))
+
+ dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
+ stacked_coords: dict[Hashable, DataArray] = {}
+ if dims_to_stack:
+ stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
+ rhs = da.transpose(dim, *dims_to_stack).stack(
+ {stacked_dim: dims_to_stack}
+ )
+ stacked_coords = {stacked_dim: rhs[stacked_dim]}
+ scale_da = scale[:, np.newaxis]
+ else:
+ rhs = da
+ scale_da = scale
+
+ if w is not None:
+ rhs = rhs * w[:, np.newaxis]
+
+ with warnings.catch_warnings():
+ if full: # Copy np.polyfit behavior
+ warnings.simplefilter("ignore", RankWarning)
+ else: # Raise only once per variable
+ warnings.simplefilter("once", RankWarning)
+
+ coeffs, residuals = duck_array_ops.least_squares(
+ lhs, rhs.data, rcond=rcond, skipna=skipna_da
+ )
+
+ if isinstance(name, str):
+ name = f"{name}_"
+ else:
+ # Thus a ReprObject => polyfit was called on a DataArray
+ name = ""
+
+ coeffs = DataArray(
+ coeffs / scale_da,
+ dims=[degree_dim] + list(stacked_coords.keys()),
+ coords={degree_dim: np.arange(order)[::-1], **stacked_coords},
+ name=name + "polyfit_coefficients",
+ )
+ if dims_to_stack:
+ coeffs = coeffs.unstack(stacked_dim)
+ variables[coeffs.name] = coeffs
+
+ if full or (cov is True):
+ residuals = DataArray(
+ residuals if dims_to_stack else residuals.squeeze(),
+ dims=list(stacked_coords.keys()),
+ coords=stacked_coords,
+ name=name + "polyfit_residuals",
+ )
+ if dims_to_stack:
+ residuals = residuals.unstack(stacked_dim)
+ variables[residuals.name] = residuals
+
+ if cov:
+ Vbase = np.linalg.inv(np.dot(lhs.T, lhs))
+ Vbase /= np.outer(scale, scale)
+ if cov == "unscaled":
+ fac = 1
+ else:
+ if x.shape[0] <= order:
+ raise ValueError(
+ "The number of data points must exceed order to scale the covariance matrix."
+ )
+ fac = residuals / (x.shape[0] - order)
+ covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
+ variables[name + "polyfit_covariance"] = covariance
+
+ return type(self)(data_vars=variables, attrs=self.attrs.copy())
+
+ def pad(
+ self,
+ pad_width: Mapping[Any, int | tuple[int, int]] | None = None,
+ mode: PadModeOptions = "constant",
+ stat_length: (
+ int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
+ ) = None,
+ constant_values: (
+ float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
+ ) = None,
+ end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
+ reflect_type: PadReflectOptions = None,
+ keep_attrs: bool | None = None,
+ **pad_width_kwargs: Any,
+ ) -> Self:
"""Pad this dataset along one or more dimensions.
.. warning::
@@ -5270,7 +9170,8 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Mapping with the form of {dim: (pad_before, pad_after)}
describing the number of values padded along each dimension.
{dim: pad} is a shortcut for pad_before = pad_after = pad
- mode : {"constant", "edge", "linear_ramp", "maximum", "mean", "median", "minimum", "reflect", "symmetric", "wrap"}, default: "constant"
+ mode : {"constant", "edge", "linear_ramp", "maximum", "mean", "median", \
+ "minimum", "reflect", "symmetric", "wrap"}, default: "constant"
How to pad the DataArray (taken from numpy docs):
- "constant": Pads with a constant value.
@@ -5364,12 +9265,74 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Data variables:
foo (x) float64 64B nan 0.0 1.0 2.0 3.0 4.0 nan nan
"""
- pass
+ pad_width = either_dict_or_kwargs(pad_width, pad_width_kwargs, "pad")
- @_deprecate_positional_args('v2023.10.0')
- def idxmin(self, dim: (Hashable | None)=None, *, skipna: (bool | None)=
- None, fill_value: Any=xrdtypes.NA, keep_attrs: (bool | None)=None
- ) ->Self:
+ if mode in ("edge", "reflect", "symmetric", "wrap"):
+ coord_pad_mode = mode
+ coord_pad_options = {
+ "stat_length": stat_length,
+ "constant_values": constant_values,
+ "end_values": end_values,
+ "reflect_type": reflect_type,
+ }
+ else:
+ coord_pad_mode = "constant"
+ coord_pad_options = {}
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ variables = {}
+
+ # keep indexes that won't be affected by pad and drop all other indexes
+ xindexes = self.xindexes
+ pad_dims = set(pad_width)
+ indexes = {}
+ for k, idx in xindexes.items():
+ if not pad_dims.intersection(xindexes.get_all_dims(k)):
+ indexes[k] = idx
+
+ for name, var in self.variables.items():
+ var_pad_width = {k: v for k, v in pad_width.items() if k in var.dims}
+ if not var_pad_width:
+ variables[name] = var
+ elif name in self.data_vars:
+ variables[name] = var.pad(
+ pad_width=var_pad_width,
+ mode=mode,
+ stat_length=stat_length,
+ constant_values=constant_values,
+ end_values=end_values,
+ reflect_type=reflect_type,
+ keep_attrs=keep_attrs,
+ )
+ else:
+ variables[name] = var.pad(
+ pad_width=var_pad_width,
+ mode=coord_pad_mode,
+ keep_attrs=keep_attrs,
+ **coord_pad_options, # type: ignore[arg-type]
+ )
+ # reset default index of dimension coordinates
+ if (name,) == var.dims:
+ dim_var = {name: variables[name]}
+ index = PandasIndex.from_variables(dim_var, options={})
+ index_vars = index.create_variables(dim_var)
+ indexes[name] = index
+ variables[name] = index_vars[name]
+
+ attrs = self._attrs if keep_attrs else None
+ return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def idxmin(
+ self,
+ dim: Hashable | None = None,
+ *,
+ skipna: bool | None = None,
+ fill_value: Any = xrdtypes.NA,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""Return the coordinate label of the minimum value along a dimension.
Returns a new `Dataset` named after the dimension with the values of
@@ -5450,12 +9413,25 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
int <U1 4B 'e'
float (y) object 24B 'e' 'a' 'c'
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def idxmax(self, dim: (Hashable | None)=None, *, skipna: (bool | None)=
- None, fill_value: Any=xrdtypes.NA, keep_attrs: (bool | None)=None
- ) ->Self:
+ return self.map(
+ methodcaller(
+ "idxmin",
+ dim=dim,
+ skipna=skipna,
+ fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ )
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def idxmax(
+ self,
+ dim: Hashable | None = None,
+ *,
+ skipna: bool | None = None,
+ fill_value: Any = xrdtypes.NA,
+ keep_attrs: bool | None = None,
+ ) -> Self:
"""Return the coordinate label of the maximum value along a dimension.
Returns a new `Dataset` named after the dimension with the values of
@@ -5536,9 +9512,17 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
int <U1 4B 'b'
float (y) object 24B 'a' 'c' 'c'
"""
- pass
+ return self.map(
+ methodcaller(
+ "idxmax",
+ dim=dim,
+ skipna=skipna,
+ fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ )
+ )
- def argmin(self, dim: (Hashable | None)=None, **kwargs) ->Self:
+ def argmin(self, dim: Hashable | None = None, **kwargs) -> Self:
"""Indices of the minima of the member variables.
If there are multiple minima, the indices of the first one found will be
@@ -5614,9 +9598,34 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.idxmin
DataArray.argmin
"""
- pass
+ if dim is None:
+ warnings.warn(
+ "Once the behaviour of DataArray.argmin() and Variable.argmin() without "
+ "dim changes to return a dict of indices of each dimension, for "
+ "consistency it will be an error to call Dataset.argmin() with no argument,"
+ "since we don't return a dict of Datasets.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if (
+ dim is None
+ or (not isinstance(dim, Sequence) and dim is not ...)
+ or isinstance(dim, str)
+ ):
+ # Return int index if single dimension is passed, and is not part of a
+ # sequence
+ argmin_func = getattr(duck_array_ops, "argmin")
+ return self.reduce(
+ argmin_func, dim=None if dim is None else [dim], **kwargs
+ )
+ else:
+ raise ValueError(
+ "When dim is a sequence or ..., DataArray.argmin() returns a dict. "
+ "dicts cannot be contained in a Dataset, so cannot call "
+ "Dataset.argmin() with a sequence or ... for dim"
+ )
- def argmax(self, dim: (Hashable | None)=None, **kwargs) ->Self:
+ def argmax(self, dim: Hashable | None = None, **kwargs) -> Self:
"""Indices of the maxima of the member variables.
If there are multiple maxima, the indices of the first one found will be
@@ -5682,10 +9691,39 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray.argmax
"""
- pass
+ if dim is None:
+ warnings.warn(
+ "Once the behaviour of DataArray.argmin() and Variable.argmin() without "
+ "dim changes to return a dict of indices of each dimension, for "
+ "consistency it will be an error to call Dataset.argmin() with no argument,"
+ "since we don't return a dict of Datasets.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if (
+ dim is None
+ or (not isinstance(dim, Sequence) and dim is not ...)
+ or isinstance(dim, str)
+ ):
+ # Return int index if single dimension is passed, and is not part of a
+ # sequence
+ argmax_func = getattr(duck_array_ops, "argmax")
+ return self.reduce(
+ argmax_func, dim=None if dim is None else [dim], **kwargs
+ )
+ else:
+ raise ValueError(
+ "When dim is a sequence or ..., DataArray.argmin() returns a dict. "
+ "dicts cannot be contained in a Dataset, so cannot call "
+ "Dataset.argmin() with a sequence or ... for dim"
+ )
- def eval(self, statement: str, *, parser: QueryParserOptions='pandas') ->(
- Self | T_DataArray):
+ def eval(
+ self,
+ statement: str,
+ *,
+ parser: QueryParserOptions = "pandas",
+ ) -> Self | T_DataArray:
"""
Calculate an expression supplied as a string in the context of the dataset.
@@ -5731,12 +9769,25 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
b (x) float64 40B 0.0 0.25 0.5 0.75 1.0
c (x) float64 40B 0.0 1.25 2.5 3.75 5.0
"""
- pass
- def query(self, queries: (Mapping[Any, Any] | None)=None, parser:
- QueryParserOptions='pandas', engine: QueryEngineOptions=None,
- missing_dims: ErrorOptionsWithWarn='raise', **queries_kwargs: Any
- ) ->Self:
+ return pd.eval( # type: ignore[return-value]
+ statement,
+ resolvers=[self],
+ target=self,
+ parser=parser,
+ # Because numexpr returns a numpy array, using that engine results in
+ # different behavior. We'd be very open to a contribution handling this.
+ engine="python",
+ )
+
+ def query(
+ self,
+ queries: Mapping[Any, Any] | None = None,
+ parser: QueryParserOptions = "pandas",
+ engine: QueryEngineOptions = None,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ **queries_kwargs: Any,
+ ) -> Self:
"""Return a new dataset with each array indexed along the specified
dimension(s), where the indexers are given as strings containing
Python expressions to be evaluated against the data variables in the
@@ -5806,14 +9857,37 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
a (x) int64 16B 3 4
b (x) float64 16B 0.75 1.0
"""
- pass
- def curvefit(self, coords: (str | DataArray | Iterable[str | DataArray]
- ), func: Callable[..., Any], reduce_dims: Dims=None, skipna: bool=
- True, p0: (Mapping[str, float | DataArray] | None)=None, bounds: (
- Mapping[str, tuple[float | DataArray, float | DataArray]] | None)=
- None, param_names: (Sequence[str] | None)=None, errors:
- ErrorOptions='raise', kwargs: (dict[str, Any] | None)=None) ->Self:
+ # allow queries to be given either as a dict or as kwargs
+ queries = either_dict_or_kwargs(queries, queries_kwargs, "query")
+
+ # check queries
+ for dim, expr in queries.items():
+ if not isinstance(expr, str):
+ msg = f"expr for dim {dim} must be a string to be evaluated, {type(expr)} given"
+ raise ValueError(msg)
+
+ # evaluate the queries to create the indexers
+ indexers = {
+ dim: pd.eval(expr, resolvers=[self], parser=parser, engine=engine)
+ for dim, expr in queries.items()
+ }
+
+ # apply the selection
+ return self.isel(indexers, missing_dims=missing_dims)
+
+ def curvefit(
+ self,
+ coords: str | DataArray | Iterable[str | DataArray],
+ func: Callable[..., Any],
+ reduce_dims: Dims = None,
+ skipna: bool = True,
+ p0: Mapping[str, float | DataArray] | None = None,
+ bounds: Mapping[str, tuple[float | DataArray, float | DataArray]] | None = None,
+ param_names: Sequence[str] | None = None,
+ errors: ErrorOptions = "raise",
+ kwargs: dict[str, Any] | None = None,
+ ) -> Self:
"""
Curve fitting optimization for arbitrary functions.
@@ -5878,11 +9952,171 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.polyfit
scipy.optimize.curve_fit
"""
- pass
+ from scipy.optimize import curve_fit
+
+ from xarray.core.alignment import broadcast
+ from xarray.core.computation import apply_ufunc
+ from xarray.core.dataarray import _THIS_ARRAY, DataArray
+
+ if p0 is None:
+ p0 = {}
+ if bounds is None:
+ bounds = {}
+ if kwargs is None:
+ kwargs = {}
- @_deprecate_positional_args('v2023.10.0')
- def drop_duplicates(self, dim: (Hashable | Iterable[Hashable]), *, keep:
- Literal['first', 'last', False]='first') ->Self:
+ reduce_dims_: list[Hashable]
+ if not reduce_dims:
+ reduce_dims_ = []
+ elif isinstance(reduce_dims, str) or not isinstance(reduce_dims, Iterable):
+ reduce_dims_ = [reduce_dims]
+ else:
+ reduce_dims_ = list(reduce_dims)
+
+ if (
+ isinstance(coords, str)
+ or isinstance(coords, DataArray)
+ or not isinstance(coords, Iterable)
+ ):
+ coords = [coords]
+ coords_: Sequence[DataArray] = [
+ self[coord] if isinstance(coord, str) else coord for coord in coords
+ ]
+
+ # Determine whether any coords are dims on self
+ for coord in coords_:
+ reduce_dims_ += [c for c in self.dims if coord.equals(self[c])]
+ reduce_dims_ = list(set(reduce_dims_))
+ preserved_dims = list(set(self.dims) - set(reduce_dims_))
+ if not reduce_dims_:
+ raise ValueError(
+ "No arguments to `coords` were identified as a dimension on the calling "
+ "object, and no dims were supplied to `reduce_dims`. This would result "
+ "in fitting on scalar data."
+ )
+
+ # Check that initial guess and bounds only contain coordinates that are in preserved_dims
+ for param, guess in p0.items():
+ if isinstance(guess, DataArray):
+ unexpected = set(guess.dims) - set(preserved_dims)
+ if unexpected:
+ raise ValueError(
+ f"Initial guess for '{param}' has unexpected dimensions "
+ f"{tuple(unexpected)}. It should only have dimensions that are in data "
+ f"dimensions {preserved_dims}."
+ )
+ for param, (lb, ub) in bounds.items():
+ for label, bound in zip(("Lower", "Upper"), (lb, ub)):
+ if isinstance(bound, DataArray):
+ unexpected = set(bound.dims) - set(preserved_dims)
+ if unexpected:
+ raise ValueError(
+ f"{label} bound for '{param}' has unexpected dimensions "
+ f"{tuple(unexpected)}. It should only have dimensions that are in data "
+ f"dimensions {preserved_dims}."
+ )
+
+ if errors not in ["raise", "ignore"]:
+ raise ValueError('errors must be either "raise" or "ignore"')
+
+ # Broadcast all coords with each other
+ coords_ = broadcast(*coords_)
+ coords_ = [
+ coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_
+ ]
+ n_coords = len(coords_)
+
+ params, func_args = _get_func_args(func, param_names)
+ param_defaults, bounds_defaults = _initialize_curvefit_params(
+ params, p0, bounds, func_args
+ )
+ n_params = len(params)
+
+ def _wrapper(Y, *args, **kwargs):
+ # Wrap curve_fit with raveled coordinates and pointwise NaN handling
+ # *args contains:
+ # - the coordinates
+ # - initial guess
+ # - lower bounds
+ # - upper bounds
+ coords__ = args[:n_coords]
+ p0_ = args[n_coords + 0 * n_params : n_coords + 1 * n_params]
+ lb = args[n_coords + 1 * n_params : n_coords + 2 * n_params]
+ ub = args[n_coords + 2 * n_params :]
+
+ x = np.vstack([c.ravel() for c in coords__])
+ y = Y.ravel()
+ if skipna:
+ mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0)
+ x = x[:, mask]
+ y = y[mask]
+ if not len(y):
+ popt = np.full([n_params], np.nan)
+ pcov = np.full([n_params, n_params], np.nan)
+ return popt, pcov
+ x = np.squeeze(x)
+
+ try:
+ popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
+ except RuntimeError:
+ if errors == "raise":
+ raise
+ popt = np.full([n_params], np.nan)
+ pcov = np.full([n_params, n_params], np.nan)
+
+ return popt, pcov
+
+ result = type(self)()
+ for name, da in self.data_vars.items():
+ if name is _THIS_ARRAY:
+ name = ""
+ else:
+ name = f"{str(name)}_"
+
+ input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)]
+ input_core_dims.extend(
+ [[] for _ in range(3 * n_params)]
+ ) # core_dims for p0 and bounds
+
+ popt, pcov = apply_ufunc(
+ _wrapper,
+ da,
+ *coords_,
+ *param_defaults.values(),
+ *[b[0] for b in bounds_defaults.values()],
+ *[b[1] for b in bounds_defaults.values()],
+ vectorize=True,
+ dask="parallelized",
+ input_core_dims=input_core_dims,
+ output_core_dims=[["param"], ["cov_i", "cov_j"]],
+ dask_gufunc_kwargs={
+ "output_sizes": {
+ "param": n_params,
+ "cov_i": n_params,
+ "cov_j": n_params,
+ },
+ },
+ output_dtypes=(np.float64, np.float64),
+ exclude_dims=set(reduce_dims_),
+ kwargs=kwargs,
+ )
+ result[name + "curvefit_coefficients"] = popt
+ result[name + "curvefit_covariance"] = pcov
+
+ result = result.assign_coords(
+ {"param": params, "cov_i": params, "cov_j": params}
+ )
+ result.attrs = self.attrs.copy()
+
+ return result
+
+ @_deprecate_positional_args("v2023.10.0")
+ def drop_duplicates(
+ self,
+ dim: Hashable | Iterable[Hashable],
+ *,
+ keep: Literal["first", "last", False] = "first",
+ ) -> Self:
"""Returns a new Dataset with duplicate dimension values removed.
Parameters
@@ -5903,11 +10137,32 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
--------
DataArray.drop_duplicates
"""
- pass
+ if isinstance(dim, str):
+ dims: Iterable = (dim,)
+ elif dim is ...:
+ dims = self.dims
+ elif not isinstance(dim, Iterable):
+ dims = [dim]
+ else:
+ dims = dim
+
+ missing_dims = set(dims) - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {tuple(missing_dims)} not found in data dimensions {tuple(self.dims)}"
+ )
- def convert_calendar(self, calendar: CFCalendar, dim: Hashable='time',
- align_on: Literal['date', 'year', None]=None, missing: (Any | None)
- =None, use_cftime: (bool | None)=None) ->Self:
+ indexes = {dim: ~self.get_index(dim).duplicated(keep=keep) for dim in dims}
+ return self.isel(indexes)
+
+ def convert_calendar(
+ self,
+ calendar: CFCalendar,
+ dim: Hashable = "time",
+ align_on: Literal["date", "year", None] = None,
+ missing: Any | None = None,
+ use_cftime: bool | None = None,
+ ) -> Self:
"""Convert the Dataset to another calendar.
Only converts the individual timestamps, does not modify any data except
@@ -6014,10 +10269,20 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
This option is best used with data on a frequency coarser than daily.
"""
- pass
-
- def interp_calendar(self, target: (pd.DatetimeIndex | CFTimeIndex |
- DataArray), dim: Hashable='time') ->Self:
+ return convert_calendar(
+ self,
+ calendar,
+ dim=dim,
+ align_on=align_on,
+ missing=missing,
+ use_cftime=use_cftime,
+ )
+
+ def interp_calendar(
+ self,
+ target: pd.DatetimeIndex | CFTimeIndex | DataArray,
+ dim: Hashable = "time",
+ ) -> Self:
"""Interpolates the Dataset to another calendar based on decimal year measure.
Each timestamp in `source` and `target` are first converted to their decimal
@@ -6042,13 +10307,19 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray
The source interpolated on the decimal years of target,
"""
- pass
-
- @_deprecate_positional_args('v2024.07.0')
- def groupby(self, group: (Hashable | DataArray | IndexVariable |
- Mapping[Any, Grouper] | None)=None, *, squeeze: Literal[False]=
- False, restore_coord_dims: bool=False, **groupers: Grouper
- ) ->DatasetGroupBy:
+ return interp_calendar(self, target, dim=dim)
+
+ @_deprecate_positional_args("v2024.07.0")
+ def groupby(
+ self,
+ group: (
+ Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
+ ) = None,
+ *,
+ squeeze: Literal[False] = False,
+ restore_coord_dims: bool = False,
+ **groupers: Grouper,
+ ) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Parameters
@@ -6094,14 +10365,52 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Dataset.resample
DataArray.resample
"""
- pass
+ from xarray.core.groupby import (
+ DatasetGroupBy,
+ ResolvedGrouper,
+ _validate_groupby_squeeze,
+ )
+ from xarray.groupers import UniqueGrouper
+
+ _validate_groupby_squeeze(squeeze)
+
+ if isinstance(group, Mapping):
+ groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
+ group = None
- @_deprecate_positional_args('v2024.07.0')
- def groupby_bins(self, group: (Hashable | DataArray | IndexVariable),
- bins: Bins, right: bool=True, labels: (ArrayLike | None)=None,
- precision: int=3, include_lowest: bool=False, squeeze: Literal[
- False]=False, restore_coord_dims: bool=False, duplicates: Literal[
- 'raise', 'drop']='raise') ->DatasetGroupBy:
+ if group is not None:
+ if groupers:
+ raise ValueError(
+ "Providing a combination of `group` and **groupers is not supported."
+ )
+ rgrouper = ResolvedGrouper(UniqueGrouper(), group, self)
+ else:
+ if len(groupers) > 1:
+ raise ValueError("Grouping by multiple variables is not supported yet.")
+ elif not groupers:
+ raise ValueError("Either `group` or `**groupers` must be provided.")
+ for group, grouper in groupers.items():
+ rgrouper = ResolvedGrouper(grouper, group, self)
+
+ return DatasetGroupBy(
+ self,
+ (rgrouper,),
+ restore_coord_dims=restore_coord_dims,
+ )
+
+ @_deprecate_positional_args("v2024.07.0")
+ def groupby_bins(
+ self,
+ group: Hashable | DataArray | IndexVariable,
+ bins: Bins,
+ right: bool = True,
+ labels: ArrayLike | None = None,
+ precision: int = 3,
+ include_lowest: bool = False,
+ squeeze: Literal[False] = False,
+ restore_coord_dims: bool = False,
+ duplicates: Literal["raise", "drop"] = "raise",
+ ) -> DatasetGroupBy:
"""Returns a DatasetGroupBy object for performing grouped operations.
Rather than using all unique values of `group`, the values are discretized
@@ -6159,9 +10468,30 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
----------
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
"""
- pass
-
- def weighted(self, weights: DataArray) ->DatasetWeighted:
+ from xarray.core.groupby import (
+ DatasetGroupBy,
+ ResolvedGrouper,
+ _validate_groupby_squeeze,
+ )
+ from xarray.groupers import BinGrouper
+
+ _validate_groupby_squeeze(squeeze)
+ grouper = BinGrouper(
+ bins=bins,
+ right=right,
+ labels=labels,
+ precision=precision,
+ include_lowest=include_lowest,
+ )
+ rgrouper = ResolvedGrouper(grouper, group, self)
+
+ return DatasetGroupBy(
+ self,
+ (rgrouper,),
+ restore_coord_dims=restore_coord_dims,
+ )
+
+ def weighted(self, weights: DataArray) -> DatasetWeighted:
"""
Weighted Dataset operations.
@@ -6192,11 +10522,17 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Tutorial on Weighted Reduction using :py:func:`~xarray.Dataset.weighted`
"""
- pass
+ from xarray.core.weighted import DatasetWeighted
+
+ return DatasetWeighted(self, weights)
- def rolling(self, dim: (Mapping[Any, int] | None)=None, min_periods: (
- int | None)=None, center: (bool | Mapping[Any, bool])=False, **
- window_kwargs: int) ->DatasetRolling:
+ def rolling(
+ self,
+ dim: Mapping[Any, int] | None = None,
+ min_periods: int | None = None,
+ center: bool | Mapping[Any, bool] = False,
+ **window_kwargs: int,
+ ) -> DatasetRolling:
"""
Rolling window object for Datasets.
@@ -6225,10 +10561,16 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray.rolling
core.rolling.DatasetRolling
"""
- pass
+ from xarray.core.rolling import DatasetRolling
- def cumulative(self, dim: (str | Iterable[Hashable]), min_periods: int=1
- ) ->DatasetRolling:
+ dim = either_dict_or_kwargs(dim, window_kwargs, "rolling")
+ return DatasetRolling(self, dim, min_periods=min_periods, center=center)
+
+ def cumulative(
+ self,
+ dim: str | Iterable[Hashable],
+ min_periods: int = 1,
+ ) -> DatasetRolling:
"""
Accumulating object for Datasets
@@ -6251,12 +10593,32 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
DataArray.cumulative
core.rolling.DatasetRolling
"""
- pass
+ from xarray.core.rolling import DatasetRolling
+
+ if isinstance(dim, str):
+ if dim not in self.dims:
+ raise ValueError(
+ f"Dimension {dim} not found in data dimensions: {self.dims}"
+ )
+ dim = {dim: self.sizes[dim]}
+ else:
+ missing_dims = set(dim) - set(self.dims)
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {missing_dims} not found in data dimensions: {self.dims}"
+ )
+ dim = {d: self.sizes[d] for d in dim}
+
+ return DatasetRolling(self, dim, min_periods=min_periods, center=False)
- def coarsen(self, dim: (Mapping[Any, int] | None)=None, boundary:
- CoarsenBoundaryOptions='exact', side: (SideOptions | Mapping[Any,
- SideOptions])='left', coord_func: (str | Callable | Mapping[Any,
- str | Callable])='mean', **window_kwargs: int) ->DatasetCoarsen:
+ def coarsen(
+ self,
+ dim: Mapping[Any, int] | None = None,
+ boundary: CoarsenBoundaryOptions = "exact",
+ side: SideOptions | Mapping[Any, SideOptions] = "left",
+ coord_func: str | Callable | Mapping[Any, str | Callable] = "mean",
+ **window_kwargs: int,
+ ) -> DatasetCoarsen:
"""
Coarsen object for Datasets.
@@ -6292,15 +10654,30 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
Tutorial on windowed computation using :py:func:`~xarray.Dataset.coarsen`
"""
- pass
-
- @_deprecate_positional_args('v2024.07.0')
- def resample(self, indexer: (Mapping[Any, str | Resampler] | None)=None,
- *, skipna: (bool | None)=None, closed: (SideOptions | None)=None,
- label: (SideOptions | None)=None, offset: (pd.Timedelta | datetime.
- timedelta | str | None)=None, origin: (str | DatetimeLike)=
- 'start_day', restore_coord_dims: (bool | None)=None, **
- indexer_kwargs: (str | Resampler)) ->DatasetResample:
+ from xarray.core.rolling import DatasetCoarsen
+
+ dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen")
+ return DatasetCoarsen(
+ self,
+ dim,
+ boundary=boundary,
+ side=side,
+ coord_func=coord_func,
+ )
+
+ @_deprecate_positional_args("v2024.07.0")
+ def resample(
+ self,
+ indexer: Mapping[Any, str | Resampler] | None = None,
+ *,
+ skipna: bool | None = None,
+ closed: SideOptions | None = None,
+ label: SideOptions | None = None,
+ offset: pd.Timedelta | datetime.timedelta | str | None = None,
+ origin: str | DatetimeLike = "start_day",
+ restore_coord_dims: bool | None = None,
+ **indexer_kwargs: str | Resampler,
+ ) -> DatasetResample:
"""Returns a Resample object for performing resampling operations.
Handles both downsampling and upsampling. The resampled
@@ -6355,9 +10732,21 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
----------
.. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
"""
- pass
+ from xarray.core.resample import DatasetResample
+
+ return self._resample(
+ resample_cls=DatasetResample,
+ indexer=indexer,
+ skipna=skipna,
+ closed=closed,
+ label=label,
+ offset=offset,
+ origin=origin,
+ restore_coord_dims=restore_coord_dims,
+ **indexer_kwargs,
+ )
- def drop_attrs(self, *, deep: bool=True) ->Self:
+ def drop_attrs(self, *, deep: bool = True) -> Self:
"""
Removes all attributes from the Dataset and its variables.
@@ -6370,4 +10759,31 @@ class Dataset(DataWithCoords, DatasetAggregations, DatasetArithmetic,
-------
Dataset
"""
- pass
+ # Remove attributes from the dataset
+ self = self._replace(attrs={})
+
+ if not deep:
+ return self
+
+ # Remove attributes from each variable in the dataset
+ for var in self.variables:
+ # variables don't have a `._replace` method, so we copy and then remove
+ # attrs. If we added a `._replace` method, we could use that instead.
+ if var not in self.indexes:
+ self[var] = self[var].copy()
+ self[var].attrs = {}
+
+ new_idx_variables = {}
+ # Not sure this is the most elegant way of doing this, but it works.
+ # (Should we have a more general "map over all variables, including
+ # indexes" approach?)
+ for idx, idx_vars in self.xindexes.group_by_index():
+ # copy each coordinate variable of an index and drop their attrs
+ temp_idx_variables = {k: v.copy() for k, v in idx_vars.items()}
+ for v in temp_idx_variables.values():
+ v.attrs = {}
+ # re-wrap the index object in new coordinate variables
+ new_idx_variables.update(idx.create_variables(temp_idx_variables))
+ self = self.assign(new_idx_variables)
+
+ return self
diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py
index 397a6531..65ff8667 100644
--- a/xarray/core/datatree.py
+++ b/xarray/core/datatree.py
@@ -1,39 +1,166 @@
from __future__ import annotations
+
import itertools
import textwrap
from collections import ChainMap
from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping
from html import escape
-from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, Union, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generic,
+ Literal,
+ NoReturn,
+ Union,
+ overload,
+)
+
from xarray.core import utils
from xarray.core.alignment import align
from xarray.core.common import TreeAttrAccessMixin
from xarray.core.coordinates import DatasetCoordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset, DataVariables
-from xarray.core.datatree_mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree
-from xarray.core.datatree_ops import DataTreeArithmeticMixin, MappedDatasetMethodsMixin, MappedDataWithCoords
+from xarray.core.datatree_mapping import (
+ TreeIsomorphismError,
+ check_isomorphic,
+ map_over_subtree,
+)
+from xarray.core.datatree_ops import (
+ DataTreeArithmeticMixin,
+ MappedDatasetMethodsMixin,
+ MappedDataWithCoords,
+)
from xarray.core.datatree_render import RenderDataTree
from xarray.core.formatting import datatree_repr, dims_and_coords_repr
-from xarray.core.formatting_html import datatree_repr as datatree_repr_html
+from xarray.core.formatting_html import (
+ datatree_repr as datatree_repr_html,
+)
from xarray.core.indexes import Index, Indexes
from xarray.core.merge import dataset_update_method
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.treenode import NamedNode, NodePath, Tree
-from xarray.core.utils import Default, Frozen, HybridMappingProxy, _default, either_dict_or_kwargs, maybe_wrap_array
+from xarray.core.utils import (
+ Default,
+ Frozen,
+ HybridMappingProxy,
+ _default,
+ either_dict_or_kwargs,
+ maybe_wrap_array,
+)
from xarray.core.variable import Variable
+
try:
from xarray.core.variable import calculate_dimensions
except ImportError:
+ # for xarray versions 2022.03.0 and earlier
from xarray.core.dataset import calculate_dimensions
+
if TYPE_CHECKING:
import pandas as pd
+
from xarray.core.datatree_io import T_DataTreeNetcdfEngine, T_DataTreeNetcdfTypes
from xarray.core.merge import CoercibleMapping, CoercibleValue
from xarray.core.types import ErrorOptions, NetcdfWriteModes, ZarrWriteModes
+
+# """
+# DEVELOPERS' NOTE
+# ----------------
+# The idea of this module is to create a `DataTree` class which inherits the tree structure from TreeNode, and also copies
+# the entire API of `xarray.Dataset`, but with certain methods decorated to instead map the dataset function over every
+# node in the tree. As this API is copied without directly subclassing `xarray.Dataset` we instead create various Mixin
+# classes (in ops.py) which each define part of `xarray.Dataset`'s extensive API.
+#
+# Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine to inherit unaltered
+# (normally because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new
+# tree) and some will get overridden by the class definition of DataTree.
+# """
+
+
T_Path = Union[str, NodePath]
+def _collect_data_and_coord_variables(
+ data: Dataset,
+) -> tuple[dict[Hashable, Variable], dict[Hashable, Variable]]:
+ data_variables = {}
+ coord_variables = {}
+ for k, v in data.variables.items():
+ if k in data._coord_names:
+ coord_variables[k] = v
+ else:
+ data_variables[k] = v
+ return data_variables, coord_variables
+
+
+def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset:
+ if isinstance(data, DataArray):
+ ds = data.to_dataset()
+ elif isinstance(data, Dataset):
+ ds = data.copy(deep=False)
+ elif data is None:
+ ds = Dataset()
+ else:
+ raise TypeError(
+ f"data object is not an xarray Dataset, DataArray, or None, it is of type {type(data)}"
+ )
+ return ds
+
+
+def _join_path(root: str, name: str) -> str:
+ return str(NodePath(root) / name)
+
+
+def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset:
+ return Dataset._construct_direct(
+ variables=parent._variables | ds._variables,
+ coord_names=parent._coord_names | ds._coord_names,
+ dims=parent._dims | ds._dims,
+ attrs=ds._attrs,
+ indexes=parent._indexes | ds._indexes,
+ encoding=ds._encoding,
+ close=ds._close,
+ )
+
+
+def _without_header(text: str) -> str:
+ return "\n".join(text.split("\n")[1:])
+
+
+def _indented(text: str) -> str:
+ return textwrap.indent(text, prefix=" ")
+
+
+def _check_alignment(
+ path: str,
+ node_ds: Dataset,
+ parent_ds: Dataset | None,
+ children: Mapping[str, DataTree],
+) -> None:
+ if parent_ds is not None:
+ try:
+ align(node_ds, parent_ds, join="exact")
+ except ValueError as e:
+ node_repr = _indented(_without_header(repr(node_ds)))
+ parent_repr = _indented(dims_and_coords_repr(parent_ds))
+ raise ValueError(
+ f"group {path!r} is not aligned with its parents:\n"
+ f"Group:\n{node_repr}\nFrom parents:\n{parent_repr}"
+ ) from e
+
+ if children:
+ if parent_ds is not None:
+ base_ds = _inherited_dataset(node_ds, parent_ds)
+ else:
+ base_ds = node_ds
+
+ for child_name, child in children.items():
+ child_path = str(NodePath(path) / child_name)
+ child_ds = child.to_dataset(inherited=False)
+ _check_alignment(child_path, child_ds, base_ds, child.children)
+
+
class DatasetView(Dataset):
"""
An immutable Dataset-like view onto the data in a single DataTree node.
@@ -44,67 +171,149 @@ class DatasetView(Dataset):
Operations returning a new result will return a new xarray.Dataset object.
This includes all API on Dataset, which will be inherited.
"""
- __slots__ = ('_attrs', '_cache', '_coord_names', '_dims', '_encoding',
- '_close', '_indexes', '_variables')
- def __init__(self, data_vars: (Mapping[Any, Any] | None)=None, coords:
- (Mapping[Any, Any] | None)=None, attrs: (Mapping[Any, Any] | None)=None
- ):
- raise AttributeError(
- 'DatasetView objects are not to be initialized directly')
+ # TODO what happens if user alters (in-place) a DataArray they extracted from this object?
+
+ __slots__ = (
+ "_attrs",
+ "_cache", # used by _CachedAccessor
+ "_coord_names",
+ "_dims",
+ "_encoding",
+ "_close",
+ "_indexes",
+ "_variables",
+ )
+
+ def __init__(
+ self,
+ data_vars: Mapping[Any, Any] | None = None,
+ coords: Mapping[Any, Any] | None = None,
+ attrs: Mapping[Any, Any] | None = None,
+ ):
+ raise AttributeError("DatasetView objects are not to be initialized directly")
@classmethod
- def _constructor(cls, variables: dict[Any, Variable], coord_names: set[
- Hashable], dims: dict[Any, int], attrs: (dict | None), indexes:
- dict[Any, Index], encoding: (dict | None), close: (Callable[[],
- None] | None)) ->DatasetView:
+ def _constructor(
+ cls,
+ variables: dict[Any, Variable],
+ coord_names: set[Hashable],
+ dims: dict[Any, int],
+ attrs: dict | None,
+ indexes: dict[Any, Index],
+ encoding: dict | None,
+ close: Callable[[], None] | None,
+ ) -> DatasetView:
"""Private constructor, from Dataset attributes."""
- pass
-
- def __setitem__(self, key, val) ->None:
+ # We override Dataset._construct_direct below, so we need a new
+ # constructor for creating DatasetView objects.
+ obj: DatasetView = object.__new__(cls)
+ obj._variables = variables
+ obj._coord_names = coord_names
+ obj._dims = dims
+ obj._indexes = indexes
+ obj._attrs = attrs
+ obj._close = close
+ obj._encoding = encoding
+ return obj
+
+ def __setitem__(self, key, val) -> None:
raise AttributeError(
- 'Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`,use `.copy()` first to get a mutable version of the input dataset.'
- )
+ "Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, "
+ "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`,"
+ "use `.copy()` first to get a mutable version of the input dataset."
+ )
- @overload
- def __getitem__(self, key: Mapping) ->Dataset:
+ def update(self, other) -> NoReturn:
+ raise AttributeError(
+ "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, "
+ "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`,"
+ "use `.copy()` first to get a mutable version of the input dataset."
+ )
+
+ # FIXME https://github.com/python/mypy/issues/7328
+ @overload # type: ignore[override]
+ def __getitem__(self, key: Mapping) -> Dataset: # type: ignore[overload-overlap]
...
@overload
- def __getitem__(self, key: Hashable) ->DataArray:
+ def __getitem__(self, key: Hashable) -> DataArray: # type: ignore[overload-overlap]
...
+ # See: https://github.com/pydata/xarray/issues/8855
@overload
- def __getitem__(self, key: Any) ->Dataset:
- ...
+ def __getitem__(self, key: Any) -> Dataset: ...
- def __getitem__(self, key) ->(DataArray | Dataset):
+ def __getitem__(self, key) -> DataArray | Dataset:
+ # TODO call the `_get_item` method of DataTree to allow path-like access to contents of other nodes
+ # For now just call Dataset.__getitem__
return Dataset.__getitem__(self, key)
@classmethod
- def _construct_direct(cls, variables: dict[Any, Variable], coord_names:
- set[Hashable], dims: (dict[Any, int] | None)=None, attrs: (dict |
- None)=None, indexes: (dict[Any, Index] | None)=None, encoding: (
- dict | None)=None, close: (Callable[[], None] | None)=None) ->Dataset:
+ def _construct_direct( # type: ignore[override]
+ cls,
+ variables: dict[Any, Variable],
+ coord_names: set[Hashable],
+ dims: dict[Any, int] | None = None,
+ attrs: dict | None = None,
+ indexes: dict[Any, Index] | None = None,
+ encoding: dict | None = None,
+ close: Callable[[], None] | None = None,
+ ) -> Dataset:
"""
Overriding this method (along with ._replace) and modifying it to return a Dataset object
should hopefully ensure that the return type of any method on this object is a Dataset.
"""
- pass
-
- def _replace(self, variables: (dict[Hashable, Variable] | None)=None,
- coord_names: (set[Hashable] | None)=None, dims: (dict[Any, int] |
- None)=None, attrs: (dict[Hashable, Any] | None | Default)=_default,
- indexes: (dict[Hashable, Index] | None)=None, encoding: (dict |
- None | Default)=_default, inplace: bool=False) ->Dataset:
+ if dims is None:
+ dims = calculate_dimensions(variables)
+ if indexes is None:
+ indexes = {}
+ obj = object.__new__(Dataset)
+ obj._variables = variables
+ obj._coord_names = coord_names
+ obj._dims = dims
+ obj._indexes = indexes
+ obj._attrs = attrs
+ obj._close = close
+ obj._encoding = encoding
+ return obj
+
+ def _replace( # type: ignore[override]
+ self,
+ variables: dict[Hashable, Variable] | None = None,
+ coord_names: set[Hashable] | None = None,
+ dims: dict[Any, int] | None = None,
+ attrs: dict[Hashable, Any] | None | Default = _default,
+ indexes: dict[Hashable, Index] | None = None,
+ encoding: dict | None | Default = _default,
+ inplace: bool = False,
+ ) -> Dataset:
"""
Overriding this method (along with ._construct_direct) and modifying it to return a Dataset object
should hopefully ensure that the return type of any method on this object is a Dataset.
"""
- pass
- def map(self, func: Callable, keep_attrs: (bool | None)=None, args:
- Iterable[Any]=(), **kwargs: Any) ->Dataset:
+ if inplace:
+ raise AttributeError("In-place mutation of the DatasetView is not allowed")
+
+ return Dataset._replace(
+ self,
+ variables=variables,
+ coord_names=coord_names,
+ dims=dims,
+ attrs=attrs,
+ indexes=indexes,
+ encoding=encoding,
+ inplace=inplace,
+ )
+
+ def map( # type: ignore[override]
+ self,
+ func: Callable,
+ keep_attrs: bool | None = None,
+ args: Iterable[Any] = (),
+ **kwargs: Any,
+ ) -> Dataset:
"""Apply a function to each data variable in this dataset
Parameters
@@ -146,20 +355,53 @@ class DatasetView(Dataset):
foo (dim_0, dim_1) float64 48B 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 16B 1.0 2.0
"""
- pass
-
-class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
- DataTreeArithmeticMixin, TreeAttrAccessMixin, Generic[Tree], Mapping):
+ # Copied from xarray.Dataset so as not to call type(self), which causes problems (see https://github.com/xarray-contrib/datatree/issues/188).
+ # TODO Refactor xarray upstream to avoid needing to overwrite this.
+ # TODO This copied version will drop all attrs - the keep_attrs stuff should be re-instated
+ variables = {
+ k: maybe_wrap_array(v, func(v, *args, **kwargs))
+ for k, v in self.data_vars.items()
+ }
+ # return type(self)(variables, attrs=attrs)
+ return Dataset(variables)
+
+
+class DataTree(
+ NamedNode,
+ MappedDatasetMethodsMixin,
+ MappedDataWithCoords,
+ DataTreeArithmeticMixin,
+ TreeAttrAccessMixin,
+ Generic[Tree],
+ Mapping,
+):
"""
A tree-like hierarchical collection of xarray objects.
Attempts to present an API like that of xarray.Dataset, but methods are wrapped to also update all the tree's child nodes.
"""
+
+ # TODO Some way of sorting children by depth
+
+ # TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes?
+
+ # TODO dataset methods which should not or cannot act over the whole tree, such as .to_array
+
+ # TODO .loc method
+
+ # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from
+
+ # TODO all groupby classes
+
+ # TODO a lot of properties like .variables could be defined in a DataMapping class which both Dataset and DataTree inherit from
+
+ # TODO all groupby classes
+
_name: str | None
_parent: DataTree | None
_children: dict[str, DataTree]
- _cache: dict[str, Any]
+ _cache: dict[str, Any] # used by _CachedAccessor
_data_variables: dict[Hashable, Variable]
_node_coord_variables: dict[Hashable, Variable]
_node_dims: dict[Hashable, int]
@@ -167,13 +409,28 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
_attrs: dict[Hashable, Any] | None
_encoding: dict[Hashable, Any] | None
_close: Callable[[], None] | None
- __slots__ = ('_name', '_parent', '_children', '_cache',
- '_data_variables', '_node_coord_variables', '_node_dims',
- '_node_indexes', '_attrs', '_encoding', '_close')
- def __init__(self, data: (Dataset | DataArray | None)=None, parent: (
- DataTree | None)=None, children: (Mapping[str, DataTree] | None)=
- None, name: (str | None)=None):
+ __slots__ = (
+ "_name",
+ "_parent",
+ "_children",
+ "_cache", # used by _CachedAccessor
+ "_data_variables",
+ "_node_coord_variables",
+ "_node_dims",
+ "_node_indexes",
+ "_attrs",
+ "_encoding",
+ "_close",
+ )
+
+ def __init__(
+ self,
+ data: Dataset | DataArray | None = None,
+ parent: DataTree | None = None,
+ children: Mapping[str, DataTree] | None = None,
+ name: str | None = None,
+ ):
"""
Create a single node of a DataTree.
@@ -203,18 +460,88 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
"""
if children is None:
children = {}
+
super().__init__(name=name)
self._set_node_data(_coerce_to_dataset(data))
self.parent = parent
self.children = children
+ def _set_node_data(self, ds: Dataset):
+ data_vars, coord_vars = _collect_data_and_coord_variables(ds)
+ self._data_variables = data_vars
+ self._node_coord_variables = coord_vars
+ self._node_dims = ds._dims
+ self._node_indexes = ds._indexes
+ self._encoding = ds._encoding
+ self._attrs = ds._attrs
+ self._close = ds._close
+
+ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
+ super()._pre_attach(parent, name)
+ if name in parent.ds.variables:
+ raise KeyError(
+ f"parent {parent.name} already contains a variable named {name}"
+ )
+ path = str(NodePath(parent.path) / name)
+ node_ds = self.to_dataset(inherited=False)
+ parent_ds = parent._to_dataset_view(rebuild_dims=False)
+ _check_alignment(path, node_ds, parent_ds, self.children)
+
+ @property
+ def _coord_variables(self) -> ChainMap[Hashable, Variable]:
+ return ChainMap(
+ self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
+ )
+
@property
- def parent(self: DataTree) ->(DataTree | None):
+ def _dims(self) -> ChainMap[Hashable, int]:
+ return ChainMap(self._node_dims, *(p._node_dims for p in self.parents))
+
+ @property
+ def _indexes(self) -> ChainMap[Hashable, Index]:
+ return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))
+
+ @property
+ def parent(self: DataTree) -> DataTree | None:
"""Parent of this node."""
- pass
+ return self._parent
+
+ @parent.setter
+ def parent(self: DataTree, new_parent: DataTree) -> None:
+ if new_parent and self.name is None:
+ raise ValueError("Cannot set an unnamed node as a child of another node")
+ self._set_parent(new_parent, self.name)
+
+ def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView:
+ variables = dict(self._data_variables)
+ variables |= self._coord_variables
+ if rebuild_dims:
+ dims = calculate_dimensions(variables)
+ else:
+ # Note: rebuild_dims=False can create technically invalid Dataset
+ # objects because it may not contain all dimensions on its direct
+ # member variables, e.g., consider:
+ # tree = DataTree.from_dict(
+ # {
+ # "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2
+ # "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1
+ # }
+ # )
+ # However, they are fine for internal use cases, for align() or
+ # building a repr().
+ dims = dict(self._dims)
+ return DatasetView._constructor(
+ variables=variables,
+ coord_names=set(self._coord_variables),
+ dims=dims,
+ attrs=self._attrs,
+ indexes=dict(self._indexes),
+ encoding=self._encoding,
+ close=None,
+ )
@property
- def ds(self) ->DatasetView:
+ def ds(self) -> DatasetView:
"""
An immutable Dataset-like view onto the data in this node.
@@ -227,9 +554,14 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
--------
DataTree.to_dataset
"""
- pass
+ return self._to_dataset_view(rebuild_dims=True)
+
+ @ds.setter
+ def ds(self, data: Dataset | DataArray | None = None) -> None:
+ ds = _coerce_to_dataset(data)
+ self._replace_node(ds)
- def to_dataset(self, inherited: bool=True) ->Dataset:
+ def to_dataset(self, inherited: bool = True) -> Dataset:
"""
Return the data in this node as a new xarray.Dataset object.
@@ -243,50 +575,74 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
--------
DataTree.ds
"""
- pass
+ coord_vars = self._coord_variables if inherited else self._node_coord_variables
+ variables = dict(self._data_variables)
+ variables |= coord_vars
+ dims = calculate_dimensions(variables) if inherited else dict(self._node_dims)
+ return Dataset._construct_direct(
+ variables,
+ set(coord_vars),
+ dims,
+ None if self._attrs is None else dict(self._attrs),
+ dict(self._indexes if inherited else self._node_indexes),
+ None if self._encoding is None else dict(self._encoding),
+ self._close,
+ )
@property
- def has_data(self) ->bool:
+ def has_data(self) -> bool:
"""Whether or not there are any variables in this node."""
- pass
+ return bool(self._data_variables or self._node_coord_variables)
@property
- def has_attrs(self) ->bool:
+ def has_attrs(self) -> bool:
"""Whether or not there are any metadata attributes in this node."""
- pass
+ return len(self.attrs.keys()) > 0
@property
- def is_empty(self) ->bool:
+ def is_empty(self) -> bool:
"""False if node contains any data or attrs. Does not look at children."""
- pass
+ return not (self.has_data or self.has_attrs)
@property
- def is_hollow(self) ->bool:
+ def is_hollow(self) -> bool:
"""True if only leaf nodes contain data."""
- pass
+ return not any(node.has_data for node in self.subtree if not node.is_leaf)
@property
- def variables(self) ->Mapping[Hashable, Variable]:
+ def variables(self) -> Mapping[Hashable, Variable]:
"""Low level interface to node contents as dict of Variable objects.
This dictionary is frozen to prevent mutation that could violate
Dataset invariants. It contains all variable objects constituting this
DataTree node, including both data variables and coordinates.
"""
- pass
+ return Frozen(self._data_variables | self._coord_variables)
@property
- def attrs(self) ->dict[Hashable, Any]:
+ def attrs(self) -> dict[Hashable, Any]:
"""Dictionary of global attributes on this node object."""
- pass
+ if self._attrs is None:
+ self._attrs = {}
+ return self._attrs
+
+ @attrs.setter
+ def attrs(self, value: Mapping[Any, Any]) -> None:
+ self._attrs = dict(value)
@property
- def encoding(self) ->dict:
+ def encoding(self) -> dict:
"""Dictionary of global encoding attributes on this node object."""
- pass
+ if self._encoding is None:
+ self._encoding = {}
+ return self._encoding
+
+ @encoding.setter
+ def encoding(self, value: Mapping) -> None:
+ self._encoding = dict(value)
@property
- def dims(self) ->Mapping[Hashable, int]:
+ def dims(self) -> Mapping[Hashable, int]:
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
@@ -295,10 +651,10 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
See `DataTree.sizes`, `Dataset.sizes`, and `DataArray.sizes` for consistently named
properties.
"""
- pass
+ return Frozen(self._dims)
@property
- def sizes(self) ->Mapping[Hashable, int]:
+ def sizes(self) -> Mapping[Hashable, int]:
"""Mapping from dimension names to lengths.
Cannot be modified directly, but is updated when adding new variables.
@@ -310,53 +666,116 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
--------
DataArray.sizes
"""
- pass
+ return self.dims
@property
- def _attr_sources(self) ->Iterable[Mapping[Hashable, Any]]:
+ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]:
"""Places to look-up items for attribute-style access"""
- pass
+ yield from self._item_sources
+ yield self.attrs
@property
- def _item_sources(self) ->Iterable[Mapping[Any, Any]]:
+ def _item_sources(self) -> Iterable[Mapping[Any, Any]]:
"""Places to look-up items for key-completion"""
- pass
+ yield self.data_vars
+ yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords)
+
+ # virtual coordinates
+ yield HybridMappingProxy(keys=self.dims, mapping=self)
- def _ipython_key_completions_(self) ->list[str]:
+ # immediate child nodes
+ yield self.children
+
+ def _ipython_key_completions_(self) -> list[str]:
"""Provide method for the key-autocompletions in IPython.
See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion
For the details.
"""
- pass
- def __contains__(self, key: object) ->bool:
+ # TODO allow auto-completing relative string paths, e.g. `dt['path/to/../ <tab> node'`
+ # Would require changes to ipython's autocompleter, see https://github.com/ipython/ipython/issues/12420
+ # Instead for now we only list direct paths to all node in subtree explicitly
+
+ items_on_this_node = self._item_sources
+ full_file_like_paths_to_all_nodes_in_subtree = {
+ node.path[1:]: node for node in self.subtree
+ }
+
+ all_item_sources = itertools.chain(
+ items_on_this_node, [full_file_like_paths_to_all_nodes_in_subtree]
+ )
+
+ items = {
+ item
+ for source in all_item_sources
+ for item in source
+ if isinstance(item, str)
+ }
+ return list(items)
+
+ def __contains__(self, key: object) -> bool:
"""The 'in' operator will return true or false depending on whether
'key' is either an array stored in the datatree or a child node, or neither.
"""
return key in self.variables or key in self.children
- def __bool__(self) ->bool:
+ def __bool__(self) -> bool:
return bool(self._data_variables) or bool(self._children)
- def __iter__(self) ->Iterator[Hashable]:
+ def __iter__(self) -> Iterator[Hashable]:
return itertools.chain(self._data_variables, self._children)
def __array__(self, dtype=None, copy=None):
raise TypeError(
- 'cannot directly convert a DataTree into a numpy array. Instead, create an xarray.DataArray first, either with indexing on the DataTree or by invoking the `to_array()` method.'
- )
+ "cannot directly convert a DataTree into a "
+ "numpy array. Instead, create an xarray.DataArray "
+ "first, either with indexing on the DataTree or by "
+ "invoking the `to_array()` method."
+ )
- def __repr__(self) ->str:
+ def __repr__(self) -> str: # type: ignore[override]
return datatree_repr(self)
- def __str__(self) ->str:
+ def __str__(self) -> str:
return datatree_repr(self)
def _repr_html_(self):
"""Make html representation of datatree object"""
- pass
+ if XR_OPTS["display_style"] == "text":
+ return f"<pre>{escape(repr(self))}</pre>"
+ return datatree_repr_html(self)
+
+ def _replace_node(
+ self: DataTree,
+ data: Dataset | Default = _default,
+ children: dict[str, DataTree] | Default = _default,
+ ) -> None:
+
+ ds = self.to_dataset(inherited=False) if data is _default else data
+
+ if children is _default:
+ children = self._children
- def copy(self: DataTree, deep: bool=False) ->DataTree:
+ for child_name in children:
+ if child_name in ds.variables:
+ raise ValueError(f"node already contains a variable named {child_name}")
+
+ parent_ds = (
+ self.parent._to_dataset_view(rebuild_dims=False)
+ if self.parent is not None
+ else None
+ )
+ _check_alignment(self.path, ds, parent_ds, children)
+
+ if data is not _default:
+ self._set_node_data(ds)
+
+ self._children = children
+
+ def copy(
+ self: DataTree,
+ deep: bool = False,
+ ) -> DataTree:
"""
Returns a copy of this subtree.
@@ -384,26 +803,38 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
xarray.Dataset.copy
pandas.DataFrame.copy
"""
- pass
+ return self._copy_subtree(deep=deep)
- def _copy_subtree(self: DataTree, deep: bool=False, memo: (dict[int,
- Any] | None)=None) ->DataTree:
+ def _copy_subtree(
+ self: DataTree,
+ deep: bool = False,
+ memo: dict[int, Any] | None = None,
+ ) -> DataTree:
"""Copy entire subtree"""
- pass
-
- def _copy_node(self: DataTree, deep: bool=False) ->DataTree:
+ new_tree = self._copy_node(deep=deep)
+ for node in self.descendants:
+ path = node.relative_to(self)
+ new_tree[path] = node._copy_node(deep=deep)
+ return new_tree
+
+ def _copy_node(
+ self: DataTree,
+ deep: bool = False,
+ ) -> DataTree:
"""Copy just one node of a tree"""
- pass
+ data = self.ds.copy(deep=deep)
+ new_node: DataTree = DataTree(data, name=self.name)
+ return new_node
- def __copy__(self: DataTree) ->DataTree:
+ def __copy__(self: DataTree) -> DataTree:
return self._copy_subtree(deep=False)
- def __deepcopy__(self: DataTree, memo: (dict[int, Any] | None)=None
- ) ->DataTree:
+ def __deepcopy__(self: DataTree, memo: dict[int, Any] | None = None) -> DataTree:
return self._copy_subtree(deep=True, memo=memo)
- def get(self: DataTree, key: str, default: (DataTree | DataArray | None
- )=None) ->(DataTree | DataArray | None):
+ def get( # type: ignore[override]
+ self: DataTree, key: str, default: DataTree | DataArray | None = None
+ ) -> DataTree | DataArray | None:
"""
Access child nodes, variables, or coordinates stored in this node.
@@ -417,9 +848,14 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
default : DataTree | DataArray | None, optional
A value to return if the specified key does not exist. Default return value is None.
"""
- pass
+ if key in self.children:
+ return self.children[key]
+ elif key in self.ds:
+ return self.ds[key]
+ else:
+ return default
- def __getitem__(self: DataTree, key: str) ->(DataTree | DataArray):
+ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
"""
Access child nodes, variables, or coordinates stored anywhere in this tree.
@@ -435,27 +871,48 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
-------
DataTree | DataArray
"""
+
+ # Either:
if utils.is_dict_like(key):
- raise NotImplementedError('Should this index over whole tree?')
+ # dict-like indexing
+ raise NotImplementedError("Should this index over whole tree?")
elif isinstance(key, str):
+ # TODO should possibly deal with hashables in general?
+ # path-like: a name of a node/variable, or path to a node/variable
path = NodePath(key)
return self._get_item(path)
elif utils.is_list_like(key):
+ # iterable of variable names
raise NotImplementedError(
- 'Selecting via tags is deprecated, and selecting multiple items should be implemented via .subset'
- )
+ "Selecting via tags is deprecated, and selecting multiple items should be "
+ "implemented via .subset"
+ )
else:
- raise ValueError(f'Invalid format for key: {key}')
+ raise ValueError(f"Invalid format for key: {key}")
- def _set(self, key: str, val: (DataTree | CoercibleValue)) ->None:
+ def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
"""
Set the child node or variable with the specified key to value.
Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree.
"""
- pass
+ if isinstance(val, DataTree):
+ # create and assign a shallow copy here so as not to alter original name of node in grafted tree
+ new_node = val.copy(deep=False)
+ new_node.name = key
+ new_node.parent = self
+ else:
+ if not isinstance(val, (DataArray, Variable)):
+ # accommodate other types that can be coerced into Variables
+ val = DataArray(val)
+
+ self.update({key: val})
- def __setitem__(self, key: str, value: Any) ->None:
+ def __setitem__(
+ self,
+ key: str,
+ value: Any,
+ ) -> None:
"""
Add either a child node or an array to the tree, at any position.
@@ -464,25 +921,72 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
If there is already a node at the given location, then if value is a Node class or Dataset it will overwrite the
data already present at that node, and if value is a single array, it will be merged with it.
"""
+ # TODO xarray.Dataset accepts other possibilities, how do we exactly replicate all the behaviour?
if utils.is_dict_like(key):
raise NotImplementedError
elif isinstance(key, str):
+ # TODO should possibly deal with hashables in general?
+ # path-like: a name of a node/variable, or path to a node/variable
path = NodePath(key)
return self._set_item(path, value, new_nodes_along_path=True)
else:
- raise ValueError('Invalid format for key')
+ raise ValueError("Invalid format for key")
+
+ @overload
+ def update(self, other: Dataset) -> None: ...
+
+ @overload
+ def update(self, other: Mapping[Hashable, DataArray | Variable]) -> None: ...
- def update(self, other: (Dataset | Mapping[Hashable, DataArray |
- Variable] | Mapping[str, DataTree | DataArray | Variable])) ->None:
+ @overload
+ def update(self, other: Mapping[str, DataTree | DataArray | Variable]) -> None: ...
+
+ def update(
+ self,
+ other: (
+ Dataset
+ | Mapping[Hashable, DataArray | Variable]
+ | Mapping[str, DataTree | DataArray | Variable]
+ ),
+ ) -> None:
"""
Update this node's children and / or variables.
Just like `dict.update` this is an in-place operation.
"""
- pass
+ new_children: dict[str, DataTree] = {}
+ new_variables: CoercibleMapping
- def assign(self, items: (Mapping[Any, Any] | None)=None, **items_kwargs:
- Any) ->DataTree:
+ if isinstance(other, Dataset):
+ new_variables = other
+ else:
+ new_variables = {}
+ for k, v in other.items():
+ if isinstance(v, DataTree):
+ # avoid named node being stored under inconsistent key
+ new_child: DataTree = v.copy()
+ # Datatree's name is always a string until we fix that (#8836)
+ new_child.name = str(k)
+ new_children[str(k)] = new_child
+ elif isinstance(v, (DataArray, Variable)):
+ # TODO this should also accommodate other types that can be coerced into Variables
+ new_variables[k] = v
+ else:
+ raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree")
+
+ vars_merge_result = dataset_update_method(
+ self.to_dataset(inherited=False), new_variables
+ )
+ data = Dataset._construct_direct(**vars_merge_result._asdict())
+
+ # TODO are there any subtleties with preserving order of children like this?
+ merged_children = {**self.children, **new_children}
+
+ self._replace_node(data, children=merged_children)
+
+ def assign(
+ self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any
+ ) -> DataTree:
"""
Assign new data variables or child nodes to a DataTree, returning a new object
with all the original items in addition to the new ones.
@@ -517,10 +1021,14 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
xarray.Dataset.assign
pandas.DataFrame.assign
"""
- pass
-
- def drop_nodes(self: DataTree, names: (str | Iterable[str]), *, errors:
- ErrorOptions='raise') ->DataTree:
+ items = either_dict_or_kwargs(items, items_kwargs, "assign")
+ dt = self.copy()
+ dt.update(items)
+ return dt
+
+ def drop_nodes(
+ self: DataTree, names: str | Iterable[str], *, errors: ErrorOptions = "raise"
+ ) -> DataTree:
"""
Drop child nodes from this node.
@@ -538,11 +1046,30 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
dropped : DataTree
A copy of the node with the specified children dropped.
"""
- pass
+ # the Iterable check is required for mypy
+ if isinstance(names, str) or not isinstance(names, Iterable):
+ names = {names}
+ else:
+ names = set(names)
+
+ if errors == "raise":
+ extra = names - set(self.children)
+ if extra:
+ raise KeyError(f"Cannot drop all nodes - nodes {extra} not present")
+
+ result = self.copy()
+ children_to_keep = {
+ name: child for name, child in result.children.items() if name not in names
+ }
+ result._replace_node(children=children_to_keep)
+ return result
@classmethod
- def from_dict(cls, d: MutableMapping[str, Dataset | DataArray |
- DataTree | None], name: (str | None)=None) ->DataTree:
+ def from_dict(
+ cls,
+ d: MutableMapping[str, Dataset | DataArray | DataTree | None],
+ name: str | None = None,
+ ) -> DataTree:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
@@ -566,9 +1093,35 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
-----
If your dictionary is nested you will need to flatten it before using this method.
"""
- pass
- def to_dict(self) ->dict[str, Dataset]:
+ # First create the root node
+ root_data = d.pop("/", None)
+ if isinstance(root_data, DataTree):
+ obj = root_data.copy()
+ obj.orphan()
+ else:
+ obj = cls(name=name, data=root_data, parent=None, children=None)
+
+ if d:
+ # Populate tree with children determined from data_objects mapping
+ for path, data in d.items():
+ # Create and set new node
+ node_name = NodePath(path).name
+ if isinstance(data, DataTree):
+ new_node = data.copy()
+ new_node.orphan()
+ else:
+ new_node = cls(name=node_name, data=data)
+ obj._set_item(
+ path,
+ new_node,
+ allow_overwrite=False,
+ new_nodes_along_path=True,
+ )
+
+ return obj
+
+ def to_dict(self) -> dict[str, Dataset]:
"""
Create a dictionary mapping of absolute node paths to the data contained in those nodes.
@@ -576,13 +1129,17 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
-------
dict[str, Dataset]
"""
- pass
+ return {node.path: node.to_dataset() for node in self.subtree}
+
+ @property
+ def nbytes(self) -> int:
+ return sum(node.to_dataset().nbytes for node in self.subtree)
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self.children) + len(self.data_vars)
@property
- def indexes(self) ->Indexes[pd.Index]:
+ def indexes(self) -> Indexes[pd.Index]:
"""Mapping of pandas.Index objects used for label based indexing.
Raises an error if this DataTree node has indexes that cannot be coerced
@@ -592,27 +1149,33 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
--------
DataTree.xindexes
"""
- pass
+ return self.xindexes.to_pandas_indexes()
@property
- def xindexes(self) ->Indexes[Index]:
+ def xindexes(self) -> Indexes[Index]:
"""Mapping of xarray Index objects used for label based indexing."""
- pass
+ return Indexes(
+ self._indexes, {k: self._coord_variables[k] for k in self._indexes}
+ )
@property
- def coords(self) ->DatasetCoordinates:
+ def coords(self) -> DatasetCoordinates:
"""Dictionary of xarray.DataArray objects corresponding to coordinate
variables
"""
- pass
+ return DatasetCoordinates(self.to_dataset())
@property
- def data_vars(self) ->DataVariables:
+ def data_vars(self) -> DataVariables:
"""Dictionary of DataArray objects corresponding to data variables"""
- pass
-
- def isomorphic(self, other: DataTree, from_root: bool=False,
- strict_names: bool=False) ->bool:
+ return DataVariables(self.to_dataset())
+
+ def isomorphic(
+ self,
+ other: DataTree,
+ from_root: bool = False,
+ strict_names: bool = False,
+ ) -> bool:
"""
Two DataTrees are considered isomorphic if every node has the same number of children.
@@ -640,9 +1203,18 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
DataTree.equals
DataTree.identical
"""
- pass
+ try:
+ check_isomorphic(
+ self,
+ other,
+ require_names_equal=strict_names,
+ check_from_root=from_root,
+ )
+ return True
+ except (TypeError, TreeIsomorphismError):
+ return False
- def equals(self, other: DataTree, from_root: bool=True) ->bool:
+ def equals(self, other: DataTree, from_root: bool = True) -> bool:
"""
Two DataTrees are equal if they have isomorphic node structures, with matching node names,
and if they have matching variables and coordinates, all of which are equal.
@@ -663,9 +1235,17 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
DataTree.isomorphic
DataTree.identical
"""
- pass
+ if not self.isomorphic(other, from_root=from_root, strict_names=True):
+ return False
+
+ return all(
+ [
+ node.ds.equals(other_node.ds)
+ for node, other_node in zip(self.subtree, other.subtree)
+ ]
+ )
- def identical(self, other: DataTree, from_root=True) ->bool:
+ def identical(self, other: DataTree, from_root=True) -> bool:
"""
Like equals, but will also check all dataset attributes and the attributes on
all variables and coordinates.
@@ -686,10 +1266,15 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
DataTree.isomorphic
DataTree.equals
"""
- pass
+ if not self.isomorphic(other, from_root=from_root, strict_names=True):
+ return False
- def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]
- ) ->DataTree:
+ return all(
+ node.ds.identical(other_node.ds)
+ for node, other_node in zip(self.subtree, other.subtree)
+ )
+
+ def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree:
"""
Filter nodes according to a specified condition.
@@ -711,9 +1296,12 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
pipe
map_over_subtree
"""
- pass
+ filtered_nodes = {
+ node.path: node.ds for node in self.subtree if filterfunc(node)
+ }
+ return DataTree.from_dict(filtered_nodes, name=self.root.name)
- def match(self, pattern: str) ->DataTree:
+ def match(self, pattern: str) -> DataTree:
"""
Return nodes with paths matching pattern.
@@ -752,10 +1340,19 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
└── Group: /b
└── Group: /b/B
"""
- pass
-
- def map_over_subtree(self, func: Callable, *args: Iterable[Any], **
- kwargs: Any) ->(DataTree | tuple[DataTree]):
+ matching_nodes = {
+ node.path: node.ds
+ for node in self.subtree
+ if NodePath(node.path).match(pattern)
+ }
+ return DataTree.from_dict(matching_nodes, name=self.root.name)
+
+ def map_over_subtree(
+ self,
+ func: Callable,
+ *args: Iterable[Any],
+ **kwargs: Any,
+ ) -> DataTree | tuple[DataTree]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
@@ -781,10 +1378,17 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
subtrees : DataTree, tuple of DataTrees
One or more subtrees containing results from applying ``func`` to the data at each node.
"""
- pass
+ # TODO this signature means that func has no way to know which node it is being called upon - change?
+
+ # TODO fix this typing error
+ return map_over_subtree(func)(self, *args, **kwargs)
- def map_over_subtree_inplace(self, func: Callable, *args: Iterable[Any],
- **kwargs: Any) ->None:
+ def map_over_subtree_inplace(
+ self,
+ func: Callable,
+ *args: Iterable[Any],
+ **kwargs: Any,
+ ) -> None:
"""
Apply a function to every dataset in this subtree, updating data in place.
@@ -800,10 +1404,16 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
**kwargs : Any
Keyword arguments passed on to `func`.
"""
- pass
- def pipe(self, func: (Callable | tuple[Callable, str]), *args: Any, **
- kwargs: Any) ->Any:
+ # TODO if func fails on some node then the previous nodes will still have been updated...
+
+ for node in self.subtree:
+ if node.has_data:
+ node.ds = func(node.ds, *args, **kwargs)
+
+ def pipe(
+ self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
+ ) -> Any:
"""Apply ``func(self, *args, **kwargs)``
This method replicates the pandas method of the same name.
@@ -850,29 +1460,54 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
(dt.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c))
"""
- pass
+ if isinstance(func, tuple):
+ func, target = func
+ if target in kwargs:
+ raise ValueError(
+ f"{target} is both the pipe target and a keyword argument"
+ )
+ kwargs[target] = self
+ else:
+ args = (self,) + args
+ return func(*args, **kwargs)
def render(self):
"""Print tree structure, including any data stored at each node."""
- pass
+ for pre, fill, node in RenderDataTree(self):
+ print(f"{pre}DataTree('{self.name}')")
+ for ds_line in repr(node.ds)[1:]:
+ print(f"{fill}{ds_line}")
- def merge(self, datatree: DataTree) ->DataTree:
+ def merge(self, datatree: DataTree) -> DataTree:
"""Merge all the leaves of a second DataTree into this one."""
- pass
+ raise NotImplementedError
- def merge_child_nodes(self, *paths, new_path: T_Path) ->DataTree:
+ def merge_child_nodes(self, *paths, new_path: T_Path) -> DataTree:
"""Merge a set of child nodes into a single new node."""
- pass
+ raise NotImplementedError
+
+ # TODO some kind of .collapse() or .flatten() method to merge a subtree
+
+ def to_dataarray(self) -> DataArray:
+ return self.ds.to_dataarray()
@property
def groups(self):
"""Return all netCDF4 groups in the tree, given as a tuple of path-like strings."""
- pass
-
- def to_netcdf(self, filepath, mode: NetcdfWriteModes='w', encoding=None,
- unlimited_dims=None, format: (T_DataTreeNetcdfTypes | None)=None,
- engine: (T_DataTreeNetcdfEngine | None)=None, group: (str | None)=
- None, compute: bool=True, **kwargs):
+ return tuple(node.path for node in self.subtree)
+
+ def to_netcdf(
+ self,
+ filepath,
+ mode: NetcdfWriteModes = "w",
+ encoding=None,
+ unlimited_dims=None,
+ format: T_DataTreeNetcdfTypes | None = None,
+ engine: T_DataTreeNetcdfEngine | None = None,
+ group: str | None = None,
+ compute: bool = True,
+ **kwargs,
+ ):
"""
Write datatree contents to a netCDF file.
@@ -913,11 +1548,31 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
kwargs :
Addional keyword arguments to be passed to ``xarray.Dataset.to_netcdf``
"""
- pass
-
- def to_zarr(self, store, mode: ZarrWriteModes='w-', encoding=None,
- consolidated: bool=True, group: (str | None)=None, compute: Literal
- [True]=True, **kwargs):
+ from xarray.core.datatree_io import _datatree_to_netcdf
+
+ _datatree_to_netcdf(
+ self,
+ filepath,
+ mode=mode,
+ encoding=encoding,
+ unlimited_dims=unlimited_dims,
+ format=format,
+ engine=engine,
+ group=group,
+ compute=compute,
+ **kwargs,
+ )
+
+ def to_zarr(
+ self,
+ store,
+ mode: ZarrWriteModes = "w-",
+ encoding=None,
+ consolidated: bool = True,
+ group: str | None = None,
+ compute: Literal[True] = True,
+ **kwargs,
+ ):
"""
Write datatree contents to a Zarr store.
@@ -948,4 +1603,18 @@ class DataTree(NamedNode, MappedDatasetMethodsMixin, MappedDataWithCoords,
kwargs :
Additional keyword arguments to be passed to ``xarray.Dataset.to_zarr``
"""
- pass
+ from xarray.core.datatree_io import _datatree_to_zarr
+
+ _datatree_to_zarr(
+ self,
+ store,
+ mode=mode,
+ encoding=encoding,
+ consolidated=consolidated,
+ group=group,
+ compute=compute,
+ **kwargs,
+ )
+
+ def plot(self):
+ raise NotImplementedError
diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py
index ac9c6ad9..36665a0d 100644
--- a/xarray/core/datatree_io.py
+++ b/xarray/core/datatree_io.py
@@ -1,33 +1,171 @@
from __future__ import annotations
+
from collections.abc import Mapping, MutableMapping
from os import PathLike
from typing import Any, Literal, get_args
+
from xarray.core.datatree import DataTree
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
-T_DataTreeNetcdfEngine = Literal['netcdf4', 'h5netcdf']
-T_DataTreeNetcdfTypes = Literal['NETCDF4']
+T_DataTreeNetcdfEngine = Literal["netcdf4", "h5netcdf"]
+T_DataTreeNetcdfTypes = Literal["NETCDF4"]
+
+
+def _get_nc_dataset_class(engine: T_DataTreeNetcdfEngine | None):
+ if engine == "netcdf4":
+ from netCDF4 import Dataset
+ elif engine == "h5netcdf":
+ from h5netcdf.legacyapi import Dataset
+ elif engine is None:
+ try:
+ from netCDF4 import Dataset
+ except ImportError:
+ from h5netcdf.legacyapi import Dataset
+ else:
+ raise ValueError(f"unsupported engine: {engine}")
+ return Dataset
+
+
+def _create_empty_netcdf_group(
+ filename: str | PathLike,
+ group: str,
+ mode: NetcdfWriteModes,
+ engine: T_DataTreeNetcdfEngine | None,
+):
+ ncDataset = _get_nc_dataset_class(engine)
+
+ with ncDataset(filename, mode=mode) as rootgrp:
+ rootgrp.createGroup(group)
-def _datatree_to_netcdf(dt: DataTree, filepath: (str | PathLike), mode:
- NetcdfWriteModes='w', encoding: (Mapping[str, Any] | None)=None,
- unlimited_dims: (Mapping | None)=None, format: (T_DataTreeNetcdfTypes |
- None)=None, engine: (T_DataTreeNetcdfEngine | None)=None, group: (str |
- None)=None, compute: bool=True, **kwargs):
+
+def _datatree_to_netcdf(
+ dt: DataTree,
+ filepath: str | PathLike,
+ mode: NetcdfWriteModes = "w",
+ encoding: Mapping[str, Any] | None = None,
+ unlimited_dims: Mapping | None = None,
+ format: T_DataTreeNetcdfTypes | None = None,
+ engine: T_DataTreeNetcdfEngine | None = None,
+ group: str | None = None,
+ compute: bool = True,
+ **kwargs,
+):
"""This function creates an appropriate datastore for writing a datatree to
disk as a netCDF file.
See `DataTree.to_netcdf` for full API docs.
"""
- pass
+ if format not in [None, *get_args(T_DataTreeNetcdfTypes)]:
+ raise ValueError("to_netcdf only supports the NETCDF4 format")
+
+ if engine not in [None, *get_args(T_DataTreeNetcdfEngine)]:
+ raise ValueError("to_netcdf only supports the netcdf4 and h5netcdf engines")
+
+ if group is not None:
+ raise NotImplementedError(
+ "specifying a root group for the tree has not been implemented"
+ )
+
+ if not compute:
+ raise NotImplementedError("compute=False has not been implemented yet")
+
+ if encoding is None:
+ encoding = {}
+
+ # In the future, we may want to expand this check to insure all the provided encoding
+ # options are valid. For now, this simply checks that all provided encoding keys are
+ # groups in the datatree.
+ if set(encoding) - set(dt.groups):
+ raise ValueError(
+ f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}"
+ )
-def _datatree_to_zarr(dt: DataTree, store: (MutableMapping | str | PathLike
- [str]), mode: ZarrWriteModes='w-', encoding: (Mapping[str, Any] | None)
- =None, consolidated: bool=True, group: (str | None)=None, compute:
- Literal[True]=True, **kwargs):
+ if unlimited_dims is None:
+ unlimited_dims = {}
+
+ for node in dt.subtree:
+ ds = node.to_dataset(inherited=False)
+ group_path = node.path
+ if ds is None:
+ _create_empty_netcdf_group(filepath, group_path, mode, engine)
+ else:
+ ds.to_netcdf(
+ filepath,
+ group=group_path,
+ mode=mode,
+ encoding=encoding.get(node.path),
+ unlimited_dims=unlimited_dims.get(node.path),
+ engine=engine,
+ format=format,
+ compute=compute,
+ **kwargs,
+ )
+ mode = "a"
+
+
+def _create_empty_zarr_group(
+ store: MutableMapping | str | PathLike[str], group: str, mode: ZarrWriteModes
+):
+ import zarr
+
+ root = zarr.open_group(store, mode=mode)
+ root.create_group(group, overwrite=True)
+
+
+def _datatree_to_zarr(
+ dt: DataTree,
+ store: MutableMapping | str | PathLike[str],
+ mode: ZarrWriteModes = "w-",
+ encoding: Mapping[str, Any] | None = None,
+ consolidated: bool = True,
+ group: str | None = None,
+ compute: Literal[True] = True,
+ **kwargs,
+):
"""This function creates an appropriate datastore for writing a datatree
to a zarr store.
See `DataTree.to_zarr` for full API docs.
"""
- pass
+
+ from zarr.convenience import consolidate_metadata
+
+ if group is not None:
+ raise NotImplementedError(
+ "specifying a root group for the tree has not been implemented"
+ )
+
+ if not compute:
+ raise NotImplementedError("compute=False has not been implemented yet")
+
+ if encoding is None:
+ encoding = {}
+
+ # In the future, we may want to expand this check to insure all the provided encoding
+ # options are valid. For now, this simply checks that all provided encoding keys are
+ # groups in the datatree.
+ if set(encoding) - set(dt.groups):
+ raise ValueError(
+ f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}"
+ )
+
+ for node in dt.subtree:
+ ds = node.to_dataset(inherited=False)
+ group_path = node.path
+ if ds is None:
+ _create_empty_zarr_group(store, group_path, mode)
+ else:
+ ds.to_zarr(
+ store,
+ group=group_path,
+ mode=mode,
+ encoding=encoding.get(node.path),
+ consolidated=False,
+ **kwargs,
+ )
+ if "w" in mode:
+ mode = "a"
+
+ if consolidated:
+ consolidate_metadata(store)
diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py
index bb01488b..6e5aae15 100644
--- a/xarray/core/datatree_mapping.py
+++ b/xarray/core/datatree_mapping.py
@@ -1,23 +1,31 @@
from __future__ import annotations
+
import functools
import sys
from itertools import repeat
from typing import TYPE_CHECKING, Callable
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.formatting import diff_treestructure
from xarray.core.treenode import NodePath, TreeNode
+
if TYPE_CHECKING:
from xarray.core.datatree import DataTree
class TreeIsomorphismError(ValueError):
"""Error raised if two tree objects do not share the same node structure."""
+
pass
-def check_isomorphic(a: DataTree, b: DataTree, require_names_equal: bool=
- False, check_from_root: bool=True):
+def check_isomorphic(
+ a: DataTree,
+ b: DataTree,
+ require_names_equal: bool = False,
+ check_from_root: bool = True,
+):
"""
Check that two trees have the same structure, raising an error if not.
@@ -47,10 +55,23 @@ def check_isomorphic(a: DataTree, b: DataTree, require_names_equal: bool=
Also optionally raised if their structure is isomorphic, but the names of any two
respective nodes are not equal.
"""
- pass
+ if not isinstance(a, TreeNode):
+ raise TypeError(f"Argument `a` is not a tree, it is of type {type(a)}")
+ if not isinstance(b, TreeNode):
+ raise TypeError(f"Argument `b` is not a tree, it is of type {type(b)}")
-def map_over_subtree(func: Callable) ->Callable:
+ if check_from_root:
+ a = a.root
+ b = b.root
+
+ diff = diff_treestructure(a, b, require_names_equal=require_names_equal)
+
+ if diff:
+ raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)
+
+
+def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
@@ -94,20 +115,200 @@ def map_over_subtree(func: Callable) ->Callable:
DataTree.map_over_subtree_inplace
DataTree.subtree
"""
- pass
+
+ # TODO examples in the docstring
+
+ # TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
+
+ @functools.wraps(func)
+ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
+ """Internal function which maps func over every node in tree, returning a tree of the results."""
+ from xarray.core.datatree import DataTree
+
+ all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
+ a for a in kwargs.values() if isinstance(a, DataTree)
+ ]
+
+ if len(all_tree_inputs) > 0:
+ first_tree, *other_trees = all_tree_inputs
+ else:
+ raise TypeError("Must pass at least one tree object")
+
+ for other_tree in other_trees:
+ # isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
+ check_isomorphic(
+ first_tree, other_tree, require_names_equal=False, check_from_root=False
+ )
+
+ # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
+ # We don't know which arguments are DataTrees so we zip all arguments together as iterables
+ # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
+ out_data_objects = {}
+ args_as_tree_length_iterables = [
+ a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
+ ]
+ n_args = len(args_as_tree_length_iterables)
+ kwargs_as_tree_length_iterables = {
+ k: v.subtree if isinstance(v, DataTree) else repeat(v)
+ for k, v in kwargs.items()
+ }
+ for node_of_first_tree, *all_node_args in zip(
+ first_tree.subtree,
+ *args_as_tree_length_iterables,
+ *list(kwargs_as_tree_length_iterables.values()),
+ ):
+ node_args_as_datasetviews = [
+ a.ds if isinstance(a, DataTree) else a for a in all_node_args[:n_args]
+ ]
+ node_kwargs_as_datasetviews = dict(
+ zip(
+ [k for k in kwargs_as_tree_length_iterables.keys()],
+ [
+ v.ds if isinstance(v, DataTree) else v
+ for v in all_node_args[n_args:]
+ ],
+ )
+ )
+ func_with_error_context = _handle_errors_with_path_context(
+ node_of_first_tree.path
+ )(func)
+
+ if node_of_first_tree.has_data:
+ # call func on the data in this particular set of corresponding nodes
+ results = func_with_error_context(
+ *node_args_as_datasetviews, **node_kwargs_as_datasetviews
+ )
+ elif node_of_first_tree.has_attrs:
+ # propagate attrs
+ results = node_of_first_tree.ds
+ else:
+ # nothing to propagate so use fastpath to create empty node in new tree
+ results = None
+
+ # TODO implement mapping over multiple trees in-place using if conditions from here on?
+ out_data_objects[node_of_first_tree.path] = results
+
+ # Find out how many return values we received
+ num_return_values = _check_all_return_values(out_data_objects)
+
+ # Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
+ original_root_path = first_tree.path
+ result_trees = []
+ for i in range(num_return_values):
+ out_tree_contents = {}
+ for n in first_tree.subtree:
+ p = n.path
+ if p in out_data_objects.keys():
+ if isinstance(out_data_objects[p], tuple):
+ output_node_data = out_data_objects[p][i]
+ else:
+ output_node_data = out_data_objects[p]
+ else:
+ output_node_data = None
+
+ # Discard parentage so that new trees don't include parents of input nodes
+ relative_path = str(NodePath(p).relative_to(original_root_path))
+ relative_path = "/" if relative_path == "." else relative_path
+ out_tree_contents[relative_path] = output_node_data
+
+ new_tree = DataTree.from_dict(
+ out_tree_contents,
+ name=first_tree.name,
+ )
+ result_trees.append(new_tree)
+
+ # If only one result then don't wrap it in a tuple
+ if len(result_trees) == 1:
+ return result_trees[0]
+ else:
+ return tuple(result_trees)
+
+ return _map_over_subtree
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""
- pass
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ # Add the context information to the error message
+ add_note(
+ e, f"Raised whilst mapping function over node with path {path}"
+ )
+ raise
+
+ return wrapper
+
+ return decorator
+
+
+def add_note(err: BaseException, msg: str) -> None:
+ # TODO: remove once python 3.10 can be dropped
+ if sys.version_info < (3, 11):
+ err.__notes__ = getattr(err, "__notes__", []) + [msg] # type: ignore[attr-defined]
+ else:
+ err.add_note(msg)
-def _check_single_set_return_values(path_to_node: str, obj: (Dataset |
- DataArray | tuple[Dataset | DataArray])):
+
+def _check_single_set_return_values(
+ path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
+):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
- pass
+ if isinstance(obj, (Dataset, DataArray)):
+ return 1
+ elif isinstance(obj, tuple):
+ for r in obj:
+ if not isinstance(r, (Dataset, DataArray)):
+ raise TypeError(
+ f"One of the results of calling func on datasets on the nodes at position {path_to_node} is "
+ f"of type {type(r)}, not Dataset or DataArray."
+ )
+ return len(obj)
+ else:
+ raise TypeError(
+ f"The result of calling func on the node at position {path_to_node} is of type {type(obj)}, not "
+ f"Dataset or DataArray, nor a tuple of such types."
+ )
def _check_all_return_values(returned_objects):
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""
- pass
+
+ if all(r is None for r in returned_objects.values()):
+ raise TypeError(
+ "Called supplied function on all nodes but found a return value of None for"
+ "all of them."
+ )
+
+ result_data_objects = [
+ (path_to_node, r)
+ for path_to_node, r in returned_objects.items()
+ if r is not None
+ ]
+
+ if len(result_data_objects) == 1:
+ # Only one node in the tree: no need to check consistency of results between nodes
+ path_to_node, result = result_data_objects[0]
+ num_return_values = _check_single_set_return_values(path_to_node, result)
+ else:
+ prev_path, _ = result_data_objects[0]
+ prev_num_return_values, num_return_values = None, None
+ for path_to_node, obj in result_data_objects[1:]:
+ num_return_values = _check_single_set_return_values(path_to_node, obj)
+
+ if (
+ num_return_values != prev_num_return_values
+ and prev_num_return_values is not None
+ ):
+ raise TypeError(
+ f"Calling func on the nodes at position {path_to_node} returns {num_return_values} separate return "
+ f"values, whereas calling func on the nodes at position {prev_path} instead returns "
+ f"{prev_num_return_values} separate return values."
+ )
+
+ prev_path, prev_num_return_values = path_to_node, num_return_values
+
+ return num_return_values
diff --git a/xarray/core/datatree_ops.py b/xarray/core/datatree_ops.py
index 77c69078..bc64b44a 100644
--- a/xarray/core/datatree_ops.py
+++ b/xarray/core/datatree_ops.py
@@ -1,52 +1,178 @@
from __future__ import annotations
+
import re
import textwrap
+
from xarray.core.dataset import Dataset
from xarray.core.datatree_mapping import map_over_subtree
+
"""
Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree.
Structured to mirror the way xarray defines Dataset's various operations internally, but does not actually import from
xarray's internals directly, only the public-facing xarray.Dataset class.
"""
+
+
_MAPPED_DOCSTRING_ADDENDUM = (
- 'This method was copied from xarray.Dataset, but has been altered to call the method on the Datasets stored in every node of the subtree. See the `map_over_subtree` function for more details.'
- )
-_DATASET_DASK_METHODS_TO_MAP = ['load', 'compute', 'persist',
- 'unify_chunks', 'chunk', 'map_blocks']
-_DATASET_METHODS_TO_MAP = ['as_numpy', 'set_coords', 'reset_coords', 'info',
- 'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
- 'reindex', 'interp', 'interp_like', 'rename', 'rename_dims',
- 'rename_vars', 'swap_dims', 'expand_dims', 'set_index', 'reset_index',
- 'reorder_levels', 'stack', 'unstack', 'merge', 'drop_vars', 'drop_sel',
- 'drop_isel', 'drop_dims', 'transpose', 'dropna', 'fillna',
- 'interpolate_na', 'ffill', 'bfill', 'combine_first', 'reduce', 'map',
- 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank', 'differentiate',
- 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
- 'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
-_ALL_DATASET_METHODS_TO_MAP = (_DATASET_DASK_METHODS_TO_MAP +
- _DATASET_METHODS_TO_MAP)
-_DATA_WITH_COORDS_METHODS_TO_MAP = ['squeeze', 'clip', 'assign_coords',
- 'where', 'close', 'isnull', 'notnull', 'isin', 'astype']
-REDUCE_METHODS = ['all', 'any']
-NAN_REDUCE_METHODS = ['max', 'min', 'mean', 'prod', 'sum', 'std', 'var',
- 'median']
-NAN_CUM_METHODS = ['cumsum', 'cumprod']
-_TYPED_DATASET_OPS_TO_MAP = ['__add__', '__sub__', '__mul__', '__pow__',
- '__truediv__', '__floordiv__', '__mod__', '__and__', '__xor__',
- '__or__', '__lt__', '__le__', '__gt__', '__ge__', '__eq__', '__ne__',
- '__radd__', '__rsub__', '__rmul__', '__rpow__', '__rtruediv__',
- '__rfloordiv__', '__rmod__', '__rand__', '__rxor__', '__ror__',
- '__iadd__', '__isub__', '__imul__', '__ipow__', '__itruediv__',
- '__ifloordiv__', '__imod__', '__iand__', '__ixor__', '__ior__',
- '__neg__', '__pos__', '__abs__', '__invert__', 'round', 'argsort',
- 'conj', 'conjugate']
-_ARITHMETIC_METHODS_TO_MAP = (REDUCE_METHODS + NAN_REDUCE_METHODS +
- NAN_CUM_METHODS + _TYPED_DATASET_OPS_TO_MAP + ['__array_ufunc__'])
-
-
-def _wrap_then_attach_to_cls(target_cls_dict, source_cls, methods_to_set,
- wrap_func=None):
+ "This method was copied from xarray.Dataset, but has been altered to "
+ "call the method on the Datasets stored in every node of the subtree. "
+ "See the `map_over_subtree` function for more details."
+)
+
+# TODO equals, broadcast_equals etc.
+# TODO do dask-related private methods need to be exposed?
+_DATASET_DASK_METHODS_TO_MAP = [
+ "load",
+ "compute",
+ "persist",
+ "unify_chunks",
+ "chunk",
+ "map_blocks",
+]
+_DATASET_METHODS_TO_MAP = [
+ "as_numpy",
+ "set_coords",
+ "reset_coords",
+ "info",
+ "isel",
+ "sel",
+ "head",
+ "tail",
+ "thin",
+ "broadcast_like",
+ "reindex_like",
+ "reindex",
+ "interp",
+ "interp_like",
+ "rename",
+ "rename_dims",
+ "rename_vars",
+ "swap_dims",
+ "expand_dims",
+ "set_index",
+ "reset_index",
+ "reorder_levels",
+ "stack",
+ "unstack",
+ "merge",
+ "drop_vars",
+ "drop_sel",
+ "drop_isel",
+ "drop_dims",
+ "transpose",
+ "dropna",
+ "fillna",
+ "interpolate_na",
+ "ffill",
+ "bfill",
+ "combine_first",
+ "reduce",
+ "map",
+ "diff",
+ "shift",
+ "roll",
+ "sortby",
+ "quantile",
+ "rank",
+ "differentiate",
+ "integrate",
+ "cumulative_integrate",
+ "filter_by_attrs",
+ "polyfit",
+ "pad",
+ "idxmin",
+ "idxmax",
+ "argmin",
+ "argmax",
+ "query",
+ "curvefit",
+]
+_ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP
+
+_DATA_WITH_COORDS_METHODS_TO_MAP = [
+ "squeeze",
+ "clip",
+ "assign_coords",
+ "where",
+ "close",
+ "isnull",
+ "notnull",
+ "isin",
+ "astype",
+]
+
+REDUCE_METHODS = ["all", "any"]
+NAN_REDUCE_METHODS = [
+ "max",
+ "min",
+ "mean",
+ "prod",
+ "sum",
+ "std",
+ "var",
+ "median",
+]
+NAN_CUM_METHODS = ["cumsum", "cumprod"]
+_TYPED_DATASET_OPS_TO_MAP = [
+ "__add__",
+ "__sub__",
+ "__mul__",
+ "__pow__",
+ "__truediv__",
+ "__floordiv__",
+ "__mod__",
+ "__and__",
+ "__xor__",
+ "__or__",
+ "__lt__",
+ "__le__",
+ "__gt__",
+ "__ge__",
+ "__eq__",
+ "__ne__",
+ "__radd__",
+ "__rsub__",
+ "__rmul__",
+ "__rpow__",
+ "__rtruediv__",
+ "__rfloordiv__",
+ "__rmod__",
+ "__rand__",
+ "__rxor__",
+ "__ror__",
+ "__iadd__",
+ "__isub__",
+ "__imul__",
+ "__ipow__",
+ "__itruediv__",
+ "__ifloordiv__",
+ "__imod__",
+ "__iand__",
+ "__ixor__",
+ "__ior__",
+ "__neg__",
+ "__pos__",
+ "__abs__",
+ "__invert__",
+ "round",
+ "argsort",
+ "conj",
+ "conjugate",
+]
+# TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere...
+_ARITHMETIC_METHODS_TO_MAP = (
+ REDUCE_METHODS
+ + NAN_REDUCE_METHODS
+ + NAN_CUM_METHODS
+ + _TYPED_DATASET_OPS_TO_MAP
+ + ["__array_ufunc__"]
+)
+
+
+def _wrap_then_attach_to_cls(
+ target_cls_dict, source_cls, methods_to_set, wrap_func=None
+):
"""
Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree).
@@ -73,10 +199,25 @@ def _wrap_then_attach_to_cls(target_cls_dict, source_cls, methods_to_set,
wrap_func : callable, optional
Function to decorate each method with. Must have the same return type as the method.
"""
- pass
+ for method_name in methods_to_set:
+ orig_method = getattr(source_cls, method_name)
+ wrapped_method = (
+ wrap_func(orig_method) if wrap_func is not None else orig_method
+ )
+ target_cls_dict[method_name] = wrapped_method
+ if wrap_func is map_over_subtree:
+ # Add a paragraph to the method's docstring explaining how it's been mapped
+ orig_method_docstring = orig_method.__doc__
-def insert_doc_addendum(docstring: (str | None), addendum: str) ->(str | None):
+ if orig_method_docstring is not None:
+ new_method_docstring = insert_doc_addendum(
+ orig_method_docstring, _MAPPED_DOCSTRING_ADDENDUM
+ )
+ setattr(target_cls_dict[method_name], "__doc__", new_method_docstring)
+
+
+def insert_doc_addendum(docstring: str | None, addendum: str) -> str | None:
"""Insert addendum after first paragraph or at the end of the docstring.
There are a number of Dataset's functions that are wrapped. These come from
@@ -86,7 +227,44 @@ def insert_doc_addendum(docstring: (str | None), addendum: str) ->(str | None):
don't, just have the addendum appeneded after. None values are returned.
"""
- pass
+ if docstring is None:
+ return None
+
+ pattern = re.compile(
+ r"^(?P<start>(\S+)?(.*?))(?P<paragraph_break>\n\s*\n)(?P<whitespace>[ ]*)(?P<rest>.*)",
+ re.DOTALL,
+ )
+ capture = re.match(pattern, docstring)
+ if capture is None:
+ ### single line docstring.
+ return (
+ docstring
+ + "\n\n"
+ + textwrap.fill(
+ addendum,
+ subsequent_indent=" ",
+ width=79,
+ )
+ )
+
+ if len(capture.groups()) == 6:
+ return (
+ capture["start"]
+ + capture["paragraph_break"]
+ + capture["whitespace"]
+ + ".. note::\n"
+ + textwrap.fill(
+ addendum,
+ initial_indent=capture["whitespace"] + " ",
+ subsequent_indent=capture["whitespace"] + " ",
+ width=79,
+ )
+ + capture["paragraph_break"]
+ + capture["whitespace"]
+ + capture["rest"]
+ )
+ else:
+ return docstring
class MappedDatasetMethodsMixin:
@@ -94,17 +272,27 @@ class MappedDatasetMethodsMixin:
Mixin to add methods defined specifically on the Dataset class such as .query(), but wrapped to map over all nodes
in the subtree.
"""
- _wrap_then_attach_to_cls(target_cls_dict=vars(), source_cls=Dataset,
- methods_to_set=_ALL_DATASET_METHODS_TO_MAP, wrap_func=map_over_subtree)
+
+ _wrap_then_attach_to_cls(
+ target_cls_dict=vars(),
+ source_cls=Dataset,
+ methods_to_set=_ALL_DATASET_METHODS_TO_MAP,
+ wrap_func=map_over_subtree,
+ )
class MappedDataWithCoords:
"""
Mixin to add coordinate-aware Dataset methods such as .where(), but wrapped to map over all nodes in the subtree.
"""
- _wrap_then_attach_to_cls(target_cls_dict=vars(), source_cls=Dataset,
- methods_to_set=_DATA_WITH_COORDS_METHODS_TO_MAP, wrap_func=
- map_over_subtree)
+
+ # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample
+ _wrap_then_attach_to_cls(
+ target_cls_dict=vars(),
+ source_cls=Dataset,
+ methods_to_set=_DATA_WITH_COORDS_METHODS_TO_MAP,
+ wrap_func=map_over_subtree,
+ )
class DataTreeArithmeticMixin:
@@ -112,5 +300,10 @@ class DataTreeArithmeticMixin:
Mixin to add Dataset arithmetic operations such as __add__, reduction methods such as .mean(), and enable numpy
ufuncs such as np.sin(), but wrapped to map over all nodes in the subtree.
"""
- _wrap_then_attach_to_cls(target_cls_dict=vars(), source_cls=Dataset,
- methods_to_set=_ARITHMETIC_METHODS_TO_MAP, wrap_func=map_over_subtree)
+
+ _wrap_then_attach_to_cls(
+ target_cls_dict=vars(),
+ source_cls=Dataset,
+ methods_to_set=_ARITHMETIC_METHODS_TO_MAP,
+ wrap_func=map_over_subtree,
+ )
diff --git a/xarray/core/datatree_render.py b/xarray/core/datatree_render.py
index 184464bb..f10f2540 100644
--- a/xarray/core/datatree_render.py
+++ b/xarray/core/datatree_render.py
@@ -5,17 +5,20 @@ Minor changes to `RenderDataTree` include accessing `children.values()`, and
type hints.
"""
+
from __future__ import annotations
+
from collections import namedtuple
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING
+
if TYPE_CHECKING:
from xarray.core.datatree import DataTree
-Row = namedtuple('Row', ('pre', 'fill', 'node'))
+Row = namedtuple("Row", ("pre", "fill", "node"))
-class AbstractStyle:
+class AbstractStyle:
def __init__(self, vertical: str, cont: str, end: str):
"""
Tree Render Style.
@@ -28,20 +31,20 @@ class AbstractStyle:
self.vertical = vertical
self.cont = cont
self.end = end
- assert len(cont) == len(vertical) == len(end
- ), f"'{vertical}', '{cont}' and '{end}' need to have equal length"
+ assert (
+ len(cont) == len(vertical) == len(end)
+ ), f"'{vertical}', '{cont}' and '{end}' need to have equal length"
@property
- def empty(self) ->str:
+ def empty(self) -> str:
"""Empty string as placeholder."""
- pass
+ return " " * len(self.end)
- def __repr__(self) ->str:
- return f'{self.__class__.__name__}()'
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}()"
class ContStyle(AbstractStyle):
-
def __init__(self):
"""
Continued style, without gaps.
@@ -61,13 +64,17 @@ class ContStyle(AbstractStyle):
│ └── Group: /sub0/sub0A
└── Group: /sub1
"""
- super().__init__('│ ', '├── ', '└── ')
+ super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ")
class RenderDataTree:
-
- def __init__(self, node: DataTree, style=ContStyle(), childiter: type=
- list, maxlevel: (int | None)=None):
+ def __init__(
+ self,
+ node: DataTree,
+ style=ContStyle(),
+ childiter: type = list,
+ maxlevel: int | None = None,
+ ):
"""
Render tree starting at `node`.
Keyword Args:
@@ -151,19 +158,47 @@ class RenderDataTree:
self.childiter = childiter
self.maxlevel = maxlevel
- def __iter__(self) ->Iterator[Row]:
+ def __iter__(self) -> Iterator[Row]:
return self.__next(self.node, tuple())
- def __str__(self) ->str:
+ def __next(
+ self, node: DataTree, continues: tuple[bool, ...], level: int = 0
+ ) -> Iterator[Row]:
+ yield RenderDataTree.__item(node, continues, self.style)
+ children = node.children.values()
+ level += 1
+ if children and (self.maxlevel is None or level < self.maxlevel):
+ children = self.childiter(children)
+ for child, is_last in _is_last(children):
+ yield from self.__next(child, continues + (not is_last,), level=level)
+
+ @staticmethod
+ def __item(
+ node: DataTree, continues: tuple[bool, ...], style: AbstractStyle
+ ) -> Row:
+ if not continues:
+ return Row("", "", node)
+ else:
+ items = [style.vertical if cont else style.empty for cont in continues]
+ indent = "".join(items[:-1])
+ branch = style.cont if continues[-1] else style.end
+ pre = indent + branch
+ fill = "".join(items)
+ return Row(pre, fill, node)
+
+ def __str__(self) -> str:
return str(self.node)
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
classname = self.__class__.__name__
- args = [repr(self.node), f'style={repr(self.style)}',
- f'childiter={repr(self.childiter)}']
+ args = [
+ repr(self.node),
+ f"style={repr(self.style)}",
+ f"childiter={repr(self.childiter)}",
+ ]
return f"{classname}({', '.join(args)})"
- def by_attr(self, attrname: str='name') ->str:
+ def by_attr(self, attrname: str = "name") -> str:
"""
Return rendered tree with node attribute `attrname`.
@@ -195,4 +230,38 @@ class RenderDataTree:
└── sub1C
└── sub1Ca
"""
+
+ def get() -> Iterator[str]:
+ for pre, fill, node in self:
+ attr = (
+ attrname(node)
+ if callable(attrname)
+ else getattr(node, attrname, "")
+ )
+ if isinstance(attr, (list, tuple)):
+ lines = attr
+ else:
+ lines = str(attr).split("\n")
+ yield f"{pre}{lines[0]}"
+ for line in lines[1:]:
+ yield f"{fill}{line}"
+
+ return "\n".join(get())
+
+
+def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]:
+ iter_ = iter(iterable)
+ try:
+ nextitem = next(iter_)
+ except StopIteration:
pass
+ else:
+ item = nextitem
+ while True:
+ try:
+ nextitem = next(iter_)
+ yield item, False
+ except StopIteration:
+ yield nextitem, True
+ break
+ item = nextitem
diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py
index 217e8240..2c3a43ee 100644
--- a/xarray/core/dtypes.py
+++ b/xarray/core/dtypes.py
@@ -1,15 +1,19 @@
from __future__ import annotations
+
import functools
from typing import Any
+
import numpy as np
from pandas.api.types import is_extension_array_dtype
+
from xarray.core import array_api_compat, npcompat, utils
-NA = utils.ReprObject('<NA>')
+
+# Use as a sentinel value to indicate a dtype appropriate NA value.
+NA = utils.ReprObject("<NA>")
@functools.total_ordering
class AlwaysGreaterThan:
-
def __gt__(self, other):
return True
@@ -19,7 +23,6 @@ class AlwaysGreaterThan:
@functools.total_ordering
class AlwaysLessThan:
-
def __lt__(self, other):
return True
@@ -27,13 +30,23 @@ class AlwaysLessThan:
return isinstance(other, type(self))
+# Equivalence to np.inf (-np.inf) for object-type
INF = AlwaysGreaterThan()
NINF = AlwaysLessThan()
-PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ((
- np.number, np.character), (np.bool_, np.character), (np.bytes_, np.str_))
-def maybe_promote(dtype: np.dtype) ->tuple[np.dtype, Any]:
+# Pairs of types that, if both found, should be promoted to object dtype
+# instead of following NumPy's own type-promotion rules. These type promotion
+# rules match pandas instead. For reference, see the NumPy type hierarchy:
+# https://numpy.org/doc/stable/reference/arrays.scalars.html
+PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
+ (np.number, np.character), # numpy promotes to character
+ (np.bool_, np.character), # numpy promotes to character
+ (np.bytes_, np.str_), # numpy promotes to unicode
+)
+
+
+def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
"""Simpler equivalent of pandas.core.common._maybe_promote
Parameters
@@ -45,10 +58,37 @@ def maybe_promote(dtype: np.dtype) ->tuple[np.dtype, Any]:
dtype : Promoted dtype that can hold missing values.
fill_value : Valid missing value for the promoted dtype.
"""
- pass
-
-
-NAT_TYPES = {np.datetime64('NaT').dtype, np.timedelta64('NaT').dtype}
+ # N.B. these casting rules should match pandas
+ dtype_: np.typing.DTypeLike
+ fill_value: Any
+ if isdtype(dtype, "real floating"):
+ dtype_ = dtype
+ fill_value = np.nan
+ elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64):
+ # See https://github.com/numpy/numpy/issues/10685
+ # np.timedelta64 is a subclass of np.integer
+ # Check np.timedelta64 before np.integer
+ fill_value = np.timedelta64("NaT")
+ dtype_ = dtype
+ elif isdtype(dtype, "integral"):
+ dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
+ fill_value = np.nan
+ elif isdtype(dtype, "complex floating"):
+ dtype_ = dtype
+ fill_value = np.nan + np.nan * 1j
+ elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64):
+ dtype_ = dtype
+ fill_value = np.datetime64("NaT")
+ else:
+ dtype_ = object
+ fill_value = np.nan
+
+ dtype_out = np.dtype(dtype_)
+ fill_value = dtype_out.type(fill_value)
+ return dtype_out, fill_value
+
+
+NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
def get_fill_value(dtype):
@@ -62,7 +102,8 @@ def get_fill_value(dtype):
-------
fill_value : Missing value corresponding to this dtype.
"""
- pass
+ _, fill_value = maybe_promote(dtype)
+ return fill_value
def get_pos_infinity(dtype, max_for_int=False):
@@ -78,7 +119,22 @@ def get_pos_infinity(dtype, max_for_int=False):
-------
fill_value : positive infinity value corresponding to this dtype.
"""
- pass
+ if isdtype(dtype, "real floating"):
+ return np.inf
+
+ if isdtype(dtype, "integral"):
+ if max_for_int:
+ return np.iinfo(dtype).max
+ else:
+ return np.inf
+
+ if isdtype(dtype, "complex floating"):
+ return np.inf + 1j * np.inf
+
+ if isdtype(dtype, "bool"):
+ return True
+
+ return np.array(INF, dtype=object)
def get_neg_infinity(dtype, min_for_int=False):
@@ -94,34 +150,82 @@ def get_neg_infinity(dtype, min_for_int=False):
-------
fill_value : positive infinity value corresponding to this dtype.
"""
- pass
+ if isdtype(dtype, "real floating"):
+ return -np.inf
+
+ if isdtype(dtype, "integral"):
+ if min_for_int:
+ return np.iinfo(dtype).min
+ else:
+ return -np.inf
+
+ if isdtype(dtype, "complex floating"):
+ return -np.inf - 1j * np.inf
+
+ if isdtype(dtype, "bool"):
+ return False
+ return np.array(NINF, dtype=object)
-def is_datetime_like(dtype) ->bool:
+
+def is_datetime_like(dtype) -> bool:
"""Check if a dtype is a subclass of the numpy datetime types"""
- pass
+ return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64))
-def is_object(dtype) ->bool:
+def is_object(dtype) -> bool:
"""Check if a dtype is object"""
- pass
+ return _is_numpy_subdtype(dtype, object)
-def is_string(dtype) ->bool:
+def is_string(dtype) -> bool:
"""Check if a dtype is a string dtype"""
- pass
+ return _is_numpy_subdtype(dtype, (np.str_, np.character))
+
+
+def _is_numpy_subdtype(dtype, kind) -> bool:
+ if not isinstance(dtype, np.dtype):
+ return False
+ kinds = kind if isinstance(kind, tuple) else (kind,)
+ return any(np.issubdtype(dtype, kind) for kind in kinds)
-def isdtype(dtype, kind: (str | tuple[str, ...]), xp=None) ->bool:
+
+def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
"""Compatibility wrapper for isdtype() from the array API standard.
Unlike xp.isdtype(), kind must be a string.
"""
- pass
-
-
-def result_type(*arrays_and_dtypes: (np.typing.ArrayLike | np.typing.
- DTypeLike), xp=None) ->np.dtype:
+ # TODO(shoyer): remove this wrapper when Xarray requires
+ # numpy>=2 and pandas extensions arrays are implemented in
+ # Xarray via the array API
+ if not isinstance(kind, str) and not (
+ isinstance(kind, tuple) and all(isinstance(k, str) for k in kind)
+ ):
+ raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
+
+ if isinstance(dtype, np.dtype):
+ return npcompat.isdtype(dtype, kind)
+ elif is_extension_array_dtype(dtype):
+ # we never want to match pandas extension array dtypes
+ return False
+ else:
+ if xp is None:
+ xp = np
+ return xp.isdtype(dtype, kind)
+
+
+def preprocess_scalar_types(t):
+ if isinstance(t, (str, bytes)):
+ return type(t)
+ else:
+ return t
+
+
+def result_type(
+ *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
+ xp=None,
+) -> np.dtype:
"""Like np.result_type, but with type promotion rules matching pandas.
Examples of changed behavior:
@@ -137,4 +241,26 @@ def result_type(*arrays_and_dtypes: (np.typing.ArrayLike | np.typing.
-------
numpy.dtype for the result.
"""
- pass
+ # TODO (keewis): replace `array_api_compat.result_type` with `xp.result_type` once we
+ # can require a version of the Array API that supports passing scalars to it.
+ from xarray.core.duck_array_ops import get_array_namespace
+
+ if xp is None:
+ xp = get_array_namespace(arrays_and_dtypes)
+
+ types = {
+ array_api_compat.result_type(preprocess_scalar_types(t), xp=xp)
+ for t in arrays_and_dtypes
+ }
+ if any(isinstance(t, np.dtype) for t in types):
+ # only check if there's numpy dtypes – the array API does not
+ # define the types we're checking for
+ for left, right in PROMOTE_TO_OBJECT:
+ if any(np.issubdtype(t, left) for t in types) and any(
+ np.issubdtype(t, right) for t in types
+ ):
+ return np.dtype(object)
+
+ return array_api_compat.result_type(
+ *map(preprocess_scalar_types, arrays_and_dtypes), xp=xp
+ )
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
index 872667f1..8993c136 100644
--- a/xarray/core/duck_array_ops.py
+++ b/xarray/core/duck_array_ops.py
@@ -3,63 +3,290 @@
Currently, this means Dask or NumPy arrays. None of these functions should
accept or return xarray objects.
"""
+
from __future__ import annotations
+
import contextlib
import datetime
import inspect
import warnings
from functools import partial
from importlib import import_module
+
import numpy as np
import pandas as pd
-from numpy import all as array_all
-from numpy import any as array_any
-from numpy import around, full_like, gradient, isclose, isin, isnat, take, tensordot, transpose, unravel_index
+from numpy import all as array_all # noqa
+from numpy import any as array_any # noqa
+from numpy import ( # noqa
+ around, # noqa
+ full_like,
+ gradient,
+ isclose,
+ isin,
+ isnat,
+ take,
+ tensordot,
+ transpose,
+ unravel_index,
+)
from numpy import concatenate as _concatenate
-from numpy.lib.stride_tricks import sliding_window_view
+from numpy.lib.stride_tricks import sliding_window_view # noqa
from packaging.version import Version
from pandas.api.types import is_extension_array_dtype
+
from xarray.core import dask_array_ops, dtypes, nputils
from xarray.core.options import OPTIONS
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available
from xarray.namedarray import pycompat
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import array_type, is_chunked_array
-if module_available('numpy', minversion='2.0.0.dev0'):
- from numpy.lib.array_utils import normalize_axis_index
+
+# remove once numpy 2.0 is the oldest supported version
+if module_available("numpy", minversion="2.0.0.dev0"):
+ from numpy.lib.array_utils import ( # type: ignore[import-not-found,unused-ignore]
+ normalize_axis_index,
+ )
else:
- from numpy.core.multiarray import normalize_axis_index
-dask_available = module_available('dask')
+ from numpy.core.multiarray import ( # type: ignore[attr-defined,no-redef,unused-ignore]
+ normalize_axis_index,
+ )
+
+
+dask_available = module_available("dask")
+
+
+def get_array_namespace(*values):
+ def _get_array_namespace(x):
+ if hasattr(x, "__array_namespace__"):
+ return x.__array_namespace__()
+ else:
+ return np
+
+ namespaces = {_get_array_namespace(t) for t in values}
+ non_numpy = namespaces - {np}
+
+ if len(non_numpy) > 1:
+ raise TypeError(
+ "cannot deal with more than one type supporting the array API at the same time"
+ )
+ elif non_numpy:
+ [xp] = non_numpy
+ else:
+ xp = np
+
+ return xp
+
+
+def einsum(*args, **kwargs):
+ from xarray.core.options import OPTIONS
+
+ if OPTIONS["use_opt_einsum"] and module_available("opt_einsum"):
+ import opt_einsum
+ return opt_einsum.contract(*args, **kwargs)
+ else:
+ return np.einsum(*args, **kwargs)
-def _dask_or_eager_func(name, eager_module=np, dask_module='dask.array'):
+
+def _dask_or_eager_func(
+ name,
+ eager_module=np,
+ dask_module="dask.array",
+):
"""Create a function that dispatches to dask for dask array inputs."""
- pass
-
-
-pandas_isnull = _dask_or_eager_func('isnull', eager_module=pd, dask_module=
- 'dask.array')
-around.__doc__ = str.replace(around.__doc__ or '', 'array([0., 2.])',
- 'array([0., 2.])')
-around.__doc__ = str.replace(around.__doc__ or '', 'array([0., 2.])',
- 'array([0., 2.])')
-around.__doc__ = str.replace(around.__doc__ or '', 'array([0.4, 1.6])',
- 'array([0.4, 1.6])')
-around.__doc__ = str.replace(around.__doc__ or '',
- 'array([0., 2., 2., 4., 4.])', 'array([0., 2., 2., 4., 4.])')
-around.__doc__ = str.replace(around.__doc__ or '',
- """ .. [2] "How Futile are Mindless Assessments of
- Roundoff in Floating-Point Computation?", William Kahan,
- https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf
-"""
- , '')
-masked_invalid = _dask_or_eager_func('masked_invalid', eager_module=np.ma,
- dask_module='dask.array.ma')
+
+ def f(*args, **kwargs):
+ if any(is_duck_dask_array(a) for a in args):
+ mod = (
+ import_module(dask_module)
+ if isinstance(dask_module, str)
+ else dask_module
+ )
+ wrapped = getattr(mod, name)
+ else:
+ wrapped = getattr(eager_module, name)
+ return wrapped(*args, **kwargs)
+
+ return f
+
+
+def fail_on_dask_array_input(values, msg=None, func_name=None):
+ if is_duck_dask_array(values):
+ if msg is None:
+ msg = "%r is not yet a valid method on dask arrays"
+ if func_name is None:
+ func_name = inspect.stack()[1][3]
+ raise NotImplementedError(msg % func_name)
+
+
+# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
+pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module="dask.array")
+
+# np.around has failing doctests, overwrite it so they pass:
+# https://github.com/numpy/numpy/issues/19759
+around.__doc__ = str.replace(
+ around.__doc__ or "",
+ "array([0., 2.])",
+ "array([0., 2.])",
+)
+around.__doc__ = str.replace(
+ around.__doc__ or "",
+ "array([0., 2.])",
+ "array([0., 2.])",
+)
+around.__doc__ = str.replace(
+ around.__doc__ or "",
+ "array([0.4, 1.6])",
+ "array([0.4, 1.6])",
+)
+around.__doc__ = str.replace(
+ around.__doc__ or "",
+ "array([0., 2., 2., 4., 4.])",
+ "array([0., 2., 2., 4., 4.])",
+)
+around.__doc__ = str.replace(
+ around.__doc__ or "",
+ (
+ ' .. [2] "How Futile are Mindless Assessments of\n'
+ ' Roundoff in Floating-Point Computation?", William Kahan,\n'
+ " https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n"
+ ),
+ "",
+)
+
+
+def isnull(data):
+ data = asarray(data)
+
+ xp = get_array_namespace(data)
+ scalar_type = data.dtype
+ if dtypes.is_datetime_like(scalar_type):
+ # datetime types use NaT for null
+ # note: must check timedelta64 before integers, because currently
+ # timedelta64 inherits from np.integer
+ return isnat(data)
+ elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
+ # float types use NaN for null
+ xp = get_array_namespace(data)
+ return xp.isnan(data)
+ elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or (
+ isinstance(scalar_type, np.dtype)
+ and (
+ np.issubdtype(scalar_type, np.character)
+ or np.issubdtype(scalar_type, np.void)
+ )
+ ):
+ # these types cannot represent missing values
+ return full_like(data, dtype=bool, fill_value=False)
+ else:
+ # at this point, array should have dtype=object
+ if isinstance(data, np.ndarray) or is_extension_array_dtype(data):
+ return pandas_isnull(data)
+ else:
+ # Not reachable yet, but intended for use with other duck array
+ # types. For full consistency with pandas, we should accept None as
+ # a null value as well as NaN, but it isn't clear how to do this
+ # with duck typing.
+ return data != data
+
+
+def notnull(data):
+ return ~isnull(data)
+
+
+# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
+masked_invalid = _dask_or_eager_func(
+ "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma"
+)
+
+
+def trapz(y, x, axis):
+ if axis < 0:
+ axis = y.ndim + axis
+ x_sl1 = (slice(1, None),) + (None,) * (y.ndim - axis - 1)
+ x_sl2 = (slice(None, -1),) + (None,) * (y.ndim - axis - 1)
+ slice1 = (slice(None),) * axis + (slice(1, None),)
+ slice2 = (slice(None),) * axis + (slice(None, -1),)
+ dx = x[x_sl1] - x[x_sl2]
+ integrand = dx * 0.5 * (y[tuple(slice1)] + y[tuple(slice2)])
+ return sum(integrand, axis=axis, skipna=False)
+
+
+def cumulative_trapezoid(y, x, axis):
+ if axis < 0:
+ axis = y.ndim + axis
+ x_sl1 = (slice(1, None),) + (None,) * (y.ndim - axis - 1)
+ x_sl2 = (slice(None, -1),) + (None,) * (y.ndim - axis - 1)
+ slice1 = (slice(None),) * axis + (slice(1, None),)
+ slice2 = (slice(None),) * axis + (slice(None, -1),)
+ dx = x[x_sl1] - x[x_sl2]
+ integrand = dx * 0.5 * (y[tuple(slice1)] + y[tuple(slice2)])
+
+ # Pad so that 'axis' has same length in result as it did in y
+ pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)]
+ integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0)
+
+ return cumsum(integrand, axis=axis, skipna=False)
+
+
+def astype(data, dtype, **kwargs):
+ if hasattr(data, "__array_namespace__"):
+ xp = get_array_namespace(data)
+ if xp == np:
+ # numpy currently doesn't have a astype:
+ return data.astype(dtype, **kwargs)
+ return xp.astype(data, dtype, **kwargs)
+ return data.astype(dtype, **kwargs)
+
+
+def asarray(data, xp=np, dtype=None):
+ converted = data if is_duck_array(data) else xp.asarray(data)
+
+ if dtype is None or converted.dtype == dtype:
+ return converted
+
+ if xp is np or not hasattr(xp, "astype"):
+ return converted.astype(dtype)
+ else:
+ return xp.astype(converted, dtype)
def as_shared_dtype(scalars_or_arrays, xp=None):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
- pass
+ if any(is_extension_array_dtype(x) for x in scalars_or_arrays):
+ extension_array_types = [
+ x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
+ ]
+ if len(extension_array_types) == len(scalars_or_arrays) and all(
+ isinstance(x, type(extension_array_types[0])) for x in extension_array_types
+ ):
+ return scalars_or_arrays
+ raise ValueError(
+ "Cannot cast arrays to shared type, found"
+ f" array types {[x.dtype for x in scalars_or_arrays]}"
+ )
+
+ # Avoid calling array_type("cupy") repeatidely in the any check
+ array_type_cupy = array_type("cupy")
+ if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays):
+ import cupy as cp
+
+ xp = cp
+ elif xp is None:
+ xp = get_array_namespace(scalars_or_arrays)
+
+ # Pass arrays directly instead of dtypes to result_type so scalars
+ # get handled properly.
+ # Note that result_type() safely gets the dtype from dask arrays without
+ # evaluating them.
+ dtype = dtypes.result_type(*scalars_or_arrays, xp=xp)
+
+ return [asarray(x, dtype=dtype, xp=xp) for x in scalars_or_arrays]
+
+
+def broadcast_to(array, shape):
+ xp = get_array_namespace(array)
+ return xp.broadcast_to(array, shape)
def lazy_array_equiv(arr1, arr2):
@@ -69,67 +296,219 @@ def lazy_array_equiv(arr1, arr2):
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
or their dask tokens are not equal
"""
- pass
-
-
-def allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08):
+ if arr1 is arr2:
+ return True
+ arr1 = asarray(arr1)
+ arr2 = asarray(arr2)
+ if arr1.shape != arr2.shape:
+ return False
+ if dask_available and is_duck_dask_array(arr1) and is_duck_dask_array(arr2):
+ from dask.base import tokenize
+
+ # GH3068, GH4221
+ if tokenize(arr1) == tokenize(arr2):
+ return True
+ else:
+ return None
+ return None
+
+
+def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
"""Like np.allclose, but also allows values to be NaN in both arrays"""
- pass
+ arr1 = asarray(arr1)
+ arr2 = asarray(arr2)
+
+ lazy_equiv = lazy_array_equiv(arr1, arr2)
+ if lazy_equiv is None:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
+ return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
+ else:
+ return lazy_equiv
def array_equiv(arr1, arr2):
"""Like np.array_equal, but also allows values to be NaN in both arrays"""
- pass
+ arr1 = asarray(arr1)
+ arr2 = asarray(arr2)
+ lazy_equiv = lazy_array_equiv(arr1, arr2)
+ if lazy_equiv is None:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
+ flag_array = (arr1 == arr2) | (isnull(arr1) & isnull(arr2))
+ return bool(flag_array.all())
+ else:
+ return lazy_equiv
def array_notnull_equiv(arr1, arr2):
"""Like np.array_equal, but also allows values to be NaN in either or both
arrays
"""
- pass
+ arr1 = asarray(arr1)
+ arr2 = asarray(arr2)
+ lazy_equiv = lazy_array_equiv(arr1, arr2)
+ if lazy_equiv is None:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "In the future, 'NAT == x'")
+ flag_array = (arr1 == arr2) | isnull(arr1) | isnull(arr2)
+ return bool(flag_array.all())
+ else:
+ return lazy_equiv
def count(data, axis=None):
"""Count the number of non-NA in this array along the given axis or axes"""
- pass
+ return np.sum(np.logical_not(isnull(data)), axis=axis)
+
+
+def sum_where(data, axis=None, dtype=None, where=None):
+ xp = get_array_namespace(data)
+ if where is not None:
+ a = where_method(xp.zeros_like(data), where, data)
+ else:
+ a = data
+ result = xp.sum(a, axis=axis, dtype=dtype)
+ return result
def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
- pass
+ xp = get_array_namespace(condition)
+ return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
+
+
+def where_method(data, cond, other=dtypes.NA):
+ if other is dtypes.NA:
+ other = dtypes.get_fill_value(data.dtype)
+ return where(cond, data, other)
+
+
+def fillna(data, other):
+ # we need to pass data first so pint has a chance of returning the
+ # correct unit
+ # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed
+ return where(notnull(data), data, other)
def concatenate(arrays, axis=0):
"""concatenate() with better dtype promotion rules."""
- pass
+ # TODO: remove the additional check once `numpy` adds `concat` to its array namespace
+ if hasattr(arrays[0], "__array_namespace__") and not isinstance(
+ arrays[0], np.ndarray
+ ):
+ xp = get_array_namespace(arrays[0])
+ return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
+ return _concatenate(as_shared_dtype(arrays), axis=axis)
def stack(arrays, axis=0):
"""stack() with better dtype promotion rules."""
- pass
-
-
-argmax = _create_nan_agg_method('argmax', coerce_strings=True)
-argmin = _create_nan_agg_method('argmin', coerce_strings=True)
-max = _create_nan_agg_method('max', coerce_strings=True, invariant_0d=True)
-min = _create_nan_agg_method('min', coerce_strings=True, invariant_0d=True)
-sum = _create_nan_agg_method('sum', invariant_0d=True)
+ xp = get_array_namespace(arrays[0])
+ return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)
+
+
+def reshape(array, shape):
+ xp = get_array_namespace(array)
+ return xp.reshape(array, shape)
+
+
+def ravel(array):
+ return reshape(array, (-1,))
+
+
+@contextlib.contextmanager
+def _ignore_warnings_if(condition):
+ if condition:
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ yield
+ else:
+ yield
+
+
+def _create_nan_agg_method(name, coerce_strings=False, invariant_0d=False):
+ from xarray.core import nanops
+
+ def f(values, axis=None, skipna=None, **kwargs):
+ if kwargs.pop("out", None) is not None:
+ raise TypeError(f"`out` is not valid for {name}")
+
+ # The data is invariant in the case of 0d data, so do not
+ # change the data (and dtype)
+ # See https://github.com/pydata/xarray/issues/4885
+ if invariant_0d and axis == ():
+ return values
+
+ xp = get_array_namespace(values)
+ values = asarray(values, xp=xp)
+
+ if coerce_strings and dtypes.is_string(values.dtype):
+ values = astype(values, object)
+
+ func = None
+ if skipna or (
+ skipna is None
+ and (
+ dtypes.isdtype(
+ values.dtype, ("complex floating", "real floating"), xp=xp
+ )
+ or dtypes.is_object(values.dtype)
+ )
+ ):
+ nanname = "nan" + name
+ func = getattr(nanops, nanname)
+ else:
+ if name in ["sum", "prod"]:
+ kwargs.pop("min_count", None)
+
+ xp = get_array_namespace(values)
+ func = getattr(xp, name)
+
+ try:
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "All-NaN slice encountered")
+ return func(values, axis=axis, **kwargs)
+ except AttributeError:
+ if not is_duck_dask_array(values):
+ raise
+ try: # dask/dask#3133 dask sometimes needs dtype argument
+ # if func does not accept dtype, then raises TypeError
+ return func(values, axis=axis, dtype=values.dtype, **kwargs)
+ except (AttributeError, TypeError):
+ raise NotImplementedError(
+ f"{name} is not yet implemented on dask arrays"
+ )
+
+ f.__name__ = name
+ return f
+
+
+# Attributes `numeric_only`, `available_min_count` is used for docs.
+# See ops.inject_reduce_methods
+argmax = _create_nan_agg_method("argmax", coerce_strings=True)
+argmin = _create_nan_agg_method("argmin", coerce_strings=True)
+max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True)
+min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True)
+sum = _create_nan_agg_method("sum", invariant_0d=True)
sum.numeric_only = True
sum.available_min_count = True
-std = _create_nan_agg_method('std')
+std = _create_nan_agg_method("std")
std.numeric_only = True
-var = _create_nan_agg_method('var')
+var = _create_nan_agg_method("var")
var.numeric_only = True
-median = _create_nan_agg_method('median', invariant_0d=True)
+median = _create_nan_agg_method("median", invariant_0d=True)
median.numeric_only = True
-prod = _create_nan_agg_method('prod', invariant_0d=True)
+prod = _create_nan_agg_method("prod", invariant_0d=True)
prod.numeric_only = True
prod.available_min_count = True
-cumprod_1d = _create_nan_agg_method('cumprod', invariant_0d=True)
+cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True)
cumprod_1d.numeric_only = True
-cumsum_1d = _create_nan_agg_method('cumsum', invariant_0d=True)
+cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True)
cumsum_1d.numeric_only = True
-_mean = _create_nan_agg_method('mean', invariant_0d=True)
+
+
+_mean = _create_nan_agg_method("mean", invariant_0d=True)
def _datetime_nanmin(array):
@@ -141,7 +520,15 @@ def _datetime_nanmin(array):
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
- dask min() does not work on datetime64 (all versions at the moment of writing)
"""
- pass
+ dtype = array.dtype
+ assert dtypes.is_datetime_like(dtype)
+ # (NaT).astype(float) does not produce NaN...
+ array = where(pandas_isnull(array), np.nan, array.astype(float))
+ array = min(array, skipna=True)
+ if isinstance(array, float):
+ array = np.array(array)
+ # ...but (NaN).astype("M8") does produce NaT
+ return array.astype(dtype)
def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
@@ -169,10 +556,42 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
though some calendars would allow for them (e.g. no_leap). This is because there
is no `cftime.timedelta` object.
"""
- pass
-
-
-def timedelta_to_numeric(value, datetime_unit='ns', dtype=float):
+ # Set offset to minimum if not given
+ if offset is None:
+ if dtypes.is_datetime_like(array.dtype):
+ offset = _datetime_nanmin(array)
+ else:
+ offset = min(array)
+
+ # Compute timedelta object.
+ # For np.datetime64, this can silently yield garbage due to overflow.
+ # One option is to enforce 1970-01-01 as the universal offset.
+
+ # This map_blocks call is for backwards compatibility.
+ # dask == 2021.04.1 does not support subtracting object arrays
+ # which is required for cftime
+ if is_duck_dask_array(array) and dtypes.is_object(array.dtype):
+ array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
+ else:
+ array = array - offset
+
+ # Scalar is converted to 0d-array
+ if not hasattr(array, "dtype"):
+ array = np.array(array)
+
+ # Convert timedelta objects to float by first converting to microseconds.
+ if dtypes.is_object(array.dtype):
+ return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype)
+
+ # Convert np.NaT to np.nan
+ elif dtypes.is_datetime_like(array.dtype):
+ # Convert to specified timedelta units.
+ if datetime_unit:
+ array = array / np.timedelta64(1, datetime_unit)
+ return np.where(isnull(array), np.nan, array.astype(dtype))
+
+
+def timedelta_to_numeric(value, datetime_unit="ns", dtype=float):
"""Convert a timedelta-like object to numerical values.
Parameters
@@ -186,7 +605,32 @@ def timedelta_to_numeric(value, datetime_unit='ns', dtype=float):
The output data type.
"""
- pass
+ import datetime as dt
+
+ if isinstance(value, dt.timedelta):
+ out = py_timedelta_to_float(value, datetime_unit)
+ elif isinstance(value, np.timedelta64):
+ out = np_timedelta64_to_float(value, datetime_unit)
+ elif isinstance(value, pd.Timedelta):
+ out = pd_timedelta_to_float(value, datetime_unit)
+ elif isinstance(value, str):
+ try:
+ a = pd.to_timedelta(value)
+ except ValueError:
+ raise ValueError(
+ f"Could not convert {value!r} to timedelta64 using pandas.to_timedelta"
+ )
+ return py_timedelta_to_float(a, datetime_unit)
+ else:
+ raise TypeError(
+ f"Expected value of type str, pandas.Timedelta, datetime.timedelta "
+ f"or numpy.timedelta64, but received {type(value).__name__}"
+ )
+ return out.astype(dtype)
+
+
+def _to_pytimedelta(array, unit="us"):
+ return array.astype(f"timedelta64[{unit}]").astype(datetime.timedelta)
def np_timedelta64_to_float(array, datetime_unit):
@@ -197,7 +641,9 @@ def np_timedelta64_to_float(array, datetime_unit):
The array is first converted to microseconds, which is less likely to
cause overflow errors.
"""
- pass
+ array = array.astype("timedelta64[ns]").astype(np.float64)
+ conversion_factor = np.timedelta64(1, "ns") / np.timedelta64(1, datetime_unit)
+ return conversion_factor * array
def pd_timedelta_to_float(value, datetime_unit):
@@ -208,50 +654,180 @@ def pd_timedelta_to_float(value, datetime_unit):
Built on the assumption that pandas timedelta values are in nanoseconds,
which is also the numpy default resolution.
"""
- pass
+ value = value.to_timedelta64()
+ return np_timedelta64_to_float(value, datetime_unit)
+
+
+def _timedelta_to_seconds(array):
+ if isinstance(array, datetime.timedelta):
+ return array.total_seconds() * 1e6
+ else:
+ return np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6
def py_timedelta_to_float(array, datetime_unit):
"""Convert a timedelta object to a float, possibly at a loss of resolution."""
- pass
+ array = asarray(array)
+ if is_duck_dask_array(array):
+ array = array.map_blocks(
+ _timedelta_to_seconds, meta=np.array([], dtype=np.float64)
+ )
+ else:
+ array = _timedelta_to_seconds(array)
+ conversion_factor = np.timedelta64(1, "us") / np.timedelta64(1, datetime_unit)
+ return conversion_factor * array
def mean(array, axis=None, skipna=None, **kwargs):
"""inhouse mean that can handle np.datetime64 or cftime.datetime
dtypes"""
- pass
+ from xarray.core.common import _contains_cftime_datetimes
+
+ array = asarray(array)
+ if dtypes.is_datetime_like(array.dtype):
+ offset = _datetime_nanmin(array)
+ # xarray always uses np.datetime64[ns] for np.datetime64 data
+ dtype = "timedelta64[ns]"
+ return (
+ _mean(
+ datetime_to_numeric(array, offset), axis=axis, skipna=skipna, **kwargs
+ ).astype(dtype)
+ + offset
+ )
+ elif _contains_cftime_datetimes(array):
+ offset = min(array)
+ timedeltas = datetime_to_numeric(array, offset, datetime_unit="us")
+ mean_timedeltas = _mean(timedeltas, axis=axis, skipna=skipna, **kwargs)
+ return _to_pytimedelta(mean_timedeltas, unit="us") + offset
+ else:
+ return _mean(array, axis=axis, skipna=skipna, **kwargs)
-mean.numeric_only = True
+
+mean.numeric_only = True # type: ignore[attr-defined]
+
+
+def _nd_cum_func(cum_func, array, axis, **kwargs):
+ array = asarray(array)
+ if axis is None:
+ axis = tuple(range(array.ndim))
+ if isinstance(axis, int):
+ axis = (axis,)
+
+ out = array
+ for ax in axis:
+ out = cum_func(out, axis=ax, **kwargs)
+ return out
def cumprod(array, axis=None, **kwargs):
"""N-dimensional version of cumprod."""
- pass
+ return _nd_cum_func(cumprod_1d, array, axis, **kwargs)
def cumsum(array, axis=None, **kwargs):
"""N-dimensional version of cumsum."""
- pass
+ return _nd_cum_func(cumsum_1d, array, axis, **kwargs)
def first(values, axis, skipna=None):
"""Return the first non-NA elements in this array along the given axis"""
- pass
+ if (skipna or skipna is None) and not (
+ dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
+ ):
+ # only bother for dtypes that can hold NaN
+ if is_chunked_array(values):
+ return chunked_nanfirst(values, axis)
+ else:
+ return nputils.nanfirst(values, axis)
+ return take(values, 0, axis=axis)
def last(values, axis, skipna=None):
"""Return the last non-NA elements in this array along the given axis"""
- pass
+ if (skipna or skipna is None) and not (
+ dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
+ ):
+ # only bother for dtypes that can hold NaN
+ if is_chunked_array(values):
+ return chunked_nanlast(values, axis)
+ else:
+ return nputils.nanlast(values, axis)
+ return take(values, -1, axis=axis)
def least_squares(lhs, rhs, rcond=None, skipna=False):
"""Return the coefficients and residuals of a least-squares fit."""
- pass
+ if is_duck_dask_array(rhs):
+ return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
+ else:
+ return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
-def _push(array, n: (int | None)=None, axis: int=-1):
+def _push(array, n: int | None = None, axis: int = -1):
"""
Use either bottleneck or numbagg depending on options & what's available
"""
- pass
+
+ if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
+ raise RuntimeError(
+ "ffill & bfill requires bottleneck or numbagg to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
+ )
+ if OPTIONS["use_numbagg"] and module_available("numbagg"):
+ import numbagg
+
+ if pycompat.mod_version("numbagg") < Version("0.6.2"):
+ warnings.warn(
+ f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead."
+ )
+ else:
+ return numbagg.ffill(array, limit=n, axis=axis)
+
+ # work around for bottleneck 178
+ limit = n if n is not None else array.shape[axis]
+
+ import bottleneck as bn
+
+ return bn.push(array, limit, axis)
+
+
+def push(array, n, axis):
+ if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]:
+ raise RuntimeError(
+ "ffill & bfill requires bottleneck or numbagg to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one."
+ )
+ if is_duck_dask_array(array):
+ return dask_array_ops.push(array, n, axis)
+ else:
+ return _push(array, n, axis)
+
+
+def _first_last_wrapper(array, *, axis, op, keepdims):
+ return op(array, axis, keepdims=keepdims)
+
+
+def _chunked_first_or_last(darray, axis, op):
+ chunkmanager = get_chunked_array_type(darray)
+
+ # This will raise the same error message seen for numpy
+ axis = normalize_axis_index(axis, darray.ndim)
+
+ wrapped_op = partial(_first_last_wrapper, op=op)
+ return chunkmanager.reduction(
+ darray,
+ func=wrapped_op,
+ aggregate_func=wrapped_op,
+ axis=axis,
+ dtype=darray.dtype,
+ keepdims=False, # match numpy version
+ )
+
+
+def chunked_nanfirst(darray, axis):
+ return _chunked_first_or_last(darray, axis, op=nputils.nanfirst)
+
+
+def chunked_nanlast(darray, axis):
+ return _chunked_first_or_last(darray, axis, op=nputils.nanlast)
diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py
index cf1511fe..b0361ef0 100644
--- a/xarray/core/extension_array.py
+++ b/xarray/core/extension_array.py
@@ -1,16 +1,65 @@
from __future__ import annotations
+
from collections.abc import Sequence
from typing import Callable, Generic, cast
+
import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype
+
from xarray.core.types import DTypeLikeSave, T_ExtensionArray
+
HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {}
def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray objects."""
- pass
+
+ def decorator(func):
+ HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func
+ return func
+
+ return decorator
+
+
+@implements(np.issubdtype)
+def __extension_duck_array__issubdtype(
+ extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave
+) -> bool:
+ return False # never want a function to think a pandas extension dtype is a subtype of numpy
+
+
+@implements(np.broadcast_to)
+def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple):
+ if shape[0] == len(arr) and len(shape) == 1:
+ return arr
+ raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.")
+
+
+@implements(np.stack)
+def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int):
+ raise NotImplementedError("Cannot stack 1d-only pandas categorical array.")
+
+
+@implements(np.concatenate)
+def __extension_duck_array__concatenate(
+ arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None
+) -> T_ExtensionArray:
+ return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined]
+
+
+@implements(np.where)
+def __extension_duck_array__where(
+ condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray
+) -> T_ExtensionArray:
+ if (
+ isinstance(x, pd.Categorical)
+ and isinstance(y, pd.Categorical)
+ and x.dtype != y.dtype
+ ):
+ x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment]
+ y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment]
+ return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array)
class PandasExtensionArray(Generic[T_ExtensionArray]):
@@ -26,23 +75,25 @@ class PandasExtensionArray(Generic[T_ExtensionArray]):
```
"""
if not isinstance(array, pd.api.extensions.ExtensionArray):
- raise TypeError(f'{array} is not an pandas ExtensionArray.')
+ raise TypeError(f"{array} is not an pandas ExtensionArray.")
self.array = array
def __array_function__(self, func, types, args, kwargs):
-
- def replace_duck_with_extension_array(args) ->list:
+ def replace_duck_with_extension_array(args) -> list:
args_as_list = list(args)
for index, value in enumerate(args_as_list):
if isinstance(value, PandasExtensionArray):
args_as_list[index] = value.array
- elif isinstance(value, tuple):
+ elif isinstance(
+ value, tuple
+ ): # should handle more than just tuple? iterable?
args_as_list[index] = tuple(
- replace_duck_with_extension_array(value))
+ replace_duck_with_extension_array(value)
+ )
elif isinstance(value, list):
- args_as_list[index] = replace_duck_with_extension_array(
- value)
+ args_as_list[index] = replace_duck_with_extension_array(value)
return args_as_list
+
args = tuple(replace_duck_with_extension_array(args))
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
return func(*args, **kwargs)
@@ -55,17 +106,17 @@ class PandasExtensionArray(Generic[T_ExtensionArray]):
return ufunc(*inputs, **kwargs)
def __repr__(self):
- return f'{type(self)}(array={repr(self.array)})'
+ return f"{type(self)}(array={repr(self.array)})"
- def __getattr__(self, attr: str) ->object:
+ def __getattr__(self, attr: str) -> object:
return getattr(self.array, attr)
- def __getitem__(self, key) ->PandasExtensionArray[T_ExtensionArray]:
+ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]:
item = self.array[key]
if is_extension_array_dtype(item):
return type(self)(item)
if np.isscalar(item):
- return type(self)(type(self.array)([item]))
+ return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed
return item
def __setitem__(self, key, val):
diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py
index c7897a61..9ebbd564 100644
--- a/xarray/core/extensions.py
+++ b/xarray/core/extensions.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+
import warnings
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
@@ -18,23 +20,48 @@ class _CachedAccessor:
def __get__(self, obj, cls):
if obj is None:
+ # we're accessing the attribute of the class, i.e., Dataset.geo
return self._accessor
+
+ # Use the same dict as @pandas.util.cache_readonly.
+ # It must be explicitly declared in obj.__slots__.
try:
cache = obj._cache
except AttributeError:
cache = obj._cache = {}
+
try:
return cache[self._name]
except KeyError:
pass
+
try:
accessor_obj = self._accessor(obj)
except AttributeError:
- raise RuntimeError(f'error initializing {self._name!r} accessor.')
+ # __getattr__ on data object will swallow any AttributeErrors
+ # raised when initializing the accessor, so we need to raise as
+ # something else (GH933):
+ raise RuntimeError(f"error initializing {self._name!r} accessor.")
+
cache[self._name] = accessor_obj
return accessor_obj
+def _register_accessor(name, cls):
+ def decorator(accessor):
+ if hasattr(cls, name):
+ warnings.warn(
+ f"registration of accessor {accessor!r} under name {name!r} for type {cls!r} is "
+ "overriding a preexisting attribute with the same name.",
+ AccessorRegistrationWarning,
+ stacklevel=2,
+ )
+ setattr(cls, name, _CachedAccessor(name, accessor))
+ return accessor
+
+ return decorator
+
+
def register_dataarray_accessor(name):
"""Register a custom accessor on xarray.DataArray objects.
@@ -48,7 +75,7 @@ def register_dataarray_accessor(name):
--------
register_dataset_accessor
"""
- pass
+ return _register_accessor(name, DataArray)
def register_dataset_accessor(name):
@@ -94,7 +121,7 @@ def register_dataset_accessor(name):
--------
register_dataarray_accessor
"""
- pass
+ return _register_accessor(name, Dataset)
def register_datatree_accessor(name):
@@ -111,4 +138,4 @@ def register_datatree_accessor(name):
xarray.register_dataarray_accessor
xarray.register_dataset_accessor
"""
- pass
+ return _register_accessor(name, DataTree)
diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py
index 8c1f8e0c..6571b288 100644
--- a/xarray/core/formatting.py
+++ b/xarray/core/formatting.py
@@ -1,6 +1,8 @@
"""String formatting routines for __repr__.
"""
+
from __future__ import annotations
+
import contextlib
import functools
import math
@@ -11,9 +13,11 @@ from itertools import chain, zip_longest
from reprlib import recursive_repr
from textwrap import dedent
from typing import TYPE_CHECKING
+
import numpy as np
import pandas as pd
from pandas.errors import OutOfBoundsDatetime
+
from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
@@ -21,10 +25,12 @@ from xarray.core.iterators import LevelOrderIter
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.utils import is_duck_array
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy
+
if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
from xarray.core.datatree import DataTree
-UNITS = 'B', 'kB', 'MB', 'GB', 'TB', 'PB', 'EB', 'ZB', 'YB'
+
+UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
def pretty_print(x, numchars: int):
@@ -32,63 +38,235 @@ def pretty_print(x, numchars: int):
that it is numchars long, padding with trailing spaces or truncating with
ellipses as necessary
"""
- pass
+ s = maybe_truncate(x, numchars)
+ return s + " " * max(numchars - len(s), 0)
+
+
+def maybe_truncate(obj, maxlen=500):
+ s = str(obj)
+ if len(s) > maxlen:
+ s = s[: (maxlen - 3)] + "..."
+ return s
+
+
+def wrap_indent(text, start="", length=None):
+ if length is None:
+ length = len(start)
+ indent = "\n" + " " * length
+ return start + indent.join(x for x in text.splitlines())
+
+
+def _get_indexer_at_least_n_items(shape, n_desired, from_end):
+ assert 0 < n_desired <= math.prod(shape)
+ cum_items = np.cumprod(shape[::-1])
+ n_steps = np.argmax(cum_items >= n_desired)
+ stop = math.ceil(float(n_desired) / np.r_[1, cum_items][n_steps])
+ indexer = (
+ ((-1 if from_end else 0),) * (len(shape) - 1 - n_steps)
+ + ((slice(-stop, None) if from_end else slice(stop)),)
+ + (slice(None),) * n_steps
+ )
+ return indexer
def first_n_items(array, n_desired):
"""Returns the first n_desired items of an array"""
- pass
+ # Unfortunately, we can't just do array.flat[:n_desired] here because it
+ # might not be a numpy.ndarray. Moreover, access to elements of the array
+ # could be very expensive (e.g. if it's only available over DAP), so go out
+ # of our way to get them in a single call to __getitem__ using only slices.
+ from xarray.core.variable import Variable
+
+ if n_desired < 1:
+ raise ValueError("must request at least one item")
+
+ if array.size == 0:
+ # work around for https://github.com/numpy/numpy/issues/5195
+ return []
+
+ if n_desired < array.size:
+ indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=False)
+ array = array[indexer]
+
+ # We pass variable objects in to handle indexing
+ # with indexer above. It would not work with our
+ # lazy indexing classes at the moment, so we cannot
+ # pass Variable._data
+ if isinstance(array, Variable):
+ array = array._data
+ return np.ravel(to_duck_array(array))[:n_desired]
def last_n_items(array, n_desired):
"""Returns the last n_desired items of an array"""
- pass
+ # Unfortunately, we can't just do array.flat[-n_desired:] here because it
+ # might not be a numpy.ndarray. Moreover, access to elements of the array
+ # could be very expensive (e.g. if it's only available over DAP), so go out
+ # of our way to get them in a single call to __getitem__ using only slices.
+ from xarray.core.variable import Variable
+
+ if (n_desired == 0) or (array.size == 0):
+ return []
+
+ if n_desired < array.size:
+ indexer = _get_indexer_at_least_n_items(array.shape, n_desired, from_end=True)
+ array = array[indexer]
+
+ # We pass variable objects in to handle indexing
+ # with indexer above. It would not work with our
+ # lazy indexing classes at the moment, so we cannot
+ # pass Variable._data
+ if isinstance(array, Variable):
+ array = array._data
+ return np.ravel(to_duck_array(array))[-n_desired:]
def last_item(array):
"""Returns the last item of an array in a list or an empty list."""
- pass
+ if array.size == 0:
+ # work around for https://github.com/numpy/numpy/issues/5195
+ return []
+
+ indexer = (slice(-1, None),) * array.ndim
+ # to_numpy since dask doesn't support tolist
+ return np.ravel(to_numpy(array[indexer])).tolist()
-def calc_max_rows_first(max_rows: int) ->int:
+def calc_max_rows_first(max_rows: int) -> int:
"""Calculate the first rows to maintain the max number of rows."""
- pass
+ return max_rows // 2 + max_rows % 2
-def calc_max_rows_last(max_rows: int) ->int:
+def calc_max_rows_last(max_rows: int) -> int:
"""Calculate the last rows to maintain the max number of rows."""
- pass
+ return max_rows // 2
def format_timestamp(t):
"""Cast given object to a Timestamp and return a nicely formatted string"""
- pass
+ try:
+ timestamp = pd.Timestamp(t)
+ datetime_str = timestamp.isoformat(sep=" ")
+ except OutOfBoundsDatetime:
+ datetime_str = str(t)
+
+ try:
+ date_str, time_str = datetime_str.split()
+ except ValueError:
+ # catch NaT and others that don't split nicely
+ return datetime_str
+ else:
+ if time_str == "00:00:00":
+ return date_str
+ else:
+ return f"{date_str}T{time_str}"
def format_timedelta(t, timedelta_format=None):
"""Cast given object to a Timestamp and return a nicely formatted string"""
- pass
+ timedelta_str = str(pd.Timedelta(t))
+ try:
+ days_str, time_str = timedelta_str.split(" days ")
+ except ValueError:
+ # catch NaT and others that don't split nicely
+ return timedelta_str
+ else:
+ if timedelta_format == "date":
+ return days_str + " days"
+ elif timedelta_format == "time":
+ return time_str
+ else:
+ return timedelta_str
def format_item(x, timedelta_format=None, quote_strings=True):
"""Returns a succinct summary of an object as a string"""
- pass
+ if isinstance(x, (np.datetime64, datetime)):
+ return format_timestamp(x)
+ if isinstance(x, (np.timedelta64, timedelta)):
+ return format_timedelta(x, timedelta_format=timedelta_format)
+ elif isinstance(x, (str, bytes)):
+ if hasattr(x, "dtype"):
+ x = x.item()
+ return repr(x) if quote_strings else x
+ elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating):
+ return f"{x.item():.4}"
+ else:
+ return str(x)
def format_items(x):
"""Returns a succinct summaries of all items in a sequence as strings"""
- pass
+ x = to_duck_array(x)
+ timedelta_format = "datetime"
+ if np.issubdtype(x.dtype, np.timedelta64):
+ x = astype(x, dtype="timedelta64[ns]")
+ day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
+ time_needed = x[~pd.isnull(x)] != day_part
+ day_needed = day_part != np.timedelta64(0, "ns")
+ if np.logical_not(day_needed).all():
+ timedelta_format = "time"
+ elif np.logical_not(time_needed).all():
+ timedelta_format = "date"
+
+ formatted = [format_item(xi, timedelta_format) for xi in x]
+ return formatted
def format_array_flat(array, max_width: int):
"""Return a formatted string for as many items in the flattened version of
array that will fit within max_width characters.
"""
- pass
+ # every item will take up at least two characters, but we always want to
+ # print at least first and last items
+ max_possibly_relevant = min(max(array.size, 1), max(math.ceil(max_width / 2.0), 2))
+ relevant_front_items = format_items(
+ first_n_items(array, (max_possibly_relevant + 1) // 2)
+ )
+ relevant_back_items = format_items(last_n_items(array, max_possibly_relevant // 2))
+ # interleave relevant front and back items:
+ # [a, b, c] and [y, z] -> [a, z, b, y, c]
+ relevant_items = sum(
+ zip_longest(relevant_front_items, reversed(relevant_back_items)), ()
+ )[:max_possibly_relevant]
+ cum_len = np.cumsum([len(s) + 1 for s in relevant_items]) - 1
+ if (array.size > 2) and (
+ (max_possibly_relevant < array.size) or (cum_len > max_width).any()
+ ):
+ padding = " ... "
+ max_len = max(int(np.argmax(cum_len + len(padding) - 1 > max_width)), 2)
+ count = min(array.size, max_len)
+ else:
+ count = array.size
+ padding = "" if (count <= 1) else " "
-_KNOWN_TYPE_REPRS = {('numpy', 'ndarray'): 'np.ndarray', (
- 'sparse._coo.core', 'COO'): 'sparse.COO'}
+ num_front = (count + 1) // 2
+ num_back = count - num_front
+ # note that num_back is 0 <--> array.size is 0 or 1
+ # <--> relevant_back_items is []
+ pprint_str = "".join(
+ [
+ " ".join(relevant_front_items[:num_front]),
+ padding,
+ " ".join(relevant_back_items[-num_back:]),
+ ]
+ )
+
+ # As a final check, if it's still too long even with the limit in values,
+ # replace the end with an ellipsis
+ # NB: this will still returns a full 3-character ellipsis when max_width < 3
+ if len(pprint_str) > max_width:
+ pprint_str = pprint_str[: max(max_width - 3, 0)] + "..."
+
+ return pprint_str
+
+
+# mapping of tuple[modulename, classname] to repr
+_KNOWN_TYPE_REPRS = {
+ ("numpy", "ndarray"): "np.ndarray",
+ ("sparse._coo.core", "COO"): "sparse.COO",
+}
def inline_dask_repr(array):
@@ -96,40 +274,254 @@ def inline_dask_repr(array):
redundant information that's already printed by the repr
function of the xarray wrapper.
"""
- pass
+ assert isinstance(array, array_type("dask")), array
+
+ chunksize = tuple(c[0] for c in array.chunks)
+
+ if hasattr(array, "_meta"):
+ meta = array._meta
+ identifier = (type(meta).__module__, type(meta).__name__)
+ meta_repr = _KNOWN_TYPE_REPRS.get(identifier, ".".join(identifier))
+ meta_string = f", meta={meta_repr}"
+ else:
+ meta_string = ""
+
+ return f"dask.array<chunksize={chunksize}{meta_string}>"
def inline_sparse_repr(array):
"""Similar to sparse.COO.__repr__, but without the redundant shape/dtype."""
- pass
+ sparse_array_type = array_type("sparse")
+ assert isinstance(array, sparse_array_type), array
+ return (
+ f"<{type(array).__name__}: nnz={array.nnz:d}, fill_value={array.fill_value!s}>"
+ )
def inline_variable_array_repr(var, max_width):
"""Build a one-line summary of a variable's data."""
- pass
+ if hasattr(var._data, "_repr_inline_"):
+ return var._data._repr_inline_(max_width)
+ if var._in_memory:
+ return format_array_flat(var, max_width)
+ dask_array_type = array_type("dask")
+ if isinstance(var._data, dask_array_type):
+ return inline_dask_repr(var.data)
+ sparse_array_type = array_type("sparse")
+ if isinstance(var._data, sparse_array_type):
+ return inline_sparse_repr(var.data)
+ if hasattr(var._data, "__array_function__"):
+ return maybe_truncate(repr(var._data).replace("\n", " "), max_width)
+ # internal xarray array type
+ return "..."
-def summarize_variable(name: Hashable, var, col_width: int, max_width: (int |
- None)=None, is_index: bool=False):
+def summarize_variable(
+ name: Hashable,
+ var,
+ col_width: int,
+ max_width: int | None = None,
+ is_index: bool = False,
+):
"""Summarize a variable in one line, e.g., for the Dataset.__repr__."""
- pass
+ variable = getattr(var, "variable", var)
+
+ if max_width is None:
+ max_width_options = OPTIONS["display_width"]
+ if not isinstance(max_width_options, int):
+ raise TypeError(f"`max_width` value of `{max_width}` is not a valid int")
+ else:
+ max_width = max_width_options
+
+ marker = "*" if is_index else " "
+ first_col = pretty_print(f" {marker} {name} ", col_width)
+
+ if variable.dims:
+ dims_str = "({}) ".format(", ".join(map(str, variable.dims)))
+ else:
+ dims_str = ""
+
+ nbytes_str = f" {render_human_readable_nbytes(variable.nbytes)}"
+ front_str = f"{first_col}{dims_str}{variable.dtype}{nbytes_str} "
+
+ values_width = max_width - len(front_str)
+ values_str = inline_variable_array_repr(variable, values_width)
+
+ return front_str + values_str
def summarize_attr(key, value, col_width=None):
"""Summary for __repr__ - use ``X.attrs[key]`` for full value."""
- pass
+ # Indent key and add ':', then right-pad if col_width is not None
+ k_str = f" {key}:"
+ if col_width is not None:
+ k_str = pretty_print(k_str, col_width)
+ # Replace tabs and newlines, so we print on one line in known width
+ v_str = str(value).replace("\t", "\\t").replace("\n", "\\n")
+ # Finally, truncate to the desired display width
+ return maybe_truncate(f"{k_str} {v_str}", OPTIONS["display_width"])
+
+
+EMPTY_REPR = " *empty*"
+
+
+def _calculate_col_width(col_items):
+ max_name_length = max((len(str(s)) for s in col_items), default=0)
+ col_width = max(max_name_length, 7) + 6
+ return col_width
+
+
+def _mapping_repr(
+ mapping,
+ title,
+ summarizer,
+ expand_option_name,
+ col_width=None,
+ max_rows=None,
+ indexes=None,
+):
+ if col_width is None:
+ col_width = _calculate_col_width(mapping)
+
+ summarizer_kwargs = defaultdict(dict)
+ if indexes is not None:
+ summarizer_kwargs = {k: {"is_index": k in indexes} for k in mapping}
+
+ summary = [f"{title}:"]
+ if mapping:
+ len_mapping = len(mapping)
+ if not _get_boolean_with_default(expand_option_name, default=True):
+ summary = [f"{summary[0]} ({len_mapping})"]
+ elif max_rows is not None and len_mapping > max_rows:
+ summary = [f"{summary[0]} ({max_rows}/{len_mapping})"]
+ first_rows = calc_max_rows_first(max_rows)
+ keys = list(mapping.keys())
+ summary += [
+ summarizer(k, mapping[k], col_width, **summarizer_kwargs[k])
+ for k in keys[:first_rows]
+ ]
+ if max_rows > 1:
+ last_rows = calc_max_rows_last(max_rows)
+ summary += [pretty_print(" ...", col_width) + " ..."]
+ summary += [
+ summarizer(k, mapping[k], col_width, **summarizer_kwargs[k])
+ for k in keys[-last_rows:]
+ ]
+ else:
+ summary += [
+ summarizer(k, v, col_width, **summarizer_kwargs[k])
+ for k, v in mapping.items()
+ ]
+ else:
+ summary += [EMPTY_REPR]
+ return "\n".join(summary)
-EMPTY_REPR = ' *empty*'
-data_vars_repr = functools.partial(_mapping_repr, title='Data variables',
- summarizer=summarize_variable, expand_option_name=
- 'display_expand_data_vars')
-attrs_repr = functools.partial(_mapping_repr, title='Attributes',
- summarizer=summarize_attr, expand_option_name='display_expand_attrs')
+data_vars_repr = functools.partial(
+ _mapping_repr,
+ title="Data variables",
+ summarizer=summarize_variable,
+ expand_option_name="display_expand_data_vars",
+)
+attrs_repr = functools.partial(
+ _mapping_repr,
+ title="Attributes",
+ summarizer=summarize_attr,
+ expand_option_name="display_expand_attrs",
+)
-def _element_formatter(elements: Collection[Hashable], col_width: int,
- max_rows: (int | None)=None, delimiter: str=', ') ->str:
+
+def coords_repr(coords: AbstractCoordinates, col_width=None, max_rows=None):
+ if col_width is None:
+ col_width = _calculate_col_width(coords)
+ return _mapping_repr(
+ coords,
+ title="Coordinates",
+ summarizer=summarize_variable,
+ expand_option_name="display_expand_coords",
+ col_width=col_width,
+ indexes=coords.xindexes,
+ max_rows=max_rows,
+ )
+
+
+def inline_index_repr(index: pd.Index, max_width=None):
+ if hasattr(index, "_repr_inline_"):
+ repr_ = index._repr_inline_(max_width=max_width)
+ else:
+ # fallback for the `pandas.Index` subclasses from
+ # `Indexes.get_pandas_indexes` / `xr_obj.indexes`
+ repr_ = repr(index)
+
+ return repr_
+
+
+def summarize_index(
+ names: tuple[Hashable, ...],
+ index,
+ col_width: int,
+ max_width: int | None = None,
+) -> str:
+ if max_width is None:
+ max_width = OPTIONS["display_width"]
+
+ def prefixes(length: int) -> list[str]:
+ if length in (0, 1):
+ return [" "]
+
+ return ["┌"] + ["│"] * max(length - 2, 0) + ["└"]
+
+ preformatted = [
+ pretty_print(f" {prefix} {name}", col_width)
+ for prefix, name in zip(prefixes(len(names)), names)
+ ]
+
+ head, *tail = preformatted
+ index_width = max_width - len(head)
+ repr_ = inline_index_repr(index, max_width=index_width)
+ return "\n".join([head + repr_] + [line.rstrip() for line in tail])
+
+
+def filter_nondefault_indexes(indexes, filter_indexes: bool):
+ from xarray.core.indexes import PandasIndex, PandasMultiIndex
+
+ if not filter_indexes:
+ return indexes
+
+ default_indexes = (PandasIndex, PandasMultiIndex)
+
+ return {
+ key: index
+ for key, index in indexes.items()
+ if not isinstance(index, default_indexes)
+ }
+
+
+def indexes_repr(indexes, max_rows: int | None = None) -> str:
+ col_width = _calculate_col_width(chain.from_iterable(indexes))
+
+ return _mapping_repr(
+ indexes,
+ "Indexes",
+ summarize_index,
+ "display_expand_indexes",
+ col_width=col_width,
+ max_rows=max_rows,
+ )
+
+
+def dim_summary(obj):
+ elements = [f"{k}: {v}" for k, v in obj.sizes.items()]
+ return ", ".join(elements)
+
+
+def _element_formatter(
+ elements: Collection[Hashable],
+ col_width: int,
+ max_rows: int | None = None,
+ delimiter: str = ", ",
+) -> str:
"""
Formats elements for better readability.
@@ -149,7 +541,61 @@ def _element_formatter(elements: Collection[Hashable], col_width: int,
delimiter : str, optional
Delimiter to use between each element. The default is ", ".
"""
- pass
+ elements_len = len(elements)
+ out = [""]
+ length_row = 0
+ for i, v in enumerate(elements):
+ delim = delimiter if i < elements_len - 1 else ""
+ v_delim = f"{v}{delim}"
+ length_element = len(v_delim)
+ length_row += length_element
+
+ # Create a new row if the next elements makes the print wider than
+ # the maximum display width:
+ if col_width + length_row > OPTIONS["display_width"]:
+ out[-1] = out[-1].rstrip() # Remove trailing whitespace.
+ out.append("\n" + pretty_print("", col_width) + v_delim)
+ length_row = length_element
+ else:
+ out[-1] += v_delim
+
+ # If there are too many rows of dimensions trim some away:
+ if max_rows and (len(out) > max_rows):
+ first_rows = calc_max_rows_first(max_rows)
+ last_rows = calc_max_rows_last(max_rows)
+ out = (
+ out[:first_rows]
+ + ["\n" + pretty_print("", col_width) + "..."]
+ + (out[-last_rows:] if max_rows > 1 else [])
+ )
+ return "".join(out)
+
+
+def dim_summary_limited(obj, col_width: int, max_rows: int | None = None) -> str:
+ elements = [f"{k}: {v}" for k, v in obj.sizes.items()]
+ return _element_formatter(elements, col_width, max_rows)
+
+
+def unindexed_dims_repr(dims, coords, max_rows: int | None = None):
+ unindexed_dims = [d for d in dims if d not in coords]
+ if unindexed_dims:
+ dims_start = "Dimensions without coordinates: "
+ dims_str = _element_formatter(
+ unindexed_dims, col_width=len(dims_start), max_rows=max_rows
+ )
+ return dims_start + dims_str
+ else:
+ return None
+
+
+@contextlib.contextmanager
+def set_numpy_options(*args, **kwargs):
+ original = np.get_printoptions()
+ np.set_printoptions(*args, **kwargs)
+ try:
+ yield
+ finally:
+ np.set_printoptions(**original)
def limit_lines(string: str, *, limit: int):
@@ -157,51 +603,504 @@ def limit_lines(string: str, *, limit: int):
If the string is more lines than the limit,
this returns the middle lines replaced by an ellipsis
"""
- pass
+ lines = string.splitlines()
+ if len(lines) > limit:
+ string = "\n".join(chain(lines[: limit // 2], ["..."], lines[-limit // 2 :]))
+ return string
+
+
+def short_array_repr(array):
+ from xarray.core.common import AbstractArray
+
+ if isinstance(array, AbstractArray):
+ array = array.data
+ array = to_duck_array(array)
+
+ # default to lower precision so a full (abbreviated) line can fit on
+ # one line with the default display_width
+ options = {
+ "precision": 6,
+ "linewidth": OPTIONS["display_width"],
+ "threshold": OPTIONS["display_values_threshold"],
+ }
+ if array.ndim < 3:
+ edgeitems = 3
+ elif array.ndim == 3:
+ edgeitems = 2
+ else:
+ edgeitems = 1
+ options["edgeitems"] = edgeitems
+ with set_numpy_options(**options):
+ return repr(array)
def short_data_repr(array):
"""Format "data" for DataArray and Variable."""
- pass
+ internal_data = getattr(array, "variable", array)._data
+ if isinstance(array, np.ndarray):
+ return short_array_repr(array)
+ elif is_duck_array(internal_data):
+ return limit_lines(repr(array.data), limit=40)
+ elif getattr(array, "_in_memory", None):
+ return short_array_repr(array)
+ else:
+ # internal xarray array type
+ return f"[{array.size} values with dtype={array.dtype}]"
+
+
+def _get_indexes_dict(indexes):
+ return {
+ tuple(index_vars.keys()): idx for idx, index_vars in indexes.group_by_index()
+ }
+
+
+@recursive_repr("<recursive array>")
+def array_repr(arr):
+ from xarray.core.variable import Variable
+
+ max_rows = OPTIONS["display_max_rows"]
+
+ # used for DataArray, Variable and IndexVariable
+ if hasattr(arr, "name") and arr.name is not None:
+ name_str = f"{arr.name!r} "
+ else:
+ name_str = ""
+
+ if (
+ isinstance(arr, Variable)
+ or _get_boolean_with_default("display_expand_data", default=True)
+ or isinstance(arr.variable._data, MemoryCachedArray)
+ ):
+ data_repr = short_data_repr(arr)
+ else:
+ data_repr = inline_variable_array_repr(arr.variable, OPTIONS["display_width"])
+
+ start = f"<xarray.{type(arr).__name__} {name_str}"
+ dims = dim_summary_limited(arr, col_width=len(start) + 1, max_rows=max_rows)
+ nbytes_str = render_human_readable_nbytes(arr.nbytes)
+ summary = [
+ f"{start}({dims})> Size: {nbytes_str}",
+ data_repr,
+ ]
+ if hasattr(arr, "coords"):
+ if arr.coords:
+ col_width = _calculate_col_width(arr.coords)
+ summary.append(
+ coords_repr(arr.coords, col_width=col_width, max_rows=max_rows)
+ )
+ unindexed_dims_str = unindexed_dims_repr(
+ arr.dims, arr.coords, max_rows=max_rows
+ )
+ if unindexed_dims_str:
+ summary.append(unindexed_dims_str)
-def dims_and_coords_repr(ds) ->str:
+ display_default_indexes = _get_boolean_with_default(
+ "display_default_indexes", False
+ )
+
+ xindexes = filter_nondefault_indexes(
+ _get_indexes_dict(arr.xindexes), not display_default_indexes
+ )
+
+ if xindexes:
+ summary.append(indexes_repr(xindexes, max_rows=max_rows))
+
+ if arr.attrs:
+ summary.append(attrs_repr(arr.attrs, max_rows=max_rows))
+
+ return "\n".join(summary)
+
+
+@recursive_repr("<recursive Dataset>")
+def dataset_repr(ds):
+ nbytes_str = render_human_readable_nbytes(ds.nbytes)
+ summary = [f"<xarray.{type(ds).__name__}> Size: {nbytes_str}"]
+
+ col_width = _calculate_col_width(ds.variables)
+ max_rows = OPTIONS["display_max_rows"]
+
+ dims_start = pretty_print("Dimensions:", col_width)
+ dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
+ summary.append(f"{dims_start}({dims_values})")
+
+ if ds.coords:
+ summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows))
+
+ unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows)
+ if unindexed_dims_str:
+ summary.append(unindexed_dims_str)
+
+ summary.append(data_vars_repr(ds.data_vars, col_width=col_width, max_rows=max_rows))
+
+ display_default_indexes = _get_boolean_with_default(
+ "display_default_indexes", False
+ )
+ xindexes = filter_nondefault_indexes(
+ _get_indexes_dict(ds.xindexes), not display_default_indexes
+ )
+ if xindexes:
+ summary.append(indexes_repr(xindexes, max_rows=max_rows))
+
+ if ds.attrs:
+ summary.append(attrs_repr(ds.attrs, max_rows=max_rows))
+
+ return "\n".join(summary)
+
+
+def dims_and_coords_repr(ds) -> str:
"""Partial Dataset repr for use inside DataTree inheritance errors."""
- pass
+ summary = []
+
+ col_width = _calculate_col_width(ds.coords)
+ max_rows = OPTIONS["display_max_rows"]
+
+ dims_start = pretty_print("Dimensions:", col_width)
+ dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
+ summary.append(f"{dims_start}({dims_values})")
+
+ if ds.coords:
+ summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows))
+
+ unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows)
+ if unindexed_dims_str:
+ summary.append(unindexed_dims_str)
+
+ return "\n".join(summary)
+
+
+def diff_dim_summary(a, b):
+ if a.sizes != b.sizes:
+ return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
+ else:
+ return ""
+
+
+def _diff_mapping_repr(
+ a_mapping,
+ b_mapping,
+ compat,
+ title,
+ summarizer,
+ col_width=None,
+ a_indexes=None,
+ b_indexes=None,
+):
+ def compare_attr(a, b):
+ if is_duck_array(a) or is_duck_array(b):
+ return array_equiv(a, b)
+ else:
+ return a == b
+
+ def extra_items_repr(extra_keys, mapping, ab_side, kwargs):
+ extra_repr = [
+ summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys
+ ]
+ if extra_repr:
+ header = f"{title} only on the {ab_side} object:"
+ return [header] + extra_repr
+ else:
+ return []
+
+ a_keys = set(a_mapping)
+ b_keys = set(b_mapping)
+
+ summary = []
+
+ diff_items = []
+
+ a_summarizer_kwargs = defaultdict(dict)
+ if a_indexes is not None:
+ a_summarizer_kwargs = {k: {"is_index": k in a_indexes} for k in a_mapping}
+ b_summarizer_kwargs = defaultdict(dict)
+ if b_indexes is not None:
+ b_summarizer_kwargs = {k: {"is_index": k in b_indexes} for k in b_mapping}
+
+ for k in a_keys & b_keys:
+ try:
+ # compare xarray variable
+ if not callable(compat):
+ compatible = getattr(a_mapping[k].variable, compat)(
+ b_mapping[k].variable
+ )
+ else:
+ compatible = compat(a_mapping[k].variable, b_mapping[k].variable)
+ is_variable = True
+ except AttributeError:
+ # compare attribute value
+ compatible = compare_attr(a_mapping[k], b_mapping[k])
+ is_variable = False
+
+ if not compatible:
+ temp = [
+ summarizer(k, a_mapping[k], col_width, **a_summarizer_kwargs[k]),
+ summarizer(k, b_mapping[k], col_width, **b_summarizer_kwargs[k]),
+ ]
+
+ if compat == "identical" and is_variable:
+ attrs_summary = []
+ a_attrs = a_mapping[k].attrs
+ b_attrs = b_mapping[k].attrs
+
+ attrs_to_print = set(a_attrs) ^ set(b_attrs)
+ attrs_to_print.update(
+ {
+ k
+ for k in set(a_attrs) & set(b_attrs)
+ if not compare_attr(a_attrs[k], b_attrs[k])
+ }
+ )
+ for m in (a_mapping, b_mapping):
+ attr_s = "\n".join(
+ " " + summarize_attr(ak, av)
+ for ak, av in m[k].attrs.items()
+ if ak in attrs_to_print
+ )
+ if attr_s:
+ attr_s = " Differing variable attributes:\n" + attr_s
+ attrs_summary.append(attr_s)
+
+ temp = [
+ "\n".join([var_s, attr_s]) if attr_s else var_s
+ for var_s, attr_s in zip(temp, attrs_summary)
+ ]
+
+ # TODO: It should be possible recursively use _diff_mapping_repr
+ # instead of explicitly handling variable attrs specially.
+ # That would require some refactoring.
+ # newdiff = _diff_mapping_repr(
+ # {k: v for k,v in a_attrs.items() if k in attrs_to_print},
+ # {k: v for k,v in b_attrs.items() if k in attrs_to_print},
+ # compat=compat,
+ # summarizer=summarize_attr,
+ # title="Variable Attributes"
+ # )
+ # temp += [newdiff]
+
+ diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)]
+
+ if diff_items:
+ summary += [f"Differing {title.lower()}:"] + diff_items
+
+ summary += extra_items_repr(a_keys - b_keys, a_mapping, "left", a_summarizer_kwargs)
+ summary += extra_items_repr(
+ b_keys - a_keys, b_mapping, "right", b_summarizer_kwargs
+ )
+
+ return "\n".join(summary)
+
+
+def diff_coords_repr(a, b, compat, col_width=None):
+ return _diff_mapping_repr(
+ a,
+ b,
+ compat,
+ "Coordinates",
+ summarize_variable,
+ col_width=col_width,
+ a_indexes=a.xindexes,
+ b_indexes=b.xindexes,
+ )
+
+diff_data_vars_repr = functools.partial(
+ _diff_mapping_repr, title="Data variables", summarizer=summarize_variable
+)
-diff_data_vars_repr = functools.partial(_diff_mapping_repr, title=
- 'Data variables', summarizer=summarize_variable)
-diff_attrs_repr = functools.partial(_diff_mapping_repr, title='Attributes',
- summarizer=summarize_attr)
+diff_attrs_repr = functools.partial(
+ _diff_mapping_repr, title="Attributes", summarizer=summarize_attr
+)
-def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool
- ) ->str:
+
+def _compat_to_str(compat):
+ if callable(compat):
+ compat = compat.__name__
+
+ if compat == "equals":
+ return "equal"
+ elif compat == "allclose":
+ return "close"
+ else:
+ return compat
+
+
+def diff_array_repr(a, b, compat):
+ # used for DataArray, Variable and IndexVariable
+ summary = [
+ f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}"
+ ]
+
+ summary.append(diff_dim_summary(a, b))
+ if callable(compat):
+ equiv = compat
+ else:
+ equiv = array_equiv
+
+ if not equiv(a.data, b.data):
+ temp = [wrap_indent(short_array_repr(obj), start=" ") for obj in (a, b)]
+ diff_data_repr = [
+ ab_side + "\n" + ab_data_repr
+ for ab_side, ab_data_repr in zip(("L", "R"), temp)
+ ]
+ summary += ["Differing values:"] + diff_data_repr
+
+ if hasattr(a, "coords"):
+ col_width = _calculate_col_width(set(a.coords) | set(b.coords))
+ summary.append(
+ diff_coords_repr(a.coords, b.coords, compat, col_width=col_width)
+ )
+
+ if compat == "identical":
+ summary.append(diff_attrs_repr(a.attrs, b.attrs, compat))
+
+ return "\n".join(summary)
+
+
+def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""
- pass
+
+ # Walking nodes in "level-order" fashion means walking down from the root breadth-first.
+ # Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
+ # (which it is so long as children are stored in a tuple or list rather than in a set).
+ for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
+ path_a, path_b = node_a.path, node_b.path
+
+ if require_names_equal and node_a.name != node_b.name:
+ diff = dedent(
+ f"""\
+ Node '{path_a}' in the left object has name '{node_a.name}'
+ Node '{path_b}' in the right object has name '{node_b.name}'"""
+ )
+ return diff
+
+ if len(node_a.children) != len(node_b.children):
+ diff = dedent(
+ f"""\
+ Number of children on node '{path_a}' of the left object: {len(node_a.children)}
+ Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
+ )
+ return diff
+
+ return ""
+
+
+def diff_dataset_repr(a, b, compat):
+ summary = [
+ f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}"
+ ]
+
+ col_width = _calculate_col_width(set(list(a.variables) + list(b.variables)))
+
+ summary.append(diff_dim_summary(a, b))
+ summary.append(diff_coords_repr(a.coords, b.coords, compat, col_width=col_width))
+ summary.append(
+ diff_data_vars_repr(a.data_vars, b.data_vars, compat, col_width=col_width)
+ )
+
+ if compat == "identical":
+ summary.append(diff_attrs_repr(a.attrs, b.attrs, compat))
+
+ return "\n".join(summary)
def diff_nodewise_summary(a: DataTree, b: DataTree, compat):
"""Iterates over all corresponding nodes, recording differences between data at each location."""
- pass
+ compat_str = _compat_to_str(compat)
-def _single_node_repr(node: DataTree) ->str:
+ summary = []
+ for node_a, node_b in zip(a.subtree, b.subtree):
+ a_ds, b_ds = node_a.ds, node_b.ds
+
+ if not a_ds._all_compat(b_ds, compat):
+ dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str)
+ data_diff = "\n".join(dataset_diff.split("\n", 1)[1:])
+
+ nodediff = (
+ f"\nData in nodes at position '{node_a.path}' do not match:"
+ f"{data_diff}"
+ )
+ summary.append(nodediff)
+
+ return "\n".join(summary)
+
+
+def diff_datatree_repr(a: DataTree, b: DataTree, compat):
+ summary = [
+ f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}"
+ ]
+
+ strict_names = True if compat in ["equals", "identical"] else False
+ treestructure_diff = diff_treestructure(a, b, strict_names)
+
+ # If the trees structures are different there is no point comparing each node
+ # TODO we could show any differences in nodes up to the first place that structure differs?
+ if treestructure_diff or compat == "isomorphic":
+ summary.append("\n" + treestructure_diff)
+ else:
+ nodewise_diff = diff_nodewise_summary(a, b, compat)
+ summary.append("\n" + nodewise_diff)
+
+ return "\n".join(summary)
+
+
+def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
- pass
+ if node.has_data or node.has_attrs:
+ ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False))
+ else:
+ ds_info = ""
+ return f"Group: {node.path}{ds_info}"
def datatree_repr(dt: DataTree):
"""A printable representation of the structure of this entire tree."""
- pass
+ renderer = RenderDataTree(dt)
+
+ name_info = "" if dt.name is None else f" {dt.name!r}"
+ header = f"<xarray.DataTree{name_info}>"
+ lines = [header]
+ for pre, fill, node in renderer:
+ node_repr = _single_node_repr(node)
-def render_human_readable_nbytes(nbytes: int, /, *, attempt_constant_width:
- bool=False) ->str:
+ node_line = f"{pre}{node_repr.splitlines()[0]}"
+ lines.append(node_line)
+
+ if node.has_data or node.has_attrs:
+ ds_repr = node_repr.splitlines()[2:]
+ for line in ds_repr:
+ if len(node.children) > 0:
+ lines.append(f"{fill}{renderer.style.vertical}{line}")
+ else:
+ lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")
+
+ return "\n".join(lines)
+
+
+def shorten_list_repr(items: Sequence, max_items: int) -> str:
+ if len(items) <= max_items:
+ return repr(items)
+ else:
+ first_half = repr(items[: max_items // 2])[
+ 1:-1
+ ] # Convert to string and remove brackets
+ second_half = repr(items[-max_items // 2 :])[
+ 1:-1
+ ] # Convert to string and remove brackets
+ return f"[{first_half}, ..., {second_half}]"
+
+
+def render_human_readable_nbytes(
+ nbytes: int,
+ /,
+ *,
+ attempt_constant_width: bool = False,
+) -> str:
"""Renders simple human-readable byte count representation
This is only a quick representation that should not be relied upon for precise needs.
@@ -219,4 +1118,21 @@ def render_human_readable_nbytes(nbytes: int, /, *, attempt_constant_width:
-------
Human-readable representation of the byte count
"""
- pass
+ dividend = float(nbytes)
+ divisor = 1000.0
+ last_unit_available = UNITS[-1]
+
+ for unit in UNITS:
+ if dividend < divisor or unit == last_unit_available:
+ break
+ dividend /= divisor
+
+ dividend_str = f"{dividend:.0f}"
+ unit_str = f"{unit}"
+
+ if attempt_constant_width:
+ dividend_str = dividend_str.rjust(3)
+ unit_str = unit_str.ljust(2)
+
+ string = f"{dividend_str}{unit_str}"
+ return string
diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py
index a85cdfc5..24b29003 100644
--- a/xarray/core/formatting_html.py
+++ b/xarray/core/formatting_html.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import uuid
from collections import OrderedDict
from collections.abc import Mapping
@@ -6,10 +7,19 @@ from functools import lru_cache, partial
from html import escape
from importlib.resources import files
from typing import TYPE_CHECKING
-from xarray.core.formatting import inline_index_repr, inline_variable_array_repr, short_data_repr
+
+from xarray.core.formatting import (
+ inline_index_repr,
+ inline_variable_array_repr,
+ short_data_repr,
+)
from xarray.core.options import _get_boolean_with_default
-STATIC_FILES = ('xarray.static.html', 'icons-svg-inline.html'), (
- 'xarray.static.css', 'style.css')
+
+STATIC_FILES = (
+ ("xarray.static.html", "icons-svg-inline.html"),
+ ("xarray.static.css", "style.css"),
+)
+
if TYPE_CHECKING:
from xarray.core.datatree import DataTree
@@ -17,26 +27,257 @@ if TYPE_CHECKING:
@lru_cache(None)
def _load_static_files():
"""Lazily load the resource files into memory the first time they are needed"""
- pass
+ return [
+ files(package).joinpath(resource).read_text(encoding="utf-8")
+ for package, resource in STATIC_FILES
+ ]
-def short_data_repr_html(array) ->str:
+def short_data_repr_html(array) -> str:
"""Format "data" for DataArray and Variable."""
- pass
+ internal_data = getattr(array, "variable", array)._data
+ if hasattr(internal_data, "_repr_html_"):
+ return internal_data._repr_html_()
+ text = escape(short_data_repr(array))
+ return f"<pre>{text}</pre>"
+
+
+def format_dims(dim_sizes, dims_with_index) -> str:
+ if not dim_sizes:
+ return ""
+
+ dim_css_map = {
+ dim: " class='xr-has-index'" if dim in dims_with_index else ""
+ for dim in dim_sizes
+ }
+
+ dims_li = "".join(
+ f"<li><span{dim_css_map[dim]}>{escape(str(dim))}</span>: {size}</li>"
+ for dim, size in dim_sizes.items()
+ )
+
+ return f"<ul class='xr-dim-list'>{dims_li}</ul>"
+
+
+def summarize_attrs(attrs) -> str:
+ attrs_dl = "".join(
+ f"<dt><span>{escape(str(k))} :</span></dt><dd>{escape(str(v))}</dd>"
+ for k, v in attrs.items()
+ )
+
+ return f"<dl class='xr-attrs'>{attrs_dl}</dl>"
+
+
+def _icon(icon_name) -> str:
+ # icon_name should be defined in xarray/static/html/icon-svg-inline.html
+ return (
+ f"<svg class='icon xr-{icon_name}'>"
+ f"<use xlink:href='#{icon_name}'>"
+ "</use>"
+ "</svg>"
+ )
+
+
+def summarize_variable(name, var, is_index=False, dtype=None) -> str:
+ variable = var.variable if hasattr(var, "variable") else var
+
+ cssclass_idx = " class='xr-has-index'" if is_index else ""
+ dims_str = f"({', '.join(escape(dim) for dim in var.dims)})"
+ name = escape(str(name))
+ dtype = dtype or escape(str(var.dtype))
+
+ # "unique" ids required to expand/collapse subsections
+ attrs_id = "attrs-" + str(uuid.uuid4())
+ data_id = "data-" + str(uuid.uuid4())
+ disabled = "" if len(var.attrs) else "disabled"
+
+ preview = escape(inline_variable_array_repr(variable, 35))
+ attrs_ul = summarize_attrs(var.attrs)
+ data_repr = short_data_repr_html(variable)
+
+ attrs_icon = _icon("icon-file-text2")
+ data_icon = _icon("icon-database")
+
+ return (
+ f"<div class='xr-var-name'><span{cssclass_idx}>{name}</span></div>"
+ f"<div class='xr-var-dims'>{dims_str}</div>"
+ f"<div class='xr-var-dtype'>{dtype}</div>"
+ f"<div class='xr-var-preview xr-preview'>{preview}</div>"
+ f"<input id='{attrs_id}' class='xr-var-attrs-in' "
+ f"type='checkbox' {disabled}>"
+ f"<label for='{attrs_id}' title='Show/Hide attributes'>"
+ f"{attrs_icon}</label>"
+ f"<input id='{data_id}' class='xr-var-data-in' type='checkbox'>"
+ f"<label for='{data_id}' title='Show/Hide data repr'>"
+ f"{data_icon}</label>"
+ f"<div class='xr-var-attrs'>{attrs_ul}</div>"
+ f"<div class='xr-var-data'>{data_repr}</div>"
+ )
+
+
+def summarize_coords(variables) -> str:
+ li_items = []
+ for k, v in variables.items():
+ li_content = summarize_variable(k, v, is_index=k in variables.xindexes)
+ li_items.append(f"<li class='xr-var-item'>{li_content}</li>")
+
+ vars_li = "".join(li_items)
+
+ return f"<ul class='xr-var-list'>{vars_li}</ul>"
+
+
+def summarize_vars(variables) -> str:
+ vars_li = "".join(
+ f"<li class='xr-var-item'>{summarize_variable(k, v)}</li>"
+ for k, v in variables.items()
+ )
+
+ return f"<ul class='xr-var-list'>{vars_li}</ul>"
+
+
+def short_index_repr_html(index) -> str:
+ if hasattr(index, "_repr_html_"):
+ return index._repr_html_()
+
+ return f"<pre>{escape(repr(index))}</pre>"
+
+
+def summarize_index(coord_names, index) -> str:
+ name = "<br>".join([escape(str(n)) for n in coord_names])
+
+ index_id = f"index-{uuid.uuid4()}"
+ preview = escape(inline_index_repr(index))
+ details = short_index_repr_html(index)
+
+ data_icon = _icon("icon-database")
+
+ return (
+ f"<div class='xr-index-name'><div>{name}</div></div>"
+ f"<div class='xr-index-preview'>{preview}</div>"
+ f"<div></div>"
+ f"<input id='{index_id}' class='xr-index-data-in' type='checkbox'/>"
+ f"<label for='{index_id}' title='Show/Hide index repr'>{data_icon}</label>"
+ f"<div class='xr-index-data'>{details}</div>"
+ )
+
+
+def summarize_indexes(indexes) -> str:
+ indexes_li = "".join(
+ f"<li class='xr-var-item'>{summarize_index(v, i)}</li>"
+ for v, i in indexes.items()
+ )
+ return f"<ul class='xr-var-list'>{indexes_li}</ul>"
+
+
+def collapsible_section(
+ name, inline_details="", details="", n_items=None, enabled=True, collapsed=False
+) -> str:
+ # "unique" id to expand/collapse the section
+ data_id = "section-" + str(uuid.uuid4())
+
+ has_items = n_items is not None and n_items
+ n_items_span = "" if n_items is None else f" <span>({n_items})</span>"
+ enabled = "" if enabled and has_items else "disabled"
+ collapsed = "" if collapsed or not has_items else "checked"
+ tip = " title='Expand/collapse section'" if enabled else ""
+
+ return (
+ f"<input id='{data_id}' class='xr-section-summary-in' "
+ f"type='checkbox' {enabled} {collapsed}>"
+ f"<label for='{data_id}' class='xr-section-summary' {tip}>"
+ f"{name}:{n_items_span}</label>"
+ f"<div class='xr-section-inline-details'>{inline_details}</div>"
+ f"<div class='xr-section-details'>{details}</div>"
+ )
-coord_section = partial(_mapping_section, name='Coordinates', details_func=
- summarize_coords, max_items_collapse=25, expand_option_name=
- 'display_expand_coords')
-datavar_section = partial(_mapping_section, name='Data variables',
- details_func=summarize_vars, max_items_collapse=15, expand_option_name=
- 'display_expand_data_vars')
-index_section = partial(_mapping_section, name='Indexes', details_func=
- summarize_indexes, max_items_collapse=0, expand_option_name=
- 'display_expand_indexes')
-attr_section = partial(_mapping_section, name='Attributes', details_func=
- summarize_attrs, max_items_collapse=10, expand_option_name=
- 'display_expand_attrs')
+def _mapping_section(
+ mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True
+) -> str:
+ n_items = len(mapping)
+ expanded = _get_boolean_with_default(
+ expand_option_name, n_items < max_items_collapse
+ )
+ collapsed = not expanded
+
+ return collapsible_section(
+ name,
+ details=details_func(mapping),
+ n_items=n_items,
+ enabled=enabled,
+ collapsed=collapsed,
+ )
+
+
+def dim_section(obj) -> str:
+ dim_list = format_dims(obj.sizes, obj.xindexes.dims)
+
+ return collapsible_section(
+ "Dimensions", inline_details=dim_list, enabled=False, collapsed=True
+ )
+
+
+def array_section(obj) -> str:
+ # "unique" id to expand/collapse the section
+ data_id = "section-" + str(uuid.uuid4())
+ collapsed = (
+ "checked"
+ if _get_boolean_with_default("display_expand_data", default=True)
+ else ""
+ )
+ variable = getattr(obj, "variable", obj)
+ preview = escape(inline_variable_array_repr(variable, max_width=70))
+ data_repr = short_data_repr_html(obj)
+ data_icon = _icon("icon-database")
+
+ return (
+ "<div class='xr-array-wrap'>"
+ f"<input id='{data_id}' class='xr-array-in' type='checkbox' {collapsed}>"
+ f"<label for='{data_id}' title='Show/hide data repr'>{data_icon}</label>"
+ f"<div class='xr-array-preview xr-preview'><span>{preview}</span></div>"
+ f"<div class='xr-array-data'>{data_repr}</div>"
+ "</div>"
+ )
+
+
+coord_section = partial(
+ _mapping_section,
+ name="Coordinates",
+ details_func=summarize_coords,
+ max_items_collapse=25,
+ expand_option_name="display_expand_coords",
+)
+
+
+datavar_section = partial(
+ _mapping_section,
+ name="Data variables",
+ details_func=summarize_vars,
+ max_items_collapse=15,
+ expand_option_name="display_expand_data_vars",
+)
+
+index_section = partial(
+ _mapping_section,
+ name="Indexes",
+ details_func=summarize_indexes,
+ max_items_collapse=0,
+ expand_option_name="display_expand_indexes",
+)
+
+attr_section = partial(
+ _mapping_section,
+ name="Attributes",
+ details_func=summarize_attrs,
+ max_items_collapse=10,
+ expand_option_name="display_expand_attrs",
+)
+
+
+def _get_indexes_dict(indexes):
+ return {
+ tuple(index_vars.keys()): idx for idx, index_vars in indexes.group_by_index()
+ }
def _obj_repr(obj, header_components, sections):
@@ -45,15 +286,120 @@ def _obj_repr(obj, header_components, sections):
If CSS is not injected (untrusted notebook), fallback to the plain text repr.
"""
- pass
+ header = f"<div class='xr-header'>{''.join(h for h in header_components)}</div>"
+ sections = "".join(f"<li class='xr-section-item'>{s}</li>" for s in sections)
+
+ icons_svg, css_style = _load_static_files()
+ return (
+ "<div>"
+ f"{icons_svg}<style>{css_style}</style>"
+ f"<pre class='xr-text-repr-fallback'>{escape(repr(obj))}</pre>"
+ "<div class='xr-wrap' style='display:none'>"
+ f"{header}"
+ f"<ul class='xr-sections'>{sections}</ul>"
+ "</div>"
+ "</div>"
+ )
+
+
+def array_repr(arr) -> str:
+ dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape))
+ if hasattr(arr, "xindexes"):
+ indexed_dims = arr.xindexes.dims
+ else:
+ indexed_dims = {}
+
+ obj_type = f"xarray.{type(arr).__name__}"
+ arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else ""
+
+ header_components = [
+ f"<div class='xr-obj-type'>{obj_type}</div>",
+ f"<div class='xr-array-name'>{arr_name}</div>",
+ format_dims(dims, indexed_dims),
+ ]
+
+ sections = [array_section(arr)]
+ if hasattr(arr, "coords"):
+ sections.append(coord_section(arr.coords))
-children_section = partial(_mapping_section, name='Groups', details_func=
- summarize_datatree_children, max_items_collapse=1, expand_option_name=
- 'display_expand_groups')
+ if hasattr(arr, "xindexes"):
+ indexes = _get_indexes_dict(arr.xindexes)
+ sections.append(index_section(indexes))
+ sections.append(attr_section(arr.attrs))
-def _wrap_datatree_repr(r: str, end: bool=False) ->str:
+ return _obj_repr(arr, header_components, sections)
+
+
+def dataset_repr(ds) -> str:
+ obj_type = f"xarray.{type(ds).__name__}"
+
+ header_components = [f"<div class='xr-obj-type'>{escape(obj_type)}</div>"]
+
+ sections = [
+ dim_section(ds),
+ coord_section(ds.coords),
+ datavar_section(ds.data_vars),
+ index_section(_get_indexes_dict(ds.xindexes)),
+ attr_section(ds.attrs),
+ ]
+
+ return _obj_repr(ds, header_components, sections)
+
+
+def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
+ N_CHILDREN = len(children) - 1
+
+ # Get result from datatree_node_repr and wrap it
+ lines_callback = lambda n, c, end: _wrap_datatree_repr(
+ datatree_node_repr(n, c), end=end
+ )
+
+ children_html = "".join(
+ (
+ lines_callback(n, c, end=False) # Long lines
+ if i < N_CHILDREN
+ else lines_callback(n, c, end=True)
+ ) # Short lines
+ for i, (n, c) in enumerate(children.items())
+ )
+
+ return "".join(
+ [
+ "<div style='display: inline-grid; grid-template-columns: 100%; grid-column: 1 / -1'>",
+ children_html,
+ "</div>",
+ ]
+ )
+
+
+children_section = partial(
+ _mapping_section,
+ name="Groups",
+ details_func=summarize_datatree_children,
+ max_items_collapse=1,
+ expand_option_name="display_expand_groups",
+)
+
+
+def datatree_node_repr(group_title: str, dt: DataTree) -> str:
+ header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]
+
+ ds = dt._to_dataset_view(rebuild_dims=False)
+
+ sections = [
+ children_section(dt.children),
+ dim_section(ds),
+ coord_section(ds.coords),
+ datavar_section(ds.data_vars),
+ attr_section(ds.attrs),
+ ]
+
+ return _obj_repr(ds, header_components, sections)
+
+
+def _wrap_datatree_repr(r: str, end: bool = False) -> str:
"""
Wrap HTML representation with a tee to the left of it.
@@ -90,4 +436,39 @@ def _wrap_datatree_repr(r: str, end: bool=False) ->str:
Tee color is set to the variable :code:`--xr-border-color`.
"""
- pass
+ # height of line
+ end = bool(end)
+ height = "100%" if end is False else "1.2em"
+ return "".join(
+ [
+ "<div style='display: inline-grid; grid-template-columns: 0px 20px auto; width: 100%;'>",
+ "<div style='",
+ "grid-column-start: 1;",
+ "border-right: 0.2em solid;",
+ "border-color: var(--xr-border-color);",
+ f"height: {height};",
+ "width: 0px;",
+ "'>",
+ "</div>",
+ "<div style='",
+ "grid-column-start: 2;",
+ "grid-row-start: 1;",
+ "height: 1em;",
+ "width: 20px;",
+ "border-bottom: 0.2em solid;",
+ "border-color: var(--xr-border-color);",
+ "'>",
+ "</div>",
+ "<div style='",
+ "grid-column-start: 3;",
+ "'>",
+ r,
+ "</div>",
+ "</div>",
+ ]
+ )
+
+
+def datatree_repr(dt: DataTree) -> str:
+ obj_type = f"datatree.{type(dt).__name__}"
+ return datatree_node_repr(obj_type, dt)
diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py
index 8807366d..9b0758d0 100644
--- a/xarray/core/groupby.py
+++ b/xarray/core/groupby.py
@@ -1,28 +1,55 @@
from __future__ import annotations
+
import copy
import warnings
from collections.abc import Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union
+
import numpy as np
import pandas as pd
from packaging.version import Version
+
from xarray.core import dtypes, duck_array_ops, nputils, ops
-from xarray.core._aggregations import DataArrayGroupByAggregations, DatasetGroupByAggregations
+from xarray.core._aggregations import (
+ DataArrayGroupByAggregations,
+ DatasetGroupByAggregations,
+)
from xarray.core.alignment import align
from xarray.core.arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
from xarray.core.concat import concat
from xarray.core.coordinates import Coordinates
from xarray.core.formatting import format_array_flat
-from xarray.core.indexes import PandasIndex, create_default_index_implicit, filter_indexes_from_coords
+from xarray.core.indexes import (
+ PandasIndex,
+ create_default_index_implicit,
+ filter_indexes_from_coords,
+)
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_DataWithCoords, T_Xarray
-from xarray.core.utils import FrozenMappingWarningOnValuesAccess, contains_only_chunked_or_numpy, either_dict_or_kwargs, hashable, is_scalar, maybe_wrap_array, module_available, peek_at
+from xarray.core.types import (
+ Dims,
+ QuantileMethods,
+ T_DataArray,
+ T_DataWithCoords,
+ T_Xarray,
+)
+from xarray.core.utils import (
+ FrozenMappingWarningOnValuesAccess,
+ contains_only_chunked_or_numpy,
+ either_dict_or_kwargs,
+ hashable,
+ is_scalar,
+ maybe_wrap_array,
+ module_available,
+ peek_at,
+)
from xarray.core.variable import IndexVariable, Variable
from xarray.util.deprecation_helpers import _deprecate_positional_args
+
if TYPE_CHECKING:
from numpy.typing import ArrayLike
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
@@ -30,13 +57,86 @@ if TYPE_CHECKING:
from xarray.groupers import Grouper
-def _consolidate_slices(slices: list[slice]) ->list[slice]:
+def check_reduce_dims(reduce_dims, dimensions):
+ if reduce_dims is not ...:
+ if is_scalar(reduce_dims):
+ reduce_dims = [reduce_dims]
+ if any(dim not in dimensions for dim in reduce_dims):
+ raise ValueError(
+ f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' "
+ f"to reduce over all dimensions or one or more of {dimensions!r}."
+ )
+
+
+def _codes_to_group_indices(inverse: np.ndarray, N: int) -> GroupIndices:
+ assert inverse.ndim == 1
+ groups: GroupIndices = tuple([] for _ in range(N))
+ for n, g in enumerate(inverse):
+ if g >= 0:
+ groups[g].append(n)
+ return groups
+
+
+def _dummy_copy(xarray_obj):
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ if isinstance(xarray_obj, Dataset):
+ res = Dataset(
+ {
+ k: dtypes.get_fill_value(v.dtype)
+ for k, v in xarray_obj.data_vars.items()
+ },
+ {
+ k: dtypes.get_fill_value(v.dtype)
+ for k, v in xarray_obj.coords.items()
+ if k not in xarray_obj.dims
+ },
+ xarray_obj.attrs,
+ )
+ elif isinstance(xarray_obj, DataArray):
+ res = DataArray(
+ dtypes.get_fill_value(xarray_obj.dtype),
+ {
+ k: dtypes.get_fill_value(v.dtype)
+ for k, v in xarray_obj.coords.items()
+ if k not in xarray_obj.dims
+ },
+ dims=[],
+ name=xarray_obj.name,
+ attrs=xarray_obj.attrs,
+ )
+ else: # pragma: no cover
+ raise AssertionError
+ return res
+
+
+def _is_one_or_none(obj) -> bool:
+ return obj == 1 or obj is None
+
+
+def _consolidate_slices(slices: list[slice]) -> list[slice]:
"""Consolidate adjacent slices in a list of slices."""
- pass
+ result: list[slice] = []
+ last_slice = slice(None)
+ for slice_ in slices:
+ if not isinstance(slice_, slice):
+ raise ValueError(f"list element is not a slice: {slice_!r}")
+ if (
+ result
+ and last_slice.stop == slice_.start
+ and _is_one_or_none(last_slice.step)
+ and _is_one_or_none(slice_.step)
+ ):
+ last_slice = slice(last_slice.start, slice_.stop, slice_.step)
+ result[-1] = last_slice
+ else:
+ result.append(slice_)
+ last_slice = slice_
+ return result
-def _inverse_permutation_indices(positions, N: (int | None)=None) ->(np.
- ndarray | None):
+def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray | None:
"""Like inverse_permutation, but also handles slices.
Parameters
@@ -48,7 +148,17 @@ def _inverse_permutation_indices(positions, N: (int | None)=None) ->(np.
-------
np.ndarray of indices or None, if no permutation is necessary.
"""
- pass
+ if not positions:
+ return None
+
+ if isinstance(positions[0], slice):
+ positions = _consolidate_slices(positions)
+ if positions == slice(None):
+ return None
+ positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions]
+
+ newpositions = nputils.inverse_permutation(np.concatenate(positions), N)
+ return newpositions[newpositions != -1]
class _DummyGroup(Generic[T_Xarray]):
@@ -56,27 +166,91 @@ class _DummyGroup(Generic[T_Xarray]):
Should not be user visible.
"""
- __slots__ = 'name', 'coords', 'size', 'dataarray'
- def __init__(self, obj: T_Xarray, name: Hashable, coords) ->None:
+ __slots__ = ("name", "coords", "size", "dataarray")
+
+ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None:
self.name = name
self.coords = coords
self.size = obj.sizes[name]
- def __array__(self) ->np.ndarray:
+ @property
+ def dims(self) -> tuple[Hashable]:
+ return (self.name,)
+
+ @property
+ def ndim(self) -> Literal[1]:
+ return 1
+
+ @property
+ def values(self) -> range:
+ return range(self.size)
+
+ @property
+ def data(self) -> range:
+ return range(self.size)
+
+ def __array__(self) -> np.ndarray:
return np.arange(self.size)
+ @property
+ def shape(self) -> tuple[int]:
+ return (self.size,)
+
+ @property
+ def attrs(self) -> dict:
+ return {}
+
def __getitem__(self, key):
if isinstance(key, tuple):
- key, = key
+ (key,) = key
return self.values[key]
- def to_array(self) ->DataArray:
+ def to_index(self) -> pd.Index:
+ # could be pd.RangeIndex?
+ return pd.Index(np.arange(self.size))
+
+ def copy(self, deep: bool = True, data: Any = None):
+ raise NotImplementedError
+
+ def to_dataarray(self) -> DataArray:
+ from xarray.core.dataarray import DataArray
+
+ return DataArray(
+ data=self.data, dims=(self.name,), coords=self.coords, name=self.name
+ )
+
+ def to_array(self) -> DataArray:
"""Deprecated version of to_dataarray."""
- pass
+ return self.to_dataarray()
-T_Group = Union['T_DataArray', _DummyGroup]
+T_Group = Union["T_DataArray", _DummyGroup]
+
+
+def _ensure_1d(group: T_Group, obj: T_DataWithCoords) -> tuple[
+ T_Group,
+ T_DataWithCoords,
+ Hashable | None,
+ list[Hashable],
+]:
+ # 1D cases: do nothing
+ if isinstance(group, _DummyGroup) or group.ndim == 1:
+ return group, obj, None, []
+
+ from xarray.core.dataarray import DataArray
+
+ if isinstance(group, DataArray):
+ # try to stack the dims of the group into a single dim
+ orig_dims = group.dims
+ stacked_dim = "stacked_" + "_".join(map(str, orig_dims))
+ # these dimensions get created by the stack operation
+ inserted_dims = [dim for dim in group.dims if dim not in group.coords]
+ newgroup = group.stack({stacked_dim: orig_dims})
+ newobj = obj.stack({stacked_dim: orig_dims})
+ return newgroup, newobj, stacked_dim, inserted_dims
+
+ raise TypeError(f"group must be DataArray or _DummyGroup, got {type(group)!r}.")
@dataclass
@@ -93,39 +267,148 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
This class is private API, while Groupers are public.
"""
+
grouper: Grouper
group: T_Group
obj: T_DataWithCoords
+
+ # returned by factorize:
codes: DataArray = field(init=False, repr=False)
full_index: pd.Index = field(init=False, repr=False)
group_indices: GroupIndices = field(init=False, repr=False)
unique_coord: Variable | _DummyGroup = field(init=False, repr=False)
+
+ # _ensure_1d:
group1d: T_Group = field(init=False, repr=False)
stacked_obj: T_DataWithCoords = field(init=False, repr=False)
stacked_dim: Hashable | None = field(init=False, repr=False)
inserted_dims: list[Hashable] = field(init=False, repr=False)
- def __post_init__(self) ->None:
+ def __post_init__(self) -> None:
+ # This copy allows the BinGrouper.factorize() method
+ # to update BinGrouper.bins when provided as int, using the output
+ # of pd.cut
+ # We do not want to modify the original object, since the same grouper
+ # might be used multiple times.
self.grouper = copy.deepcopy(self.grouper)
+
self.group = _resolve_group(self.obj, self.group)
- (self.group1d, self.stacked_obj, self.stacked_dim, self.inserted_dims
- ) = _ensure_1d(group=self.group, obj=self.obj)
+
+ (
+ self.group1d,
+ self.stacked_obj,
+ self.stacked_dim,
+ self.inserted_dims,
+ ) = _ensure_1d(group=self.group, obj=self.obj)
+
self.factorize()
@property
- def name(self) ->Hashable:
+ def name(self) -> Hashable:
"""Name for the grouped coordinate after reduction."""
- pass
+ # the name has to come from unique_coord because we need `_bins` suffix for BinGrouper
+ (name,) = self.unique_coord.dims
+ return name
@property
- def size(self) ->int:
+ def size(self) -> int:
"""Number of groups."""
- pass
+ return len(self)
- def __len__(self) ->int:
+ def __len__(self) -> int:
"""Number of groups."""
return len(self.full_index)
+ @property
+ def dims(self):
+ return self.group1d.dims
+
+ def factorize(self) -> None:
+ encoded = self.grouper.factorize(self.group1d)
+
+ self.codes = encoded.codes
+ self.full_index = encoded.full_index
+
+ if encoded.group_indices is not None:
+ self.group_indices = encoded.group_indices
+ else:
+ self.group_indices = tuple(
+ g
+ for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
+ if g
+ )
+ if encoded.unique_coord is None:
+ unique_values = self.full_index[np.unique(encoded.codes)]
+ self.unique_coord = Variable(
+ dims=self.codes.name, data=unique_values, attrs=self.group.attrs
+ )
+ else:
+ self.unique_coord = encoded.unique_coord
+
+
+def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
+ # While we don't generally check the type of every arg, passing
+ # multiple dimensions as multiple arguments is common enough, and the
+ # consequences hidden enough (strings evaluate as true) to warrant
+ # checking here.
+ # A future version could make squeeze kwarg only, but would face
+ # backward-compat issues.
+ if squeeze is not False:
+ raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
+
+
+def _resolve_group(
+ obj: T_DataWithCoords, group: T_Group | Hashable | IndexVariable
+) -> T_Group:
+ from xarray.core.dataarray import DataArray
+
+ error_msg = (
+ "the group variable's length does not "
+ "match the length of this variable along its "
+ "dimensions"
+ )
+
+ newgroup: T_Group
+ if isinstance(group, DataArray):
+ try:
+ align(obj, group, join="exact", copy=False)
+ except ValueError:
+ raise ValueError(error_msg)
+
+ newgroup = group.copy(deep=False)
+ newgroup.name = group.name or "group"
+
+ elif isinstance(group, IndexVariable):
+ # This assumption is built in to _ensure_1d.
+ if group.ndim != 1:
+ raise ValueError(
+ "Grouping by multi-dimensional IndexVariables is not allowed."
+ "Convert to and pass a DataArray instead."
+ )
+ (group_dim,) = group.dims
+ if len(group) != obj.sizes[group_dim]:
+ raise ValueError(error_msg)
+ newgroup = DataArray(group)
+
+ else:
+ if not hashable(group):
+ raise TypeError(
+ "`group` must be an xarray.DataArray or the "
+ "name of an xarray variable or dimension. "
+ f"Received {group!r} instead."
+ )
+ group_da: DataArray = obj[group]
+ if group_da.name not in obj._indexes and group_da.name in obj.dims:
+ # DummyGroups should not appear on groupby results
+ newgroup = _DummyGroup(obj, group_da.name, group_da.coords)
+ else:
+ newgroup = group_da
+
+ if newgroup.size == 0:
+ raise ValueError(f"{newgroup.name} must not be empty")
+
+ return newgroup
+
class GroupBy(Generic[T_Xarray]):
"""A object that implements the split-apply-combine pattern.
@@ -143,24 +426,47 @@ class GroupBy(Generic[T_Xarray]):
Dataset.groupby
DataArray.groupby
"""
- __slots__ = ('_full_index', '_inserted_dims', '_group', '_group_dim',
- '_group_indices', '_groups', 'groupers', '_obj',
- '_restore_coord_dims', '_stacked_dim', '_unique_coord', '_dims',
- '_sizes', '_original_obj', '_original_group', '_bins', '_codes')
+
+ __slots__ = (
+ "_full_index",
+ "_inserted_dims",
+ "_group",
+ "_group_dim",
+ "_group_indices",
+ "_groups",
+ "groupers",
+ "_obj",
+ "_restore_coord_dims",
+ "_stacked_dim",
+ "_unique_coord",
+ "_dims",
+ "_sizes",
+ # Save unstacked object for flox
+ "_original_obj",
+ "_original_group",
+ "_bins",
+ "_codes",
+ )
_obj: T_Xarray
groupers: tuple[ResolvedGrouper]
_restore_coord_dims: bool
+
_original_obj: T_Xarray
_original_group: T_Group
_group_indices: GroupIndices
_codes: DataArray
_group_dim: Hashable
+
_groups: dict[GroupKey, GroupIndex] | None
_dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None
_sizes: Mapping[Hashable, int] | None
- def __init__(self, obj: T_Xarray, groupers: tuple[ResolvedGrouper],
- restore_coord_dims: bool=True) ->None:
+ def __init__(
+ self,
+ obj: T_Xarray,
+ groupers: tuple[ResolvedGrouper],
+ restore_coord_dims: bool = True,
+ ) -> None:
"""Create a GroupBy object
Parameters
@@ -174,20 +480,28 @@ class GroupBy(Generic[T_Xarray]):
coordinates.
"""
self.groupers = groupers
+
self._original_obj = obj
- grouper, = self.groupers
+
+ (grouper,) = self.groupers
self._original_group = grouper.group
+
+ # specification for the groupby operation
self._obj = grouper.stacked_obj
self._restore_coord_dims = restore_coord_dims
+
+ # These should generalize to multiple groupers
self._group_indices = grouper.group_indices
self._codes = self._maybe_unstack(grouper.codes)
- self._group_dim, = grouper.group1d.dims
+
+ (self._group_dim,) = grouper.group1d.dims
+ # cached attributes
self._groups = None
self._dims = None
self._sizes = None
@property
- def sizes(self) ->Mapping[Hashable, int]:
+ def sizes(self) -> Mapping[Hashable, int]:
"""Ordered mapping from dimension names to lengths.
Immutable.
@@ -197,57 +511,327 @@ class GroupBy(Generic[T_Xarray]):
DataArray.sizes
Dataset.sizes
"""
- pass
+ if self._sizes is None:
+ (grouper,) = self.groupers
+ index = self._group_indices[0]
+ self._sizes = self._obj.isel({self._group_dim: index}).sizes
+ return self._sizes
+
+ def map(
+ self,
+ func: Callable,
+ args: tuple[Any, ...] = (),
+ shortcut: bool | None = None,
+ **kwargs: Any,
+ ) -> T_Xarray:
+ raise NotImplementedError()
+
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ shortcut: bool = True,
+ **kwargs: Any,
+ ) -> T_Xarray:
+ raise NotImplementedError()
@property
- def groups(self) ->dict[GroupKey, GroupIndex]:
+ def groups(self) -> dict[GroupKey, GroupIndex]:
"""
Mapping from group labels to indices. The indices can be used to index the underlying object.
"""
- pass
+ # provided to mimic pandas.groupby
+ if self._groups is None:
+ (grouper,) = self.groupers
+ self._groups = dict(zip(grouper.unique_coord.values, self._group_indices))
+ return self._groups
- def __getitem__(self, key: GroupKey) ->T_Xarray:
+ def __getitem__(self, key: GroupKey) -> T_Xarray:
"""
Get DataArray or Dataset corresponding to a particular group label.
"""
- grouper, = self.groupers
+ (grouper,) = self.groupers
return self._obj.isel({self._group_dim: self.groups[key]})
- def __len__(self) ->int:
- grouper, = self.groupers
+ def __len__(self) -> int:
+ (grouper,) = self.groupers
return grouper.size
- def __iter__(self) ->Iterator[tuple[GroupKey, T_Xarray]]:
- grouper, = self.groupers
+ def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]:
+ (grouper,) = self.groupers
return zip(grouper.unique_coord.data, self._iter_grouped())
- def __repr__(self) ->str:
- grouper, = self.groupers
- return '{}, grouped over {!r}\n{!r} groups with labels {}.'.format(self
- .__class__.__name__, grouper.name, grouper.full_index.size,
- ', '.join(format_array_flat(grouper.full_index, 30).split()))
+ def __repr__(self) -> str:
+ (grouper,) = self.groupers
+ return "{}, grouped over {!r}\n{!r} groups with labels {}.".format(
+ self.__class__.__name__,
+ grouper.name,
+ grouper.full_index.size,
+ ", ".join(format_array_flat(grouper.full_index, 30).split()),
+ )
- def _iter_grouped(self) ->Iterator[T_Xarray]:
+ def _iter_grouped(self) -> Iterator[T_Xarray]:
"""Iterate over each element in this group"""
- pass
+ (grouper,) = self.groupers
+ for idx, indices in enumerate(self._group_indices):
+ yield self._obj.isel({self._group_dim: indices})
+
+ def _infer_concat_args(self, applied_example):
+ from xarray.groupers import BinGrouper
+
+ (grouper,) = self.groupers
+ if self._group_dim in applied_example.dims:
+ coord = grouper.group1d
+ positions = self._group_indices
+ else:
+ coord = grouper.unique_coord
+ positions = None
+ (dim,) = coord.dims
+ if isinstance(grouper.group, _DummyGroup) and not isinstance(
+ grouper.grouper, BinGrouper
+ ):
+ # When binning we actually do set the index
+ coord = None
+ coord = getattr(coord, "variable", coord)
+ return coord, dim, positions
+
+ def _binary_op(self, other, f, reflexive=False):
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ g = f if not reflexive else lambda x, y: f(y, x)
+
+ (grouper,) = self.groupers
+ obj = self._original_obj
+ name = grouper.name
+ group = grouper.group
+ codes = self._codes
+ dims = group.dims
+
+ if isinstance(group, _DummyGroup):
+ group = coord = group.to_dataarray()
+ else:
+ coord = grouper.unique_coord
+ if isinstance(coord, Variable):
+ assert coord.ndim == 1
+ (coord_dim,) = coord.dims
+ # TODO: explicitly create Index here
+ coord = DataArray(coord, coords={coord_dim: coord.data})
+
+ if not isinstance(other, (Dataset, DataArray)):
+ raise TypeError(
+ "GroupBy objects only support binary ops "
+ "when the other argument is a Dataset or "
+ "DataArray"
+ )
+
+ if name not in other.dims:
+ raise ValueError(
+ "incompatible dimensions for a grouped "
+ f"binary operation: the group variable {name!r} "
+ "is not a dimension on the other argument "
+ f"with dimensions {other.dims!r}"
+ )
+
+ # Broadcast out scalars for backwards compatibility
+ # TODO: get rid of this when fixing GH2145
+ for var in other.coords:
+ if other[var].ndim == 0:
+ other[var] = (
+ other[var].drop_vars(var).expand_dims({name: other.sizes[name]})
+ )
+
+ # need to handle NaNs in group or elements that don't belong to any bins
+ mask = codes == -1
+ if mask.any():
+ obj = obj.where(~mask, drop=True)
+ group = group.where(~mask, drop=True)
+ codes = codes.where(~mask, drop=True).astype(int)
+
+ # if other is dask-backed, that's a hint that the
+ # "expanded" dataset is too big to hold in memory.
+ # this can be the case when `other` was read from disk
+ # and contains our lazy indexing classes
+ # We need to check for dask-backed Datasets
+ # so utils.is_duck_dask_array does not work for this check
+ if obj.chunks and not other.chunks:
+ # TODO: What about datasets with some dask vars, and others not?
+ # This handles dims other than `name``
+ chunks = {k: v for k, v in obj.chunksizes.items() if k in other.dims}
+ # a chunk size of 1 seems reasonable since we expect individual elements of
+ # other to be repeated multiple times across the reduced dimension(s)
+ chunks[name] = 1
+ other = other.chunk(chunks)
+
+ # codes are defined for coord, so we align `other` with `coord`
+ # before indexing
+ other, _ = align(other, coord, join="right", copy=False)
+ expanded = other.isel({name: codes})
+
+ result = g(obj, expanded)
+
+ if group.ndim > 1:
+ # backcompat:
+ # TODO: get rid of this when fixing GH2145
+ for var in set(obj.coords) - set(obj.xindexes):
+ if set(obj[var].dims) < set(group.dims):
+ result[var] = obj[var].reset_coords(drop=True).broadcast_like(group)
+
+ if isinstance(result, Dataset) and isinstance(obj, Dataset):
+ for var in set(result):
+ for d in dims:
+ if d not in obj[var].dims:
+ result[var] = result[var].transpose(d, ...)
+ return result
+
+ def _restore_dim_order(self, stacked):
+ raise NotImplementedError
def _maybe_restore_empty_groups(self, combined):
"""Our index contained empty groups (e.g., from a resampling or binning). If we
reduced on that dimension, we want to restore the full index.
"""
- pass
+ from xarray.groupers import BinGrouper, TimeResampler
+
+ (grouper,) = self.groupers
+ if (
+ isinstance(grouper.grouper, (BinGrouper, TimeResampler))
+ and grouper.name in combined.dims
+ ):
+ indexers = {grouper.name: grouper.full_index}
+ combined = combined.reindex(**indexers)
+ return combined
def _maybe_unstack(self, obj):
"""This gets called if we are applying on an array with a
multidimensional group."""
- pass
-
- def _flox_reduce(self, dim: Dims, keep_attrs: (bool | None)=None, **
- kwargs: Any):
+ (grouper,) = self.groupers
+ stacked_dim = grouper.stacked_dim
+ inserted_dims = grouper.inserted_dims
+ if stacked_dim is not None and stacked_dim in obj.dims:
+ obj = obj.unstack(stacked_dim)
+ for dim in inserted_dims:
+ if dim in obj.coords:
+ del obj.coords[dim]
+ obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords))
+ return obj
+
+ def _flox_reduce(
+ self,
+ dim: Dims,
+ keep_attrs: bool | None = None,
+ **kwargs: Any,
+ ):
"""Adaptor function that translates our groupby API to that of flox."""
- pass
-
- def fillna(self, value: Any) ->T_Xarray:
+ import flox
+ from flox.xarray import xarray_reduce
+
+ from xarray.core.dataset import Dataset
+ from xarray.groupers import BinGrouper
+
+ obj = self._original_obj
+ (grouper,) = self.groupers
+ name = grouper.name
+ isbin = isinstance(grouper.grouper, BinGrouper)
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ if Version(flox.__version__) < Version("0.9"):
+ # preserve current strategy (approximately) for dask groupby
+ # on older flox versions to prevent surprises.
+ # flox >=0.9 will choose this on its own.
+ kwargs.setdefault("method", "cohorts")
+
+ numeric_only = kwargs.pop("numeric_only", None)
+ if numeric_only:
+ non_numeric = {
+ name: var
+ for name, var in obj.data_vars.items()
+ if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_))
+ }
+ else:
+ non_numeric = {}
+
+ if "min_count" in kwargs:
+ if kwargs["func"] not in ["sum", "prod"]:
+ raise TypeError("Received an unexpected keyword argument 'min_count'")
+ elif kwargs["min_count"] is None:
+ # set explicitly to avoid unnecessarily accumulating count
+ kwargs["min_count"] = 0
+
+ unindexed_dims: tuple[Hashable, ...] = tuple()
+ if isinstance(grouper.group, _DummyGroup) and not isbin:
+ unindexed_dims = (name,)
+
+ parsed_dim: tuple[Hashable, ...]
+ if isinstance(dim, str):
+ parsed_dim = (dim,)
+ elif dim is None:
+ parsed_dim = grouper.group.dims
+ elif dim is ...:
+ parsed_dim = tuple(obj.dims)
+ else:
+ parsed_dim = tuple(dim)
+
+ # Do this so we raise the same error message whether flox is present or not.
+ # Better to control it here than in flox.
+ if any(d not in grouper.group.dims and d not in obj.dims for d in parsed_dim):
+ raise ValueError(f"cannot reduce over dimensions {dim}.")
+
+ if kwargs["func"] not in ["all", "any", "count"]:
+ kwargs.setdefault("fill_value", np.nan)
+ if isbin and kwargs["func"] == "count":
+ # This is an annoying hack. Xarray returns np.nan
+ # when there are no observations in a bin, instead of 0.
+ # We can fake that here by forcing min_count=1.
+ # note min_count makes no sense in the xarray world
+ # as a kwarg for count, so this should be OK
+ kwargs.setdefault("fill_value", np.nan)
+ kwargs.setdefault("min_count", 1)
+
+ output_index = grouper.full_index
+ result = xarray_reduce(
+ obj.drop_vars(non_numeric.keys()),
+ self._codes,
+ dim=parsed_dim,
+ # pass RangeIndex as a hint to flox that `by` is already factorized
+ expected_groups=(pd.RangeIndex(len(output_index)),),
+ isbin=False,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ # we did end up reducing over dimension(s) that are
+ # in the grouped variable
+ group_dims = grouper.group.dims
+ if set(group_dims).issubset(set(parsed_dim)):
+ result = result.assign_coords(
+ Coordinates(
+ coords={name: (name, np.array(output_index))},
+ indexes={name: PandasIndex(output_index, dim=name)},
+ )
+ )
+ result = result.drop_vars(unindexed_dims)
+
+ # broadcast and restore non-numeric data variables (backcompat)
+ for name, var in non_numeric.items():
+ if all(d not in var.dims for d in parsed_dim):
+ result[name] = var.variable.set_dims(
+ (name,) + var.dims, (result.sizes[name],) + var.shape
+ )
+
+ if not isinstance(result, Dataset):
+ # only restore dimension order for arrays
+ result = self._restore_dim_order(result)
+
+ return result
+
+ def fillna(self, value: Any) -> T_Xarray:
"""Fill missing values in this object by group.
This operation follows the normal broadcasting and alignment rules that
@@ -271,13 +855,19 @@ class GroupBy(Generic[T_Xarray]):
Dataset.fillna
DataArray.fillna
"""
- pass
-
- @_deprecate_positional_args('v2023.10.0')
- def quantile(self, q: ArrayLike, dim: Dims=None, *, method:
- QuantileMethods='linear', keep_attrs: (bool | None)=None, skipna: (
- bool | None)=None, interpolation: (QuantileMethods | None)=None
- ) ->T_Xarray:
+ return ops.fillna(self, value)
+
+ @_deprecate_positional_args("v2023.10.0")
+ def quantile(
+ self,
+ q: ArrayLike,
+ dim: Dims = None,
+ *,
+ method: QuantileMethods = "linear",
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ interpolation: QuantileMethods | None = None,
+ ) -> T_Xarray:
"""Compute the qth quantile over each array in the groups and
concatenate them together into a new array.
@@ -394,9 +984,36 @@ class GroupBy(Generic[T_Xarray]):
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
- pass
-
- def where(self, cond, other=dtypes.NA) ->T_Xarray:
+ if dim is None:
+ (grouper,) = self.groupers
+ dim = grouper.group1d.dims
+
+ # Dataset.quantile does this, do it for flox to ensure same output.
+ q = np.asarray(q, dtype=np.float64)
+
+ if (
+ method == "linear"
+ and OPTIONS["use_flox"]
+ and contains_only_chunked_or_numpy(self._obj)
+ and module_available("flox", minversion="0.9.4")
+ ):
+ result = self._flox_reduce(
+ func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna
+ )
+ return result
+ else:
+ return self.map(
+ self._obj.__class__.quantile,
+ shortcut=False,
+ q=q,
+ dim=dim,
+ method=method,
+ keep_attrs=keep_attrs,
+ skipna=skipna,
+ interpolation=interpolation,
+ )
+
+ def where(self, cond, other=dtypes.NA) -> T_Xarray:
"""Return elements from `self` or `other` depending on `cond`.
Parameters
@@ -415,16 +1032,30 @@ class GroupBy(Generic[T_Xarray]):
--------
Dataset.where
"""
- pass
+ return ops.where_method(self, cond, other)
- def first(self, skipna: (bool | None)=None, keep_attrs: (bool | None)=None
+ def _first_or_last(self, op, skipna, keep_attrs):
+ if all(
+ isinstance(maybe_slice, slice)
+ and (maybe_slice.stop == maybe_slice.start + 1)
+ for maybe_slice in self._group_indices
):
+ # NB. this is currently only used for reductions along an existing
+ # dimension
+ return self._obj
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+ return self.reduce(
+ op, dim=[self._group_dim], skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ def first(self, skipna: bool | None = None, keep_attrs: bool | None = None):
"""Return the first element of each group along the group dimension"""
- pass
+ return self._first_or_last(duck_array_ops.first, skipna, keep_attrs)
- def last(self, skipna: (bool | None)=None, keep_attrs: (bool | None)=None):
+ def last(self, skipna: bool | None = None, keep_attrs: bool | None = None):
"""Return the last element of each group along the group dimension"""
- pass
+ return self._first_or_last(duck_array_ops.last, skipna, keep_attrs)
def assign_coords(self, coords=None, **coords_kwargs):
"""Assign coordinates by group.
@@ -434,22 +1065,78 @@ class GroupBy(Generic[T_Xarray]):
Dataset.assign_coords
Dataset.swap_dims
"""
- pass
+ coords_kwargs = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords")
+ return self.map(lambda ds: ds.assign_coords(**coords_kwargs))
+
+
+def _maybe_reorder(xarray_obj, dim, positions, N: int | None):
+ order = _inverse_permutation_indices(positions, N)
+ if order is None or len(order) != xarray_obj.sizes[dim]:
+ return xarray_obj
+ else:
+ return xarray_obj[{dim: order}]
-class DataArrayGroupByBase(GroupBy['DataArray'], DataArrayGroupbyArithmetic):
+
+class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic):
"""GroupBy object specialized to grouping DataArray objects"""
+
__slots__ = ()
_dims: tuple[Hashable, ...] | None
+ @property
+ def dims(self) -> tuple[Hashable, ...]:
+ if self._dims is None:
+ (grouper,) = self.groupers
+ index = self._group_indices[0]
+ self._dims = self._obj.isel({self._group_dim: index}).dims
+ return self._dims
+
def _iter_grouped_shortcut(self):
"""Fast version of `_iter_grouped` that yields Variables without
metadata
"""
- pass
-
- def map(self, func: Callable[..., DataArray], args: tuple[Any, ...]=(),
- shortcut: (bool | None)=None, **kwargs: Any) ->DataArray:
+ var = self._obj.variable
+ (grouper,) = self.groupers
+ for idx, indices in enumerate(self._group_indices):
+ yield var[{self._group_dim: indices}]
+
+ def _concat_shortcut(self, applied, dim, positions=None):
+ # nb. don't worry too much about maintaining this method -- it does
+ # speed things up, but it's not very interpretable and there are much
+ # faster alternatives (e.g., doing the grouped aggregation in a
+ # compiled language)
+ # TODO: benbovy - explicit indexes: this fast implementation doesn't
+ # create an explicit index for the stacked dim coordinate
+ stacked = Variable.concat(applied, dim, shortcut=True)
+
+ (grouper,) = self.groupers
+ reordered = _maybe_reorder(stacked, dim, positions, N=grouper.group.size)
+ return self._obj._replace_maybe_drop_dims(reordered)
+
+ def _restore_dim_order(self, stacked: DataArray) -> DataArray:
+ (grouper,) = self.groupers
+ group = grouper.group1d
+
+ def lookup_order(dimension):
+ if dimension == grouper.name:
+ (dimension,) = group.dims
+ if dimension in self._obj.dims:
+ axis = self._obj.get_axis_num(dimension)
+ else:
+ axis = 1e6 # some arbitrarily high value
+ return axis
+
+ new_order = sorted(stacked.dims, key=lookup_order)
+ return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims)
+
+ def map(
+ self,
+ func: Callable[..., DataArray],
+ args: tuple[Any, ...] = (),
+ shortcut: bool | None = None,
+ **kwargs: Any,
+ ) -> DataArray:
"""Apply a function to each array in the group and concatenate them
together into a new array.
@@ -491,7 +1178,9 @@ class DataArrayGroupByBase(GroupBy['DataArray'], DataArrayGroupbyArithmetic):
applied : DataArray
The result of splitting, applying and combining this array.
"""
- pass
+ grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped()
+ applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped)
+ return self._combine(applied, shortcut=shortcut)
def apply(self, func, shortcut=False, args=(), **kwargs):
"""
@@ -501,15 +1190,47 @@ class DataArrayGroupByBase(GroupBy['DataArray'], DataArrayGroupbyArithmetic):
--------
DataArrayGroupBy.map
"""
- pass
+ warnings.warn(
+ "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.map(func, shortcut=shortcut, args=args, **kwargs)
def _combine(self, applied, shortcut=False):
"""Recombine the applied objects like the original."""
- pass
-
- def reduce(self, func: Callable[..., Any], dim: Dims=None, *, axis: (
- int | Sequence[int] | None)=None, keep_attrs: (bool | None)=None,
- keepdims: bool=False, shortcut: bool=True, **kwargs: Any) ->DataArray:
+ applied_example, applied = peek_at(applied)
+ coord, dim, positions = self._infer_concat_args(applied_example)
+ if shortcut:
+ combined = self._concat_shortcut(applied, dim, positions)
+ else:
+ combined = concat(applied, dim)
+ (grouper,) = self.groupers
+ combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size)
+
+ if isinstance(combined, type(self._obj)):
+ # only restore dimension order for arrays
+ combined = self._restore_dim_order(combined)
+ # assign coord and index when the applied function does not return that coord
+ if coord is not None and dim not in applied_example.dims:
+ index, index_vars = create_default_index_implicit(coord)
+ indexes = {k: index for k in index_vars}
+ combined = combined._overwrite_indexes(indexes, index_vars)
+ combined = self._maybe_restore_empty_groups(combined)
+ combined = self._maybe_unstack(combined)
+ return combined
+
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ shortcut: bool = True,
+ **kwargs: Any,
+ ) -> DataArray:
"""Reduce the items in this group by applying `func` along some
dimension(s).
@@ -539,20 +1260,56 @@ class DataArrayGroupByBase(GroupBy['DataArray'], DataArrayGroupbyArithmetic):
Array with summarized data and the indicated dimension(s)
removed.
"""
- pass
+ if dim is None:
+ dim = [self._group_dim]
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+ def reduce_array(ar: DataArray) -> DataArray:
+ return ar.reduce(
+ func=func,
+ dim=dim,
+ axis=axis,
+ keep_attrs=keep_attrs,
+ keepdims=keepdims,
+ **kwargs,
+ )
-class DataArrayGroupBy(DataArrayGroupByBase, DataArrayGroupByAggregations,
- ImplementsArrayReduce):
+ check_reduce_dims(dim, self.dims)
+
+ return self.map(reduce_array, shortcut=shortcut)
+
+
+# https://github.com/python/mypy/issues/9031
+class DataArrayGroupBy( # type: ignore[misc]
+ DataArrayGroupByBase,
+ DataArrayGroupByAggregations,
+ ImplementsArrayReduce,
+):
__slots__ = ()
-class DatasetGroupByBase(GroupBy['Dataset'], DatasetGroupbyArithmetic):
+class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic):
__slots__ = ()
_dims: Frozen[Hashable, int] | None
- def map(self, func: Callable[..., Dataset], args: tuple[Any, ...]=(),
- shortcut: (bool | None)=None, **kwargs: Any) ->Dataset:
+ @property
+ def dims(self) -> Frozen[Hashable, int]:
+ if self._dims is None:
+ (grouper,) = self.groupers
+ index = self._group_indices[0]
+ self._dims = self._obj.isel({self._group_dim: index}).dims
+
+ return FrozenMappingWarningOnValuesAccess(self._dims)
+
+ def map(
+ self,
+ func: Callable[..., Dataset],
+ args: tuple[Any, ...] = (),
+ shortcut: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""Apply a function to each Dataset in the group and concatenate them
together into a new Dataset.
@@ -582,7 +1339,9 @@ class DatasetGroupByBase(GroupBy['Dataset'], DatasetGroupbyArithmetic):
applied : Dataset
The result of splitting, applying and combining this dataset.
"""
- pass
+ # ignore shortcut if set (for now)
+ applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
+ return self._combine(applied)
def apply(self, func, args=(), shortcut=None, **kwargs):
"""
@@ -592,15 +1351,41 @@ class DatasetGroupByBase(GroupBy['Dataset'], DatasetGroupbyArithmetic):
--------
DatasetGroupBy.map
"""
- pass
+
+ warnings.warn(
+ "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.map(func, shortcut=shortcut, args=args, **kwargs)
def _combine(self, applied):
"""Recombine the applied objects like the original."""
- pass
-
- def reduce(self, func: Callable[..., Any], dim: Dims=None, *, axis: (
- int | Sequence[int] | None)=None, keep_attrs: (bool | None)=None,
- keepdims: bool=False, shortcut: bool=True, **kwargs: Any) ->Dataset:
+ applied_example, applied = peek_at(applied)
+ coord, dim, positions = self._infer_concat_args(applied_example)
+ combined = concat(applied, dim)
+ (grouper,) = self.groupers
+ combined = _maybe_reorder(combined, dim, positions, N=grouper.group.size)
+ # assign coord when the applied function does not return that coord
+ if coord is not None and dim not in applied_example.dims:
+ index, index_vars = create_default_index_implicit(coord)
+ indexes = {k: index for k in index_vars}
+ combined = combined._overwrite_indexes(indexes, index_vars)
+ combined = self._maybe_restore_empty_groups(combined)
+ combined = self._maybe_unstack(combined)
+ return combined
+
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ shortcut: bool = True,
+ **kwargs: Any,
+ ) -> Dataset:
"""Reduce the items in this group by applying `func` along some
dimension(s).
@@ -630,18 +1415,40 @@ class DatasetGroupByBase(GroupBy['Dataset'], DatasetGroupbyArithmetic):
Array with summarized data and the indicated dimension(s)
removed.
"""
- pass
+ if dim is None:
+ dim = [self._group_dim]
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ def reduce_dataset(ds: Dataset) -> Dataset:
+ return ds.reduce(
+ func=func,
+ dim=dim,
+ axis=axis,
+ keep_attrs=keep_attrs,
+ keepdims=keepdims,
+ **kwargs,
+ )
+
+ check_reduce_dims(dim, self.dims)
+
+ return self.map(reduce_dataset)
- def assign(self, **kwargs: Any) ->Dataset:
+ def assign(self, **kwargs: Any) -> Dataset:
"""Assign data variables by group.
See Also
--------
Dataset.assign
"""
- pass
+ return self.map(lambda ds: ds.assign(**kwargs))
-class DatasetGroupBy(DatasetGroupByBase, DatasetGroupByAggregations,
- ImplementsDatasetReduce):
+# https://github.com/python/mypy/issues/9031
+class DatasetGroupBy( # type: ignore[misc]
+ DatasetGroupByBase,
+ DatasetGroupByAggregations,
+ ImplementsDatasetReduce,
+):
__slots__ = ()
diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py
index 78726197..9d8a68ed 100644
--- a/xarray/core/indexes.py
+++ b/xarray/core/indexes.py
@@ -1,18 +1,34 @@
from __future__ import annotations
+
import collections.abc
import copy
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
+
import numpy as np
import pandas as pd
+
from xarray.core import formatting, nputils, utils
-from xarray.core.indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter
-from xarray.core.utils import Frozen, emit_user_level_warning, get_valid_numpy_dtype, is_dict_like, is_scalar
+from xarray.core.indexing import (
+ IndexSelResult,
+ PandasIndexingAdapter,
+ PandasMultiIndexingAdapter,
+)
+from xarray.core.utils import (
+ Frozen,
+ emit_user_level_warning,
+ get_valid_numpy_dtype,
+ is_dict_like,
+ is_scalar,
+)
+
if TYPE_CHECKING:
from xarray.core.types import ErrorOptions, JoinOptions, Self
from xarray.core.variable import Variable
-IndexVars = dict[Any, 'Variable']
+
+
+IndexVars = dict[Any, "Variable"]
class Index:
@@ -43,8 +59,12 @@ class Index:
"""
@classmethod
- def from_variables(cls, variables: Mapping[Any, Variable], *, options:
- Mapping[str, Any]) ->Self:
+ def from_variables(
+ cls,
+ variables: Mapping[Any, Variable],
+ *,
+ options: Mapping[str, Any],
+ ) -> Self:
"""Create a new index object from one or more coordinate variables.
This factory method must be implemented in all subclasses of Index.
@@ -64,11 +84,15 @@ class Index:
index : Index
A new Index object.
"""
- pass
+ raise NotImplementedError()
@classmethod
- def concat(cls, indexes: Sequence[Self], dim: Hashable, positions: (
- Iterable[Iterable[int]] | None)=None) ->Self:
+ def concat(
+ cls,
+ indexes: Sequence[Self],
+ dim: Hashable,
+ positions: Iterable[Iterable[int]] | None = None,
+ ) -> Self:
"""Create a new index by concatenating one or more indexes of the same
type.
@@ -93,10 +117,10 @@ class Index:
index : Index
A new Index object.
"""
- pass
+ raise NotImplementedError()
@classmethod
- def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) ->Self:
+ def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self:
"""Create a new index by stacking coordinate variables into a single new
dimension.
@@ -116,9 +140,11 @@ class Index:
index
A new Index object.
"""
- pass
+ raise NotImplementedError(
+ f"{cls!r} cannot be used for creating an index of stacked coordinates"
+ )
- def unstack(self) ->tuple[dict[Hashable, Index], pd.MultiIndex]:
+ def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]:
"""Unstack a (multi-)index into multiple (single) indexes.
Implementation is optional but required in order to support unstacking
@@ -132,10 +158,11 @@ class Index:
object used to unstack unindexed coordinate variables or data
variables.
"""
- pass
+ raise NotImplementedError()
- def create_variables(self, variables: (Mapping[Any, Variable] | None)=None
- ) ->IndexVars:
+ def create_variables(
+ self, variables: Mapping[Any, Variable] | None = None
+ ) -> IndexVars:
"""Maybe create new coordinate variables from this index.
This method is useful if the index data can be reused as coordinate
@@ -160,9 +187,13 @@ class Index:
Dictionary of :py:class:`Variable` or :py:class:`IndexVariable`
objects.
"""
- pass
+ if variables is not None:
+ # pass through
+ return dict(**variables)
+ else:
+ return {}
- def to_pandas_index(self) ->pd.Index:
+ def to_pandas_index(self) -> pd.Index:
"""Cast this xarray index to a pandas.Index object or raise a
``TypeError`` if this is not supported.
@@ -172,10 +203,11 @@ class Index:
By default it raises a ``TypeError``, unless it is re-implemented in
subclasses of Index.
"""
- pass
+ raise TypeError(f"{self!r} cannot be cast to a pandas.Index object")
- def isel(self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
- ) ->(Self | None):
+ def isel(
+ self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
+ ) -> Self | None:
"""Maybe returns a new index from the current index itself indexed by
positional indexers.
@@ -202,9 +234,9 @@ class Index:
maybe_index : Index
A new Index object or ``None``.
"""
- pass
+ return None
- def sel(self, labels: dict[Any, Any]) ->IndexSelResult:
+ def sel(self, labels: dict[Any, Any]) -> IndexSelResult:
"""Query the index with arbitrary coordinate label indexers.
Implementation is optional but required in order to support label-based
@@ -228,9 +260,9 @@ class Index:
An index query result object that contains dimension positional indexers.
It may also contain new indexes, coordinate variables, etc.
"""
- pass
+ raise NotImplementedError(f"{self!r} doesn't support label-based selection")
- def join(self, other: Self, how: JoinOptions='inner') ->Self:
+ def join(self, other: Self, how: JoinOptions = "inner") -> Self:
"""Return a new index from the combination of this index with another
index of the same type.
@@ -248,9 +280,11 @@ class Index:
joined : Index
A new Index object.
"""
- pass
+ raise NotImplementedError(
+ f"{self!r} doesn't support alignment with inner/outer join method"
+ )
- def reindex_like(self, other: Self) ->dict[Hashable, Any]:
+ def reindex_like(self, other: Self) -> dict[Hashable, Any]:
"""Query the index with another index of the same type.
Implementation is optional but required in order to support alignment.
@@ -266,9 +300,9 @@ class Index:
A dictionary where keys are dimension names and values are positional
indexers.
"""
- pass
+ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels")
- def equals(self, other: Self) ->bool:
+ def equals(self, other: Self) -> bool:
"""Compare this index with another index of the same type.
Implementation is optional but required in order to support alignment.
@@ -283,9 +317,9 @@ class Index:
is_equal : bool
``True`` if the indexes are equal, ``False`` otherwise.
"""
- pass
+ raise NotImplementedError()
- def roll(self, shifts: Mapping[Any, int]) ->(Self | None):
+ def roll(self, shifts: Mapping[Any, int]) -> Self | None:
"""Roll this index by an offset along one or more dimensions.
This method can be re-implemented in subclasses of Index, e.g., when the
@@ -308,10 +342,13 @@ class Index:
rolled : Index
A new index with rolled data.
"""
- pass
+ return None
- def rename(self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[
- Any, Hashable]) ->Self:
+ def rename(
+ self,
+ name_dict: Mapping[Any, Hashable],
+ dims_dict: Mapping[Any, Hashable],
+ ) -> Self:
"""Maybe update the index with new coordinate and dimension names.
This method should be re-implemented in subclasses of Index if it has
@@ -336,9 +373,9 @@ class Index:
renamed : Index
Index with renamed attributes.
"""
- pass
+ return self
- def copy(self, deep: bool=True) ->Self:
+ def copy(self, deep: bool = True) -> Self:
"""Return a (deep) copy of this index.
Implementation in subclasses of Index is optional. The base class
@@ -355,19 +392,44 @@ class Index:
index : Index
A new Index object.
"""
- pass
+ return self._copy(deep=deep)
- def __copy__(self) ->Self:
+ def __copy__(self) -> Self:
return self.copy(deep=False)
- def __deepcopy__(self, memo: (dict[int, Any] | None)=None) ->Index:
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index:
return self._copy(deep=True, memo=memo)
- def __getitem__(self, indexer: Any) ->Self:
+ def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self:
+ cls = self.__class__
+ copied = cls.__new__(cls)
+ if deep:
+ for k, v in self.__dict__.items():
+ setattr(copied, k, copy.deepcopy(v, memo))
+ else:
+ copied.__dict__.update(self.__dict__)
+ return copied
+
+ def __getitem__(self, indexer: Any) -> Self:
raise NotImplementedError()
+ def _repr_inline_(self, max_width):
+ return self.__class__.__name__
+
+
+def _maybe_cast_to_cftimeindex(index: pd.Index) -> pd.Index:
+ from xarray.coding.cftimeindex import CFTimeIndex
-def safe_cast_to_index(array: Any) ->pd.Index:
+ if len(index) > 0 and index.dtype == "O" and not isinstance(index, CFTimeIndex):
+ try:
+ return CFTimeIndex(index)
+ except (ImportError, TypeError):
+ return index
+ else:
+ return index
+
+
+def safe_cast_to_index(array: Any) -> pd.Index:
"""Given an array, safely cast it to a pandas.Index.
If it is already a pandas.Index, return it unchanged.
@@ -376,7 +438,76 @@ def safe_cast_to_index(array: Any) ->pd.Index:
this function will not attempt to do automatic type conversion but will
always return an index with dtype=object.
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.variable import Variable
+
+ if isinstance(array, pd.Index):
+ index = array
+ elif isinstance(array, (DataArray, Variable)):
+ # returns the original multi-index for pandas.MultiIndex level coordinates
+ index = array._to_index()
+ elif isinstance(array, Index):
+ index = array.to_pandas_index()
+ elif isinstance(array, PandasIndexingAdapter):
+ index = array.array
+ else:
+ kwargs: dict[str, Any] = {}
+ if hasattr(array, "dtype"):
+ if array.dtype.kind == "O":
+ kwargs["dtype"] = "object"
+ elif array.dtype == "float16":
+ emit_user_level_warning(
+ (
+ "`pandas.Index` does not support the `float16` dtype."
+ " Casting to `float64` for you, but in the future please"
+ " manually cast to either `float32` and `float64`."
+ ),
+ category=DeprecationWarning,
+ )
+ kwargs["dtype"] = "float64"
+
+ index = pd.Index(np.asarray(array), **kwargs)
+
+ return _maybe_cast_to_cftimeindex(index)
+
+
+def _sanitize_slice_element(x):
+ from xarray.core.dataarray import DataArray
+ from xarray.core.variable import Variable
+
+ if not isinstance(x, tuple) and len(np.shape(x)) != 0:
+ raise ValueError(
+ f"cannot use non-scalar arrays in a slice for xarray indexing: {x}"
+ )
+
+ if isinstance(x, (Variable, DataArray)):
+ x = x.values
+
+ if isinstance(x, np.ndarray):
+ x = x[()]
+
+ return x
+
+
+def _query_slice(index, label, coord_name="", method=None, tolerance=None):
+ if method is not None or tolerance is not None:
+ raise NotImplementedError(
+ "cannot use ``method`` argument if any indexers are slice objects"
+ )
+ indexer = index.slice_indexer(
+ _sanitize_slice_element(label.start),
+ _sanitize_slice_element(label.stop),
+ _sanitize_slice_element(label.step),
+ )
+ if not isinstance(indexer, slice):
+ # unlike pandas, in xarray we never want to silently convert a
+ # slice indexer into an array indexer
+ raise KeyError(
+ "cannot represent labeled-based slice indexer for coordinate "
+ f"{coord_name!r} with a slice over integer positions; the index is "
+ "unsorted or non-unique"
+ )
+ return indexer
def _asarray_tuplesafe(values):
@@ -386,96 +517,501 @@ def _asarray_tuplesafe(values):
Adapted from pandas.core.common._asarray_tuplesafe
"""
- pass
+ if isinstance(values, tuple):
+ result = utils.to_0d_object_array(values)
+ else:
+ result = np.asarray(values)
+ if result.ndim == 2:
+ result = np.empty(len(values), dtype=object)
+ result[:] = values
+
+ return result
+
+
+def _is_nested_tuple(possible_tuple):
+ return isinstance(possible_tuple, tuple) and any(
+ isinstance(value, (tuple, list, slice)) for value in possible_tuple
+ )
-def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None
- ) ->np.ndarray:
+def normalize_label(value, dtype=None) -> np.ndarray:
+ if getattr(value, "ndim", 1) <= 1:
+ value = _asarray_tuplesafe(value)
+ if dtype is not None and dtype.kind == "f" and value.dtype.kind != "b":
+ # pd.Index built from coordinate with float precision != 64
+ # see https://github.com/pydata/xarray/pull/3153 for details
+ # bypass coercing dtype for boolean indexers (ignore index)
+ # see https://github.com/pydata/xarray/issues/5727
+ value = np.asarray(value, dtype=dtype)
+ return value
+
+
+def as_scalar(value: np.ndarray):
+ # see https://github.com/pydata/xarray/pull/4292 for details
+ return value[()] if value.dtype.kind in "mM" else value.item()
+
+
+def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.ndarray:
"""Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional
labels
"""
- pass
+ flat_labels = np.ravel(labels)
+ if flat_labels.dtype == "float16":
+ flat_labels = flat_labels.astype("float64")
+ flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)
+ indexer = flat_indexer.reshape(labels.shape)
+ return indexer
-T_PandasIndex = TypeVar('T_PandasIndex', bound='PandasIndex')
+T_PandasIndex = TypeVar("T_PandasIndex", bound="PandasIndex")
class PandasIndex(Index):
"""Wrap a pandas.Index as an xarray compatible index."""
+
index: pd.Index
dim: Hashable
coord_dtype: Any
- __slots__ = 'index', 'dim', 'coord_dtype'
- def __init__(self, array: Any, dim: Hashable, coord_dtype: Any=None, *,
- fastpath: bool=False):
+ __slots__ = ("index", "dim", "coord_dtype")
+
+ def __init__(
+ self,
+ array: Any,
+ dim: Hashable,
+ coord_dtype: Any = None,
+ *,
+ fastpath: bool = False,
+ ):
if fastpath:
index = array
else:
index = safe_cast_to_index(array)
+
if index.name is None:
+ # make a shallow copy: cheap and because the index name may be updated
+ # here or in other constructors (cannot use pd.Index.rename as this
+ # constructor is also called from PandasMultiIndex)
index = index.copy()
index.name = dim
+
self.index = index
self.dim = dim
+
if coord_dtype is None:
coord_dtype = get_valid_numpy_dtype(index)
self.coord_dtype = coord_dtype
+ def _replace(self, index, dim=None, coord_dtype=None):
+ if dim is None:
+ dim = self.dim
+ if coord_dtype is None:
+ coord_dtype = self.coord_dtype
+ return type(self)(index, dim, coord_dtype, fastpath=True)
+
+ @classmethod
+ def from_variables(
+ cls,
+ variables: Mapping[Any, Variable],
+ *,
+ options: Mapping[str, Any],
+ ) -> PandasIndex:
+ if len(variables) != 1:
+ raise ValueError(
+ f"PandasIndex only accepts one variable, found {len(variables)} variables"
+ )
+
+ name, var = next(iter(variables.items()))
+
+ if var.ndim == 0:
+ raise ValueError(
+ f"cannot set a PandasIndex from the scalar variable {name!r}, "
+ "only 1-dimensional variables are supported. "
+ f"Note: you might want to use `obj.expand_dims({name!r})` to create a "
+ f"new dimension and turn {name!r} as an indexed dimension coordinate."
+ )
+ elif var.ndim != 1:
+ raise ValueError(
+ "PandasIndex only accepts a 1-dimensional variable, "
+ f"variable {name!r} has {var.ndim} dimensions"
+ )
+
+ dim = var.dims[0]
+
+ # TODO: (benbovy - explicit indexes): add __index__ to ExplicitlyIndexesNDArrayMixin?
+ # this could be eventually used by Variable.to_index() and would remove the need to perform
+ # the checks below.
+
+ # preserve wrapped pd.Index (if any)
+ # accessing `.data` can load data from disk, so we only access if needed
+ data = getattr(var._data, "array") if hasattr(var._data, "array") else var.data
+ # multi-index level variable: get level index
+ if isinstance(var._data, PandasMultiIndexingAdapter):
+ level = var._data.level
+ if level is not None:
+ data = var._data.array.get_level_values(level)
+
+ obj = cls(data, dim, coord_dtype=var.dtype)
+ assert not isinstance(obj.index, pd.MultiIndex)
+ # Rename safely
+ # make a shallow copy: cheap and because the index name may be updated
+ # here or in other constructors (cannot use pd.Index.rename as this
+ # constructor is also called from PandasMultiIndex)
+ obj.index = obj.index.copy()
+ obj.index.name = name
+
+ return obj
+
+ @staticmethod
+ def _concat_indexes(indexes, dim, positions=None) -> pd.Index:
+ new_pd_index: pd.Index
+
+ if not indexes:
+ new_pd_index = pd.Index([])
+ else:
+ if not all(idx.dim == dim for idx in indexes):
+ dims = ",".join({f"{idx.dim!r}" for idx in indexes})
+ raise ValueError(
+ f"Cannot concatenate along dimension {dim!r} indexes with "
+ f"dimensions: {dims}"
+ )
+ pd_indexes = [idx.index for idx in indexes]
+ new_pd_index = pd_indexes[0].append(pd_indexes[1:])
+
+ if positions is not None:
+ indices = nputils.inverse_permutation(np.concatenate(positions))
+ new_pd_index = new_pd_index.take(indices)
+
+ return new_pd_index
+
+ @classmethod
+ def concat(
+ cls,
+ indexes: Sequence[Self],
+ dim: Hashable,
+ positions: Iterable[Iterable[int]] | None = None,
+ ) -> Self:
+ new_pd_index = cls._concat_indexes(indexes, dim, positions)
+
+ if not indexes:
+ coord_dtype = None
+ else:
+ coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes])
+
+ return cls(new_pd_index, dim=dim, coord_dtype=coord_dtype)
+
+ def create_variables(
+ self, variables: Mapping[Any, Variable] | None = None
+ ) -> IndexVars:
+ from xarray.core.variable import IndexVariable
+
+ name = self.index.name
+ attrs: Mapping[Hashable, Any] | None
+ encoding: Mapping[Hashable, Any] | None
+
+ if variables is not None and name in variables:
+ var = variables[name]
+ attrs = var.attrs
+ encoding = var.encoding
+ else:
+ attrs = None
+ encoding = None
+
+ data = PandasIndexingAdapter(self.index, dtype=self.coord_dtype)
+ var = IndexVariable(self.dim, data, attrs=attrs, encoding=encoding)
+ return {name: var}
+
+ def to_pandas_index(self) -> pd.Index:
+ return self.index
+
+ def isel(
+ self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
+ ) -> PandasIndex | None:
+ from xarray.core.variable import Variable
+
+ indxr = indexers[self.dim]
+ if isinstance(indxr, Variable):
+ if indxr.dims != (self.dim,):
+ # can't preserve a index if result has new dimensions
+ return None
+ else:
+ indxr = indxr.data
+ if not isinstance(indxr, slice) and is_scalar(indxr):
+ # scalar indexer: drop index
+ return None
+
+ return self._replace(self.index[indxr]) # type: ignore[index]
+
+ def sel(
+ self, labels: dict[Any, Any], method=None, tolerance=None
+ ) -> IndexSelResult:
+ from xarray.core.dataarray import DataArray
+ from xarray.core.variable import Variable
+
+ if method is not None and not isinstance(method, str):
+ raise TypeError("``method`` must be a string")
+
+ assert len(labels) == 1
+ coord_name, label = next(iter(labels.items()))
+
+ if isinstance(label, slice):
+ indexer = _query_slice(self.index, label, coord_name, method, tolerance)
+ elif is_dict_like(label):
+ raise ValueError(
+ "cannot use a dict-like object for selection on "
+ "a dimension that does not have a MultiIndex"
+ )
+ else:
+ label_array = normalize_label(label, dtype=self.coord_dtype)
+ if label_array.ndim == 0:
+ label_value = as_scalar(label_array)
+ if isinstance(self.index, pd.CategoricalIndex):
+ if method is not None:
+ raise ValueError(
+ "'method' is not supported when indexing using a CategoricalIndex."
+ )
+ if tolerance is not None:
+ raise ValueError(
+ "'tolerance' is not supported when indexing using a CategoricalIndex."
+ )
+ indexer = self.index.get_loc(label_value)
+ else:
+ if method is not None:
+ indexer = get_indexer_nd(
+ self.index, label_array, method, tolerance
+ )
+ if np.any(indexer < 0):
+ raise KeyError(
+ f"not all values found in index {coord_name!r}"
+ )
+ else:
+ try:
+ indexer = self.index.get_loc(label_value)
+ except KeyError as e:
+ raise KeyError(
+ f"not all values found in index {coord_name!r}. "
+ "Try setting the `method` keyword argument (example: method='nearest')."
+ ) from e
+
+ elif label_array.dtype.kind == "b":
+ indexer = label_array
+ else:
+ indexer = get_indexer_nd(self.index, label_array, method, tolerance)
+ if np.any(indexer < 0):
+ raise KeyError(f"not all values found in index {coord_name!r}")
+
+ # attach dimension names and/or coordinates to positional indexer
+ if isinstance(label, Variable):
+ indexer = Variable(label.dims, indexer)
+ elif isinstance(label, DataArray):
+ indexer = DataArray(indexer, coords=label._coords, dims=label.dims)
+
+ return IndexSelResult({self.dim: indexer})
+
+ def equals(self, other: Index):
+ if not isinstance(other, PandasIndex):
+ return False
+ return self.index.equals(other.index) and self.dim == other.dim
+
+ def join(
+ self,
+ other: Self,
+ how: str = "inner",
+ ) -> Self:
+ if how == "outer":
+ index = self.index.union(other.index)
+ else:
+ # how = "inner"
+ index = self.index.intersection(other.index)
+
+ coord_dtype = np.result_type(self.coord_dtype, other.coord_dtype)
+ return type(self)(index, self.dim, coord_dtype=coord_dtype)
+
+ def reindex_like(
+ self, other: Self, method=None, tolerance=None
+ ) -> dict[Hashable, Any]:
+ if not self.index.is_unique:
+ raise ValueError(
+ f"cannot reindex or align along dimension {self.dim!r} because the "
+ "(pandas) index has duplicate values"
+ )
+
+ return {self.dim: get_indexer_nd(self.index, other.index, method, tolerance)}
+
+ def roll(self, shifts: Mapping[Any, int]) -> PandasIndex:
+ shift = shifts[self.dim] % self.index.shape[0]
+
+ if shift != 0:
+ new_pd_idx = self.index[-shift:].append(self.index[:-shift])
+ else:
+ new_pd_idx = self.index[:]
+
+ return self._replace(new_pd_idx)
+
+ def rename(self, name_dict, dims_dict):
+ if self.index.name not in name_dict and self.dim not in dims_dict:
+ return self
+
+ new_name = name_dict.get(self.index.name, self.index.name)
+ index = self.index.rename(new_name)
+ new_dim = dims_dict.get(self.dim, self.dim)
+ return self._replace(index, dim=new_dim)
+
+ def _copy(
+ self: T_PandasIndex, deep: bool = True, memo: dict[int, Any] | None = None
+ ) -> T_PandasIndex:
+ if deep:
+ # pandas is not using the memo
+ index = self.index.copy(deep=True)
+ else:
+ # index will be copied in constructor
+ index = self.index
+ return self._replace(index)
+
def __getitem__(self, indexer: Any):
return self._replace(self.index[indexer])
def __repr__(self):
- return f'PandasIndex({repr(self.index)})'
+ return f"PandasIndex({repr(self.index)})"
-def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str='equal'
- ):
+def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal"):
"""Check that all multi-index variable candidates are 1-dimensional and
either share the same (single) dimension or each have a different dimension.
"""
- pass
+ if any([var.ndim != 1 for var in variables.values()]):
+ raise ValueError("PandasMultiIndex only accepts 1-dimensional variables")
+
+ dims = {var.dims for var in variables.values()}
+ if all_dims == "equal" and len(dims) > 1:
+ raise ValueError(
+ "unmatched dimensions for multi-index variables "
+ + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()])
+ )
-T_PDIndex = TypeVar('T_PDIndex', bound=pd.Index)
+ if all_dims == "different" and len(dims) < len(variables):
+ raise ValueError(
+ "conflicting dimensions for multi-index product variables "
+ + ", ".join([f"{k!r} {v.dims}" for k, v in variables.items()])
+ )
-def remove_unused_levels_categories(index: T_PDIndex) ->T_PDIndex:
+T_PDIndex = TypeVar("T_PDIndex", bound=pd.Index)
+
+
+def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex:
"""
Remove unused levels from MultiIndex and unused categories from CategoricalIndex
"""
- pass
+ if isinstance(index, pd.MultiIndex):
+ new_index = cast(pd.MultiIndex, index.remove_unused_levels())
+ # if it contains CategoricalIndex, we need to remove unused categories
+ # manually. See https://github.com/pandas-dev/pandas/issues/30846
+ if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels):
+ levels = []
+ for i, level in enumerate(new_index.levels):
+ if isinstance(level, pd.CategoricalIndex):
+ level = level[new_index.codes[i]].remove_unused_categories()
+ else:
+ level = level[new_index.codes[i]]
+ levels.append(level)
+ # TODO: calling from_array() reorders MultiIndex levels. It would
+ # be best to avoid this, if possible, e.g., by using
+ # MultiIndex.remove_unused_levels() (which does not reorder) on the
+ # part of the MultiIndex that is not categorical, or by fixing this
+ # upstream in pandas.
+ new_index = pd.MultiIndex.from_arrays(levels, names=new_index.names)
+ return cast(T_PDIndex, new_index)
+
+ if isinstance(index, pd.CategoricalIndex):
+ return index.remove_unused_categories() # type: ignore[attr-defined]
+
+ return index
class PandasMultiIndex(PandasIndex):
"""Wrap a pandas.MultiIndex as an xarray compatible index."""
+
index: pd.MultiIndex
dim: Hashable
coord_dtype: Any
level_coords_dtype: dict[str, Any]
- __slots__ = 'index', 'dim', 'coord_dtype', 'level_coords_dtype'
- def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any=None
- ):
+ __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype")
+
+ def __init__(self, array: Any, dim: Hashable, level_coords_dtype: Any = None):
super().__init__(array, dim)
+
+ # default index level names
names = []
for i, idx in enumerate(self.index.levels):
- name = idx.name or f'{dim}_level_{i}'
+ name = idx.name or f"{dim}_level_{i}"
if name == dim:
raise ValueError(
- f'conflicting multi-index level name {name!r} with dimension {dim!r}'
- )
+ f"conflicting multi-index level name {name!r} with dimension {dim!r}"
+ )
names.append(name)
self.index.names = names
+
if level_coords_dtype is None:
- level_coords_dtype = {idx.name: get_valid_numpy_dtype(idx) for
- idx in self.index.levels}
+ level_coords_dtype = {
+ idx.name: get_valid_numpy_dtype(idx) for idx in self.index.levels
+ }
self.level_coords_dtype = level_coords_dtype
+ def _replace(self, index, dim=None, level_coords_dtype=None) -> PandasMultiIndex:
+ if dim is None:
+ dim = self.dim
+ index.name = dim
+ if level_coords_dtype is None:
+ level_coords_dtype = self.level_coords_dtype
+ return type(self)(index, dim, level_coords_dtype)
+
+ @classmethod
+ def from_variables(
+ cls,
+ variables: Mapping[Any, Variable],
+ *,
+ options: Mapping[str, Any],
+ ) -> PandasMultiIndex:
+ _check_dim_compat(variables)
+ dim = next(iter(variables.values())).dims[0]
+
+ index = pd.MultiIndex.from_arrays(
+ [var.values for var in variables.values()], names=variables.keys()
+ )
+ index.name = dim
+ level_coords_dtype = {name: var.dtype for name, var in variables.items()}
+ obj = cls(index, dim, level_coords_dtype=level_coords_dtype)
+
+ return obj
+
@classmethod
- def stack(cls, variables: Mapping[Any, Variable], dim: Hashable
- ) ->PandasMultiIndex:
+ def concat(
+ cls,
+ indexes: Sequence[Self],
+ dim: Hashable,
+ positions: Iterable[Iterable[int]] | None = None,
+ ) -> Self:
+ new_pd_index = cls._concat_indexes(indexes, dim, positions)
+
+ if not indexes:
+ level_coords_dtype = None
+ else:
+ level_coords_dtype = {}
+ for name in indexes[0].level_coords_dtype:
+ level_coords_dtype[name] = np.result_type(
+ *[idx.level_coords_dtype[name] for idx in indexes]
+ )
+
+ return cls(new_pd_index, dim=dim, level_coords_dtype=level_coords_dtype)
+
+ @classmethod
+ def stack(
+ cls, variables: Mapping[Any, Variable], dim: Hashable
+ ) -> PandasMultiIndex:
"""Create a new Pandas MultiIndex from the product of 1-d variables (levels) along a
new dimension.
@@ -485,38 +1021,356 @@ class PandasMultiIndex(PandasIndex):
labels after a stack/unstack roundtrip.
"""
- pass
+ _check_dim_compat(variables, all_dims="different")
+
+ level_indexes = [safe_cast_to_index(var) for var in variables.values()]
+ for name, idx in zip(variables, level_indexes):
+ if isinstance(idx, pd.MultiIndex):
+ raise ValueError(
+ f"cannot create a multi-index along stacked dimension {dim!r} "
+ f"from variable {name!r} that wraps a multi-index"
+ )
+
+ split_labels, levels = zip(*[lev.factorize() for lev in level_indexes])
+ labels_mesh = np.meshgrid(*split_labels, indexing="ij")
+ labels = [x.ravel() for x in labels_mesh]
+
+ index = pd.MultiIndex(levels, labels, sortorder=0, names=variables.keys())
+ level_coords_dtype = {k: var.dtype for k, var in variables.items()}
+
+ return cls(index, dim, level_coords_dtype=level_coords_dtype)
+
+ def unstack(self) -> tuple[dict[Hashable, Index], pd.MultiIndex]:
+ clean_index = remove_unused_levels_categories(self.index)
+
+ if not clean_index.is_unique:
+ raise ValueError(
+ "Cannot unstack MultiIndex containing duplicates. Make sure entries "
+ f"are unique, e.g., by calling ``.drop_duplicates('{self.dim}')``, "
+ "before unstacking."
+ )
+
+ new_indexes: dict[Hashable, Index] = {}
+ for name, lev in zip(clean_index.names, clean_index.levels):
+ idx = PandasIndex(
+ lev.copy(), name, coord_dtype=self.level_coords_dtype[name]
+ )
+ new_indexes[name] = idx
+
+ return new_indexes, clean_index
@classmethod
- def from_variables_maybe_expand(cls, dim: Hashable, current_variables:
- Mapping[Any, Variable], variables: Mapping[Any, Variable]) ->tuple[
- PandasMultiIndex, IndexVars]:
+ def from_variables_maybe_expand(
+ cls,
+ dim: Hashable,
+ current_variables: Mapping[Any, Variable],
+ variables: Mapping[Any, Variable],
+ ) -> tuple[PandasMultiIndex, IndexVars]:
"""Create a new multi-index maybe by expanding an existing one with
new variables as index levels.
The index and its corresponding coordinates may be created along a new dimension.
"""
- pass
+ names: list[Hashable] = []
+ codes: list[Iterable[int]] = []
+ levels: list[Iterable[Any]] = []
+ level_variables: dict[Any, Variable] = {}
+
+ _check_dim_compat({**current_variables, **variables})
+
+ if len(current_variables) > 1:
+ # expand from an existing multi-index
+ data = cast(
+ PandasMultiIndexingAdapter, next(iter(current_variables.values()))._data
+ )
+ current_index = data.array
+ names.extend(current_index.names)
+ codes.extend(current_index.codes)
+ levels.extend(current_index.levels)
+ for name in current_index.names:
+ level_variables[name] = current_variables[name]
+
+ elif len(current_variables) == 1:
+ # expand from one 1D variable (no multi-index): convert it to an index level
+ var = next(iter(current_variables.values()))
+ new_var_name = f"{dim}_level_0"
+ names.append(new_var_name)
+ cat = pd.Categorical(var.values, ordered=True)
+ codes.append(cat.codes)
+ levels.append(cat.categories)
+ level_variables[new_var_name] = var
+
+ for name, var in variables.items():
+ names.append(name)
+ cat = pd.Categorical(var.values, ordered=True)
+ codes.append(cat.codes)
+ levels.append(cat.categories)
+ level_variables[name] = var
- def keep_levels(self, level_variables: Mapping[Any, Variable]) ->(
- PandasMultiIndex | PandasIndex):
+ index = pd.MultiIndex(levels, codes, names=names)
+ level_coords_dtype = {k: var.dtype for k, var in level_variables.items()}
+ obj = cls(index, dim, level_coords_dtype=level_coords_dtype)
+ index_vars = obj.create_variables(level_variables)
+
+ return obj, index_vars
+
+ def keep_levels(
+ self, level_variables: Mapping[Any, Variable]
+ ) -> PandasMultiIndex | PandasIndex:
"""Keep only the provided levels and return a new multi-index with its
corresponding coordinates.
"""
- pass
+ index = self.index.droplevel(
+ [k for k in self.index.names if k not in level_variables]
+ )
- def reorder_levels(self, level_variables: Mapping[Any, Variable]
- ) ->PandasMultiIndex:
+ if isinstance(index, pd.MultiIndex):
+ level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names}
+ return self._replace(index, level_coords_dtype=level_coords_dtype)
+ else:
+ # backward compatibility: rename the level coordinate to the dimension name
+ return PandasIndex(
+ index.rename(self.dim),
+ self.dim,
+ coord_dtype=self.level_coords_dtype[index.name],
+ )
+
+ def reorder_levels(
+ self, level_variables: Mapping[Any, Variable]
+ ) -> PandasMultiIndex:
"""Re-arrange index levels using input order and return a new multi-index with
its corresponding coordinates.
"""
- pass
+ index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys()))
+ level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names}
+ return self._replace(index, level_coords_dtype=level_coords_dtype)
+
+ def create_variables(
+ self, variables: Mapping[Any, Variable] | None = None
+ ) -> IndexVars:
+ from xarray.core.variable import IndexVariable
+
+ if variables is None:
+ variables = {}
+
+ index_vars: IndexVars = {}
+ for name in (self.dim,) + tuple(self.index.names):
+ if name == self.dim:
+ level = None
+ dtype = None
+ else:
+ level = name
+ dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok?
+
+ var = variables.get(name, None)
+ if var is not None:
+ attrs = var.attrs
+ encoding = var.encoding
+ else:
+ attrs = {}
+ encoding = {}
+
+ data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok?
+ index_vars[name] = IndexVariable(
+ self.dim,
+ data,
+ attrs=attrs,
+ encoding=encoding,
+ fastpath=True,
+ )
+
+ return index_vars
+
+ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult:
+ from xarray.core.dataarray import DataArray
+ from xarray.core.variable import Variable
+
+ if method is not None or tolerance is not None:
+ raise ValueError(
+ "multi-index does not support ``method`` and ``tolerance``"
+ )
+
+ new_index = None
+ scalar_coord_values = {}
+
+ indexer: int | slice | np.ndarray | Variable | DataArray
+
+ # label(s) given for multi-index level(s)
+ if all([lbl in self.index.names for lbl in labels]):
+ label_values = {}
+ for k, v in labels.items():
+ label_array = normalize_label(v, dtype=self.level_coords_dtype[k])
+ try:
+ label_values[k] = as_scalar(label_array)
+ except ValueError:
+ # label should be an item not an array-like
+ raise ValueError(
+ "Vectorized selection is not "
+ f"available along coordinate {k!r} (multi-index level)"
+ )
+
+ has_slice = any([isinstance(v, slice) for v in label_values.values()])
+
+ if len(label_values) == self.index.nlevels and not has_slice:
+ indexer = self.index.get_loc(
+ tuple(label_values[k] for k in self.index.names)
+ )
+ else:
+ indexer, new_index = self.index.get_loc_level(
+ tuple(label_values.values()), level=tuple(label_values.keys())
+ )
+ scalar_coord_values.update(label_values)
+ # GH2619. Raise a KeyError if nothing is chosen
+ if indexer.dtype.kind == "b" and indexer.sum() == 0: # type: ignore[union-attr]
+ raise KeyError(f"{labels} not found")
+
+ # assume one label value given for the multi-index "array" (dimension)
+ else:
+ if len(labels) > 1:
+ coord_name = next(iter(set(labels) - set(self.index.names)))
+ raise ValueError(
+ f"cannot provide labels for both coordinate {coord_name!r} (multi-index array) "
+ f"and one or more coordinates among {self.index.names!r} (multi-index levels)"
+ )
+
+ coord_name, label = next(iter(labels.items()))
+
+ if is_dict_like(label):
+ invalid_levels = tuple(
+ name for name in label if name not in self.index.names
+ )
+ if invalid_levels:
+ raise ValueError(
+ f"multi-index level names {invalid_levels} not found in indexes {tuple(self.index.names)}"
+ )
+ return self.sel(label)
+
+ elif isinstance(label, slice):
+ indexer = _query_slice(self.index, label, coord_name)
+
+ elif isinstance(label, tuple):
+ if _is_nested_tuple(label):
+ indexer = self.index.get_locs(label)
+ elif len(label) == self.index.nlevels:
+ indexer = self.index.get_loc(label)
+ else:
+ levels = [self.index.names[i] for i in range(len(label))]
+ indexer, new_index = self.index.get_loc_level(label, level=levels)
+ scalar_coord_values.update({k: v for k, v in zip(levels, label)})
+
+ else:
+ label_array = normalize_label(label)
+ if label_array.ndim == 0:
+ label_value = as_scalar(label_array)
+ indexer, new_index = self.index.get_loc_level(label_value, level=0)
+ scalar_coord_values[self.index.names[0]] = label_value
+ elif label_array.dtype.kind == "b":
+ indexer = label_array
+ else:
+ if label_array.ndim > 1:
+ raise ValueError(
+ "Vectorized selection is not available along "
+ f"coordinate {coord_name!r} with a multi-index"
+ )
+ indexer = get_indexer_nd(self.index, label_array)
+ if np.any(indexer < 0):
+ raise KeyError(f"not all values found in index {coord_name!r}")
+
+ # attach dimension names and/or coordinates to positional indexer
+ if isinstance(label, Variable):
+ indexer = Variable(label.dims, indexer)
+ elif isinstance(label, DataArray):
+ # do not include label-indexer DataArray coordinates that conflict
+ # with the level names of this index
+ coords = {
+ k: v
+ for k, v in label._coords.items()
+ if k not in self.index.names
+ }
+ indexer = DataArray(indexer, coords=coords, dims=label.dims)
+
+ if new_index is not None:
+ if isinstance(new_index, pd.MultiIndex):
+ level_coords_dtype = {
+ k: self.level_coords_dtype[k] for k in new_index.names
+ }
+ new_index = self._replace(
+ new_index, level_coords_dtype=level_coords_dtype
+ )
+ dims_dict = {}
+ drop_coords = []
+ else:
+ new_index = PandasIndex(
+ new_index,
+ new_index.name,
+ coord_dtype=self.level_coords_dtype[new_index.name],
+ )
+ dims_dict = {self.dim: new_index.index.name}
+ drop_coords = [self.dim]
+
+ # variable(s) attrs and encoding metadata are propagated
+ # when replacing the indexes in the resulting xarray object
+ new_vars = new_index.create_variables()
+ indexes = cast(dict[Any, Index], {k: new_index for k in new_vars})
+
+ # add scalar variable for each dropped level
+ variables = new_vars
+ for name, val in scalar_coord_values.items():
+ variables[name] = Variable([], val)
+
+ return IndexSelResult(
+ {self.dim: indexer},
+ indexes=indexes,
+ variables=variables,
+ drop_indexes=list(scalar_coord_values),
+ drop_coords=drop_coords,
+ rename_dims=dims_dict,
+ )
+
+ else:
+ return IndexSelResult({self.dim: indexer})
+
+ def join(self, other, how: str = "inner"):
+ if how == "outer":
+ # bug in pandas? need to reset index.name
+ other_index = other.index.copy()
+ other_index.name = None
+ index = self.index.union(other_index)
+ index.name = self.dim
+ else:
+ # how = "inner"
+ index = self.index.intersection(other.index)
+
+ level_coords_dtype = {
+ k: np.result_type(lvl_dtype, other.level_coords_dtype[k])
+ for k, lvl_dtype in self.level_coords_dtype.items()
+ }
+
+ return type(self)(index, self.dim, level_coords_dtype=level_coords_dtype)
+
+ def rename(self, name_dict, dims_dict):
+ if not set(self.index.names) & set(name_dict) and self.dim not in dims_dict:
+ return self
+
+ # pandas 1.3.0: could simply do `self.index.rename(names_dict)`
+ new_names = [name_dict.get(k, k) for k in self.index.names]
+ index = self.index.rename(new_names)
+ new_dim = dims_dict.get(self.dim, self.dim)
+ new_level_coords_dtype = {
+ k: v for k, v in zip(new_names, self.level_coords_dtype.values())
+ }
+ return self._replace(
+ index, dim=new_dim, level_coords_dtype=new_level_coords_dtype
+ )
-def create_default_index_implicit(dim_variable: Variable, all_variables: (
- Mapping | Iterable[Hashable] | None)=None) ->tuple[PandasIndex, IndexVars]:
+
+def create_default_index_implicit(
+ dim_variable: Variable,
+ all_variables: Mapping | Iterable[Hashable] | None = None,
+) -> tuple[PandasIndex, IndexVars]:
"""Create a default index from a dimension variable.
Create a PandasMultiIndex if the given variable wraps a pandas.MultiIndex,
@@ -524,10 +1378,48 @@ def create_default_index_implicit(dim_variable: Variable, all_variables: (
depreciate implicitly passing a pandas.MultiIndex as a coordinate).
"""
- pass
+ if all_variables is None:
+ all_variables = {}
+ if not isinstance(all_variables, Mapping):
+ all_variables = {k: None for k in all_variables}
+
+ name = dim_variable.dims[0]
+ array = getattr(dim_variable._data, "array", None)
+ index: PandasIndex
+
+ if isinstance(array, pd.MultiIndex):
+ index = PandasMultiIndex(array, name)
+ index_vars = index.create_variables()
+ # check for conflict between level names and variable names
+ duplicate_names = [k for k in index_vars if k in all_variables and k != name]
+ if duplicate_names:
+ # dirty workaround for an edge case where both the dimension
+ # coordinate and the level coordinates are given for the same
+ # multi-index object => do not raise an error
+ # TODO: remove this check when removing the multi-index dimension coordinate
+ if len(duplicate_names) < len(index.index.names):
+ conflict = True
+ else:
+ duplicate_vars = [all_variables[k] for k in duplicate_names]
+ conflict = any(
+ v is None or not dim_variable.equals(v) for v in duplicate_vars
+ )
+
+ if conflict:
+ conflict_str = "\n".join(duplicate_names)
+ raise ValueError(
+ f"conflicting MultiIndex level / variable name(s):\n{conflict_str}"
+ )
+ else:
+ dim_var = {name: dim_variable}
+ index = PandasIndex.from_variables(dim_var, options={})
+ index_vars = index.create_variables(dim_var)
+
+ return index, index_vars
-T_PandasOrXarrayIndex = TypeVar('T_PandasOrXarrayIndex', Index, pd.Index)
+# generic type that represents either a pandas or an xarray index
+T_PandasOrXarrayIndex = TypeVar("T_PandasOrXarrayIndex", Index, pd.Index)
class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
@@ -540,15 +1432,27 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
methods.
"""
+
_index_type: type[Index] | type[pd.Index]
_indexes: dict[Any, T_PandasOrXarrayIndex]
_variables: dict[Any, Variable]
- __slots__ = ('_index_type', '_indexes', '_variables', '_dims',
- '__coord_name_id', '__id_index', '__id_coord_names')
- def __init__(self, indexes: (Mapping[Any, T_PandasOrXarrayIndex] | None
- )=None, variables: (Mapping[Any, Variable] | None)=None, index_type:
- (type[Index] | type[pd.Index])=Index):
+ __slots__ = (
+ "_index_type",
+ "_indexes",
+ "_variables",
+ "_dims",
+ "__coord_name_id",
+ "__id_index",
+ "__id_coord_names",
+ )
+
+ def __init__(
+ self,
+ indexes: Mapping[Any, T_PandasOrXarrayIndex] | None = None,
+ variables: Mapping[Any, Variable] | None = None,
+ index_type: type[Index] | type[pd.Index] = Index,
+ ):
"""Constructor not for public consumption.
Parameters
@@ -567,35 +1471,89 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
indexes = {}
if variables is None:
variables = {}
+
unmatched_keys = set(indexes) ^ set(variables)
if unmatched_keys:
raise ValueError(
- f'unmatched keys found in indexes and variables: {unmatched_keys}'
- )
+ f"unmatched keys found in indexes and variables: {unmatched_keys}"
+ )
+
if any(not isinstance(idx, index_type) for idx in indexes.values()):
- index_type_str = f'{index_type.__module__}.{index_type.__name__}'
+ index_type_str = f"{index_type.__module__}.{index_type.__name__}"
raise TypeError(
- f'values of indexes must all be instances of {index_type_str}')
+ f"values of indexes must all be instances of {index_type_str}"
+ )
+
self._index_type = index_type
self._indexes = dict(**indexes)
self._variables = dict(**variables)
+
self._dims: Mapping[Hashable, int] | None = None
self.__coord_name_id: dict[Any, int] | None = None
self.__id_index: dict[int, T_PandasOrXarrayIndex] | None = None
self.__id_coord_names: dict[int, tuple[Hashable, ...]] | None = None
- def get_unique(self) ->list[T_PandasOrXarrayIndex]:
+ @property
+ def _coord_name_id(self) -> dict[Any, int]:
+ if self.__coord_name_id is None:
+ self.__coord_name_id = {k: id(idx) for k, idx in self._indexes.items()}
+ return self.__coord_name_id
+
+ @property
+ def _id_index(self) -> dict[int, T_PandasOrXarrayIndex]:
+ if self.__id_index is None:
+ self.__id_index = {id(idx): idx for idx in self.get_unique()}
+ return self.__id_index
+
+ @property
+ def _id_coord_names(self) -> dict[int, tuple[Hashable, ...]]:
+ if self.__id_coord_names is None:
+ id_coord_names: Mapping[int, list[Hashable]] = defaultdict(list)
+ for k, v in self._coord_name_id.items():
+ id_coord_names[v].append(k)
+ self.__id_coord_names = {k: tuple(v) for k, v in id_coord_names.items()}
+
+ return self.__id_coord_names
+
+ @property
+ def variables(self) -> Mapping[Hashable, Variable]:
+ return Frozen(self._variables)
+
+ @property
+ def dims(self) -> Mapping[Hashable, int]:
+ from xarray.core.variable import calculate_dimensions
+
+ if self._dims is None:
+ self._dims = calculate_dimensions(self._variables)
+
+ return Frozen(self._dims)
+
+ def copy(self) -> Indexes:
+ return type(self)(dict(self._indexes), dict(self._variables))
+
+ def get_unique(self) -> list[T_PandasOrXarrayIndex]:
"""Return a list of unique indexes, preserving order."""
- pass
- def is_multi(self, key: Hashable) ->bool:
+ unique_indexes: list[T_PandasOrXarrayIndex] = []
+ seen: set[int] = set()
+
+ for index in self._indexes.values():
+ index_id = id(index)
+ if index_id not in seen:
+ unique_indexes.append(index)
+ seen.add(index_id)
+
+ return unique_indexes
+
+ def is_multi(self, key: Hashable) -> bool:
"""Return True if ``key`` maps to a multi-coordinate index,
False otherwise.
"""
- pass
+ return len(self._id_coord_names[self._coord_name_id[key]]) > 1
- def get_all_coords(self, key: Hashable, errors: ErrorOptions='raise'
- ) ->dict[Hashable, Variable]:
+ def get_all_coords(
+ self, key: Hashable, errors: ErrorOptions = "raise"
+ ) -> dict[Hashable, Variable]:
"""Return all coordinates having the same index.
Parameters
@@ -612,10 +1570,21 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
A dictionary of all coordinate variables having the same index.
"""
- pass
+ if errors not in ["raise", "ignore"]:
+ raise ValueError('errors must be either "raise" or "ignore"')
+
+ if key not in self._indexes:
+ if errors == "raise":
+ raise ValueError(f"no index found for {key!r} coordinate")
+ else:
+ return {}
- def get_all_dims(self, key: Hashable, errors: ErrorOptions='raise'
- ) ->Mapping[Hashable, int]:
+ all_coord_names = self._id_coord_names[self._coord_name_id[key]]
+ return {k: self._variables[k] for k in all_coord_names}
+
+ def get_all_dims(
+ self, key: Hashable, errors: ErrorOptions = "raise"
+ ) -> Mapping[Hashable, int]:
"""Return all dimensions shared by an index.
Parameters
@@ -632,25 +1601,42 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
A dictionary of all dimensions shared by an index.
"""
- pass
+ from xarray.core.variable import calculate_dimensions
+
+ return calculate_dimensions(self.get_all_coords(key, errors=errors))
- def group_by_index(self) ->list[tuple[T_PandasOrXarrayIndex, dict[
- Hashable, Variable]]]:
+ def group_by_index(
+ self,
+ ) -> list[tuple[T_PandasOrXarrayIndex, dict[Hashable, Variable]]]:
"""Returns a list of unique indexes and their corresponding coordinates."""
- pass
- def to_pandas_indexes(self) ->Indexes[pd.Index]:
+ index_coords = []
+ for i, index in self._id_index.items():
+ coords = {k: self._variables[k] for k in self._id_coord_names[i]}
+ index_coords.append((index, coords))
+
+ return index_coords
+
+ def to_pandas_indexes(self) -> Indexes[pd.Index]:
"""Returns an immutable proxy for Dataset or DataArrary pandas indexes.
Raises an error if this proxy contains indexes that cannot be coerced to
pandas.Index objects.
"""
- pass
+ indexes: dict[Hashable, pd.Index] = {}
+
+ for k, idx in self._indexes.items():
+ if isinstance(idx, pd.Index):
+ indexes[k] = idx
+ elif isinstance(idx, Index):
+ indexes[k] = idx.to_pandas_index()
+
+ return Indexes(indexes, self._variables, index_type=pd.Index)
- def copy_indexes(self, deep: bool=True, memo: (dict[int,
- T_PandasOrXarrayIndex] | None)=None) ->tuple[dict[Hashable,
- T_PandasOrXarrayIndex], dict[Hashable, Variable]]:
+ def copy_indexes(
+ self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None
+ ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]:
"""Return a new dictionary with copies of indexes, preserving
unique indexes.
@@ -663,18 +1649,44 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
in this dict.
"""
- pass
-
- def __iter__(self) ->Iterator[T_PandasOrXarrayIndex]:
+ new_indexes: dict[Hashable, T_PandasOrXarrayIndex] = {}
+ new_index_vars: dict[Hashable, Variable] = {}
+
+ xr_idx: Index
+ new_idx: T_PandasOrXarrayIndex
+ for idx, coords in self.group_by_index():
+ if isinstance(idx, pd.Index):
+ convert_new_idx = True
+ dim = next(iter(coords.values())).dims[0]
+ if isinstance(idx, pd.MultiIndex):
+ xr_idx = PandasMultiIndex(idx, dim)
+ else:
+ xr_idx = PandasIndex(idx, dim)
+ else:
+ convert_new_idx = False
+ xr_idx = idx
+
+ new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment]
+ idx_vars = xr_idx.create_variables(coords)
+
+ if convert_new_idx:
+ new_idx = new_idx.index # type: ignore[attr-defined]
+
+ new_indexes.update({k: new_idx for k in coords})
+ new_index_vars.update(idx_vars)
+
+ return new_indexes, new_index_vars
+
+ def __iter__(self) -> Iterator[T_PandasOrXarrayIndex]:
return iter(self._indexes)
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self._indexes)
- def __contains__(self, key) ->bool:
+ def __contains__(self, key) -> bool:
return key in self._indexes
- def __getitem__(self, key) ->T_PandasOrXarrayIndex:
+ def __getitem__(self, key) -> T_PandasOrXarrayIndex:
return self._indexes[key]
def __repr__(self):
@@ -682,8 +1694,9 @@ class Indexes(collections.abc.Mapping, Generic[T_PandasOrXarrayIndex]):
return formatting.indexes_repr(indexes)
-def default_indexes(coords: Mapping[Any, Variable], dims: Iterable) ->dict[
- Hashable, Index]:
+def default_indexes(
+ coords: Mapping[Any, Variable], dims: Iterable
+) -> dict[Hashable, Index]:
"""Default indexes for a Dataset/DataArray.
Parameters
@@ -698,44 +1711,209 @@ def default_indexes(coords: Mapping[Any, Variable], dims: Iterable) ->dict[
Mapping from indexing keys (levels/dimension names) to indexes used for
indexing along that dimension.
"""
- pass
+ indexes: dict[Hashable, Index] = {}
+ coord_names = set(coords)
+
+ for name, var in coords.items():
+ if name in dims and var.ndim == 1:
+ index, index_vars = create_default_index_implicit(var, coords)
+ if set(index_vars) <= coord_names:
+ indexes.update({k: index for k in index_vars})
+ return indexes
-def indexes_equal(index: Index, other_index: Index, variable: Variable,
- other_variable: Variable, cache: (dict[tuple[int, int], bool | None] |
- None)=None) ->bool:
+
+def indexes_equal(
+ index: Index,
+ other_index: Index,
+ variable: Variable,
+ other_variable: Variable,
+ cache: dict[tuple[int, int], bool | None] | None = None,
+) -> bool:
"""Check if two indexes are equal, possibly with cached results.
If the two indexes are not of the same type or they do not implement
equality, fallback to coordinate labels equality check.
"""
- pass
+ if cache is None:
+ # dummy cache
+ cache = {}
+
+ key = (id(index), id(other_index))
+ equal: bool | None = None
+
+ if key not in cache:
+ if type(index) is type(other_index):
+ try:
+ equal = index.equals(other_index)
+ except NotImplementedError:
+ equal = None
+ else:
+ cache[key] = equal
+ else:
+ equal = None
+ else:
+ equal = cache[key]
+
+ if equal is None:
+ equal = variable.equals(other_variable)
+ return cast(bool, equal)
-def indexes_all_equal(elements: Sequence[tuple[Index, dict[Hashable,
- Variable]]]) ->bool:
+
+def indexes_all_equal(
+ elements: Sequence[tuple[Index, dict[Hashable, Variable]]]
+) -> bool:
"""Check if indexes are all equal.
If they are not of the same type or they do not implement this check, check
if their coordinate variables are all equal instead.
"""
- pass
-
-def filter_indexes_from_coords(indexes: Mapping[Any, Index],
- filtered_coord_names: set) ->dict[Hashable, Index]:
+ def check_variables():
+ variables = [e[1] for e in elements]
+ return any(
+ not variables[0][k].equals(other_vars[k])
+ for other_vars in variables[1:]
+ for k in variables[0]
+ )
+
+ indexes = [e[0] for e in elements]
+
+ same_objects = all(indexes[0] is other_idx for other_idx in indexes[1:])
+ if same_objects:
+ return True
+
+ same_type = all(type(indexes[0]) is type(other_idx) for other_idx in indexes[1:])
+ if same_type:
+ try:
+ not_equal = any(
+ not indexes[0].equals(other_idx) for other_idx in indexes[1:]
+ )
+ except NotImplementedError:
+ not_equal = check_variables()
+ else:
+ not_equal = check_variables()
+
+ return not not_equal
+
+
+def _apply_indexes_fast(indexes: Indexes[Index], args: Mapping[Any, Any], func: str):
+ # This function avoids the call to indexes.group_by_index
+ # which is really slow when repeatidly iterating through
+ # an array. However, it fails to return the correct ID for
+ # multi-index arrays
+ indexes_fast, coords = indexes._indexes, indexes._variables
+
+ new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes_fast.items()}
+ new_index_variables: dict[Hashable, Variable] = {}
+ for name, index in indexes_fast.items():
+ coord = coords[name]
+ if hasattr(coord, "_indexes"):
+ index_vars = {n: coords[n] for n in coord._indexes}
+ else:
+ index_vars = {name: coord}
+ index_dims = {d for var in index_vars.values() for d in var.dims}
+ index_args = {k: v for k, v in args.items() if k in index_dims}
+
+ if index_args:
+ new_index = getattr(index, func)(index_args)
+ if new_index is not None:
+ new_indexes.update({k: new_index for k in index_vars})
+ new_index_vars = new_index.create_variables(index_vars)
+ new_index_variables.update(new_index_vars)
+ else:
+ for k in index_vars:
+ new_indexes.pop(k, None)
+ return new_indexes, new_index_variables
+
+
+def _apply_indexes(
+ indexes: Indexes[Index],
+ args: Mapping[Any, Any],
+ func: str,
+) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
+ new_indexes: dict[Hashable, Index] = {k: v for k, v in indexes.items()}
+ new_index_variables: dict[Hashable, Variable] = {}
+
+ for index, index_vars in indexes.group_by_index():
+ index_dims = {d for var in index_vars.values() for d in var.dims}
+ index_args = {k: v for k, v in args.items() if k in index_dims}
+ if index_args:
+ new_index = getattr(index, func)(index_args)
+ if new_index is not None:
+ new_indexes.update({k: new_index for k in index_vars})
+ new_index_vars = new_index.create_variables(index_vars)
+ new_index_variables.update(new_index_vars)
+ else:
+ for k in index_vars:
+ new_indexes.pop(k, None)
+
+ return new_indexes, new_index_variables
+
+
+def isel_indexes(
+ indexes: Indexes[Index],
+ indexers: Mapping[Any, Any],
+) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
+ # TODO: remove if clause in the future. It should be unnecessary.
+ # See failure introduced when removed
+ # https://github.com/pydata/xarray/pull/9002#discussion_r1590443756
+ if any(isinstance(v, PandasMultiIndex) for v in indexes._indexes.values()):
+ return _apply_indexes(indexes, indexers, "isel")
+ else:
+ return _apply_indexes_fast(indexes, indexers, "isel")
+
+
+def roll_indexes(
+ indexes: Indexes[Index],
+ shifts: Mapping[Any, int],
+) -> tuple[dict[Hashable, Index], dict[Hashable, Variable]]:
+ return _apply_indexes(indexes, shifts, "roll")
+
+
+def filter_indexes_from_coords(
+ indexes: Mapping[Any, Index],
+ filtered_coord_names: set,
+) -> dict[Hashable, Index]:
"""Filter index items given a (sub)set of coordinate names.
Drop all multi-coordinate related index items for any key missing in the set
of coordinate names.
"""
- pass
+ filtered_indexes: dict[Any, Index] = dict(indexes)
+ index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set)
+ for name, idx in indexes.items():
+ index_coord_names[id(idx)].add(name)
-def assert_no_index_corrupted(indexes: Indexes[Index], coord_names: set[
- Hashable], action: str='remove coordinate(s)') ->None:
+ for idx_coord_names in index_coord_names.values():
+ if not idx_coord_names <= filtered_coord_names:
+ for k in idx_coord_names:
+ del filtered_indexes[k]
+
+ return filtered_indexes
+
+
+def assert_no_index_corrupted(
+ indexes: Indexes[Index],
+ coord_names: set[Hashable],
+ action: str = "remove coordinate(s)",
+) -> None:
"""Assert removing coordinates or indexes will not corrupt indexes."""
- pass
+
+ # An index may be corrupted when the set of its corresponding coordinate name(s)
+ # partially overlaps the set of coordinate names to remove
+ for index, index_coords in indexes.group_by_index():
+ common_names = set(index_coords) & coord_names
+ if common_names and len(common_names) != len(index_coords):
+ common_names_str = ", ".join(f"{k!r}" for k in common_names)
+ index_names_str = ", ".join(f"{k!r}" for k in index_coords)
+ raise ValueError(
+ f"cannot {action} {common_names_str}, which would corrupt "
+ f"the following index built from coordinates {index_names_str}:\n"
+ f"{index}"
+ )
diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py
index d02970b0..19937270 100644
--- a/xarray/core/indexing.py
+++ b/xarray/core/indexing.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import enum
import functools
import operator
@@ -9,17 +10,29 @@ from dataclasses import dataclass, field
from datetime import timedelta
from html import escape
from typing import TYPE_CHECKING, Any, Callable, overload
+
import numpy as np
import pandas as pd
+
from xarray.core import duck_array_ops
from xarray.core.nputils import NumpyVIndexAdapter
from xarray.core.options import OPTIONS
from xarray.core.types import T_Xarray
-from xarray.core.utils import NDArrayMixin, either_dict_or_kwargs, get_valid_numpy_dtype, is_duck_array, is_duck_dask_array, is_scalar, to_0d_array
+from xarray.core.utils import (
+ NDArrayMixin,
+ either_dict_or_kwargs,
+ get_valid_numpy_dtype,
+ is_duck_array,
+ is_duck_dask_array,
+ is_scalar,
+ to_0d_array,
+)
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array
+
if TYPE_CHECKING:
from numpy.typing import DTypeLike
+
from xarray.core.indexes import Index
from xarray.core.types import Self
from xarray.core.variable import Variable
@@ -49,6 +62,7 @@ class IndexSelResult:
rename in the resulting DataArray or Dataset.
"""
+
dim_indexers: dict[Any, Any]
indexes: dict[Any, Index] = field(default_factory=dict)
variables: dict[Any, Variable] = field(default_factory=dict)
@@ -62,23 +76,135 @@ class IndexSelResult:
See https://stackoverflow.com/a/51802661
"""
- pass
-
-
-def group_indexers_by_index(obj: T_Xarray, indexers: Mapping[Any, Any],
- options: Mapping[str, Any]) ->list[tuple[Index, dict[Any, Any]]]:
+ return (
+ self.dim_indexers,
+ self.indexes,
+ self.variables,
+ self.drop_coords,
+ self.drop_indexes,
+ self.rename_dims,
+ )
+
+
+def merge_sel_results(results: list[IndexSelResult]) -> IndexSelResult:
+ all_dims_count = Counter([dim for res in results for dim in res.dim_indexers])
+ duplicate_dims = {k: v for k, v in all_dims_count.items() if v > 1}
+
+ if duplicate_dims:
+ # TODO: this message is not right when combining indexe(s) queries with
+ # location-based indexing on a dimension with no dimension-coordinate (failback)
+ fmt_dims = [
+ f"{dim!r}: {count} indexes involved"
+ for dim, count in duplicate_dims.items()
+ ]
+ raise ValueError(
+ "Xarray does not support label-based selection with more than one index "
+ "over the following dimension(s):\n"
+ + "\n".join(fmt_dims)
+ + "\nSuggestion: use a multi-index for each of those dimension(s)."
+ )
+
+ dim_indexers = {}
+ indexes = {}
+ variables = {}
+ drop_coords = []
+ drop_indexes = []
+ rename_dims = {}
+
+ for res in results:
+ dim_indexers.update(res.dim_indexers)
+ indexes.update(res.indexes)
+ variables.update(res.variables)
+ drop_coords += res.drop_coords
+ drop_indexes += res.drop_indexes
+ rename_dims.update(res.rename_dims)
+
+ return IndexSelResult(
+ dim_indexers, indexes, variables, drop_coords, drop_indexes, rename_dims
+ )
+
+
+def group_indexers_by_index(
+ obj: T_Xarray,
+ indexers: Mapping[Any, Any],
+ options: Mapping[str, Any],
+) -> list[tuple[Index, dict[Any, Any]]]:
"""Returns a list of unique indexes and their corresponding indexers."""
- pass
-
-
-def map_index_queries(obj: T_Xarray, indexers: Mapping[Any, Any], method=
- None, tolerance: (int | float | Iterable[int | float] | None)=None, **
- indexers_kwargs: Any) ->IndexSelResult:
+ unique_indexes = {}
+ grouped_indexers: Mapping[int | None, dict] = defaultdict(dict)
+
+ for key, label in indexers.items():
+ index: Index = obj.xindexes.get(key, None)
+
+ if index is not None:
+ index_id = id(index)
+ unique_indexes[index_id] = index
+ grouped_indexers[index_id][key] = label
+ elif key in obj.coords:
+ raise KeyError(f"no index found for coordinate {key!r}")
+ elif key not in obj.dims:
+ raise KeyError(
+ f"{key!r} is not a valid dimension or coordinate for "
+ f"{obj.__class__.__name__} with dimensions {obj.dims!r}"
+ )
+ elif len(options):
+ raise ValueError(
+ f"cannot supply selection options {options!r} for dimension {key!r}"
+ "that has no associated coordinate or index"
+ )
+ else:
+ # key is a dimension without a "dimension-coordinate"
+ # failback to location-based selection
+ # TODO: depreciate this implicit behavior and suggest using isel instead?
+ unique_indexes[None] = None
+ grouped_indexers[None][key] = label
+
+ return [(unique_indexes[k], grouped_indexers[k]) for k in unique_indexes]
+
+
+def map_index_queries(
+ obj: T_Xarray,
+ indexers: Mapping[Any, Any],
+ method=None,
+ tolerance: int | float | Iterable[int | float] | None = None,
+ **indexers_kwargs: Any,
+) -> IndexSelResult:
"""Execute index queries from a DataArray / Dataset and label-based indexers
and return the (merged) query results.
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ # TODO benbovy - flexible indexes: remove when custom index options are available
+ if method is None and tolerance is None:
+ options = {}
+ else:
+ options = {"method": method, "tolerance": tolerance}
+
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "map_index_queries")
+ grouped_indexers = group_indexers_by_index(obj, indexers, options)
+
+ results = []
+ for index, labels in grouped_indexers:
+ if index is None:
+ # forward dimension indexers with no index/coordinate
+ results.append(IndexSelResult(labels))
+ else:
+ results.append(index.sel(labels, **options))
+
+ merged = merge_sel_results(results)
+
+ # drop dimension coordinates found in dimension indexers
+ # (also drop multi-index if any)
+ # (.sel() already ensures alignment)
+ for k, v in merged.dim_indexers.items():
+ if isinstance(v, DataArray):
+ if k in v._indexes:
+ v = v.reset_index(k)
+ drop_coords = [name for name in v._coords if name in merged.dim_indexers]
+ merged.dim_indexers[k] = v.drop_vars(drop_coords)
+
+ return merged
def expanded_indexer(key, ndim):
@@ -89,10 +215,29 @@ def expanded_indexer(key, ndim):
number of full slices and then padding the key with full slices so that it
reaches the appropriate dimensionality.
"""
- pass
+ if not isinstance(key, tuple):
+ # numpy treats non-tuple keys equivalent to tuples of length 1
+ key = (key,)
+ new_key = []
+ # handling Ellipsis right is a little tricky, see:
+ # https://numpy.org/doc/stable/reference/arrays.indexing.html#advanced-indexing
+ found_ellipsis = False
+ for k in key:
+ if k is Ellipsis:
+ if not found_ellipsis:
+ new_key.extend((ndim + 1 - len(key)) * [slice(None)])
+ found_ellipsis = True
+ else:
+ new_key.append(slice(None))
+ else:
+ new_key.append(k)
+ if len(new_key) > ndim:
+ raise IndexError("too many indices")
+ new_key.extend((ndim - len(new_key)) * [slice(None)])
+ return tuple(new_key)
-def _normalize_slice(sl: slice, size: int) ->slice:
+def _normalize_slice(sl: slice, size: int) -> slice:
"""
Ensure that given slice only contains positive start and stop values
(stop can be -1 for full-size slices with negative steps, e.g. [-10::-1])
@@ -104,11 +249,10 @@ def _normalize_slice(sl: slice, size: int) ->slice:
>>> _normalize_slice(slice(0, -1), 10)
slice(0, 9, 1)
"""
- pass
+ return slice(*sl.indices(size))
-def _expand_slice(slice_: slice, size: int) ->np.ndarray[Any, np.dtype[np.
- integer]]:
+def _expand_slice(slice_: slice, size: int) -> np.ndarray[Any, np.dtype[np.integer]]:
"""
Expand slice to an array containing only positive integers.
@@ -119,15 +263,53 @@ def _expand_slice(slice_: slice, size: int) ->np.ndarray[Any, np.dtype[np.
>>> _expand_slice(slice(0, -1), 10)
array([0, 1, 2, 3, 4, 5, 6, 7, 8])
"""
- pass
+ sl = _normalize_slice(slice_, size)
+ return np.arange(sl.start, sl.stop, sl.step)
-def slice_slice(old_slice: slice, applied_slice: slice, size: int) ->slice:
+def slice_slice(old_slice: slice, applied_slice: slice, size: int) -> slice:
"""Given a slice and the size of the dimension to which it will be applied,
index it with another slice to return a new slice equivalent to applying
the slices sequentially
"""
- pass
+ old_slice = _normalize_slice(old_slice, size)
+
+ size_after_old_slice = len(range(old_slice.start, old_slice.stop, old_slice.step))
+ if size_after_old_slice == 0:
+ # nothing left after applying first slice
+ return slice(0)
+
+ applied_slice = _normalize_slice(applied_slice, size_after_old_slice)
+
+ start = old_slice.start + applied_slice.start * old_slice.step
+ if start < 0:
+ # nothing left after applying second slice
+ # (can only happen for old_slice.step < 0, e.g. [10::-1], [20:])
+ return slice(0)
+
+ stop = old_slice.start + applied_slice.stop * old_slice.step
+ if stop < 0:
+ stop = None
+
+ step = old_slice.step * applied_slice.step
+
+ return slice(start, stop, step)
+
+
+def _index_indexer_1d(old_indexer, applied_indexer, size: int):
+ if isinstance(applied_indexer, slice) and applied_indexer == slice(None):
+ # shortcut for the usual case
+ return old_indexer
+ if isinstance(old_indexer, slice):
+ if isinstance(applied_indexer, slice):
+ indexer = slice_slice(old_indexer, applied_indexer, size)
+ elif isinstance(applied_indexer, integer_types):
+ indexer = range(*old_indexer.indices(size))[applied_indexer] # type: ignore[assignment]
+ else:
+ indexer = _expand_slice(old_indexer, size)[applied_indexer]
+ else:
+ indexer = old_indexer[applied_indexer]
+ return indexer
class ExplicitIndexer:
@@ -140,33 +322,56 @@ class ExplicitIndexer:
Do not instantiate BaseIndexer objects directly: instead, use one of the
sub-classes BasicIndexer, OuterIndexer or VectorizedIndexer.
"""
- __slots__ = '_key',
+
+ __slots__ = ("_key",)
def __init__(self, key: tuple[Any, ...]):
if type(self) is ExplicitIndexer:
- raise TypeError('cannot instantiate base ExplicitIndexer objects')
+ raise TypeError("cannot instantiate base ExplicitIndexer objects")
self._key = tuple(key)
- def __repr__(self) ->str:
- return f'{type(self).__name__}({self.tuple})'
+ @property
+ def tuple(self) -> tuple[Any, ...]:
+ return self._key
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}({self.tuple})"
+
+
+@overload
+def as_integer_or_none(value: int) -> int: ...
+@overload
+def as_integer_or_none(value: None) -> None: ...
+def as_integer_or_none(value: int | None) -> int | None:
+ return None if value is None else operator.index(value)
+
+
+def as_integer_slice(value: slice) -> slice:
+ start = as_integer_or_none(value.start)
+ stop = as_integer_or_none(value.stop)
+ step = as_integer_or_none(value.step)
+ return slice(start, stop, step)
class IndexCallable:
"""Provide getitem and setitem syntax for callable objects."""
- __slots__ = 'getter', 'setter'
- def __init__(self, getter: Callable[..., Any], setter: (Callable[...,
- Any] | None)=None):
+ __slots__ = ("getter", "setter")
+
+ def __init__(
+ self, getter: Callable[..., Any], setter: Callable[..., Any] | None = None
+ ):
self.getter = getter
self.setter = setter
- def __getitem__(self, key: Any) ->Any:
+ def __getitem__(self, key: Any) -> Any:
return self.getter(key)
- def __setitem__(self, key: Any, value: Any) ->None:
+ def __setitem__(self, key: Any, value: Any) -> None:
if self.setter is None:
raise NotImplementedError(
- 'Setting values is not supported for this indexer.')
+ "Setting values is not supported for this indexer."
+ )
self.setter(key, value)
@@ -177,11 +382,13 @@ class BasicIndexer(ExplicitIndexer):
rules for basic indexing: each axis is independently sliced and axes
indexed with an integer are dropped from the result.
"""
+
__slots__ = ()
def __init__(self, key: tuple[int | np.integer | slice, ...]):
if not isinstance(key, tuple):
- raise TypeError(f'key must be a tuple: {key!r}')
+ raise TypeError(f"key must be a tuple: {key!r}")
+
new_key = []
for k in key:
if isinstance(k, integer_types):
@@ -190,9 +397,10 @@ class BasicIndexer(ExplicitIndexer):
k = as_integer_slice(k)
else:
raise TypeError(
- f'unexpected indexer type for {type(self).__name__}: {k!r}'
- )
+ f"unexpected indexer type for {type(self).__name__}: {k!r}"
+ )
new_key.append(k)
+
super().__init__(tuple(new_key))
@@ -204,12 +412,18 @@ class OuterIndexer(ExplicitIndexer):
axes indexed with an integer are dropped from the result. This type of
indexing works like MATLAB/Fortran.
"""
+
__slots__ = ()
- def __init__(self, key: tuple[int | np.integer | slice | np.ndarray[Any,
- np.dtype[np.generic]], ...]):
+ def __init__(
+ self,
+ key: tuple[
+ int | np.integer | slice | np.ndarray[Any, np.dtype[np.generic]], ...
+ ],
+ ):
if not isinstance(key, tuple):
- raise TypeError(f'key must be a tuple: {key!r}')
+ raise TypeError(f"key must be a tuple: {key!r}")
+
new_key = []
for k in key:
if isinstance(k, integer_types):
@@ -219,18 +433,20 @@ class OuterIndexer(ExplicitIndexer):
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
- f'invalid indexer array, does not have integer dtype: {k!r}'
- )
- if k.ndim > 1:
+ f"invalid indexer array, does not have integer dtype: {k!r}"
+ )
+ if k.ndim > 1: # type: ignore[union-attr]
raise TypeError(
- f'invalid indexer array for {type(self).__name__}; must be scalar or have 1 dimension: {k!r}'
- )
- k = k.astype(np.int64)
+ f"invalid indexer array for {type(self).__name__}; must be scalar "
+ f"or have 1 dimension: {k!r}"
+ )
+ k = k.astype(np.int64) # type: ignore[union-attr]
else:
raise TypeError(
- f'unexpected indexer type for {type(self).__name__}: {k!r}'
- )
+ f"unexpected indexer type for {type(self).__name__}: {k!r}"
+ )
new_key.append(k)
+
super().__init__(tuple(new_key))
@@ -243,12 +459,13 @@ class VectorizedIndexer(ExplicitIndexer):
(including broadcasting) except sliced axes are always moved to the end:
https://github.com/numpy/numpy/pull/6256
"""
+
__slots__ = ()
- def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.
- generic]], ...]):
+ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...]):
if not isinstance(key, tuple):
- raise TypeError(f'key must be a tuple: {key!r}')
+ raise TypeError(f"key must be a tuple: {key!r}")
+
new_key = []
ndim = None
for k in key:
@@ -256,70 +473,129 @@ class VectorizedIndexer(ExplicitIndexer):
k = as_integer_slice(k)
elif is_duck_dask_array(k):
raise ValueError(
- 'Vectorized indexing with Dask arrays is not supported. Please pass a numpy array by calling ``.compute``. See https://github.com/dask/dask/issues/8958.'
- )
+ "Vectorized indexing with Dask arrays is not supported. "
+ "Please pass a numpy array by calling ``.compute``. "
+ "See https://github.com/dask/dask/issues/8958."
+ )
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
- f'invalid indexer array, does not have integer dtype: {k!r}'
- )
+ f"invalid indexer array, does not have integer dtype: {k!r}"
+ )
if ndim is None:
- ndim = k.ndim
+ ndim = k.ndim # type: ignore[union-attr]
elif ndim != k.ndim:
ndims = [k.ndim for k in key if isinstance(k, np.ndarray)]
raise ValueError(
- f'invalid indexer key: ndarray arguments have different numbers of dimensions: {ndims}'
- )
- k = k.astype(np.int64)
+ "invalid indexer key: ndarray arguments "
+ f"have different numbers of dimensions: {ndims}"
+ )
+ k = k.astype(np.int64) # type: ignore[union-attr]
else:
raise TypeError(
- f'unexpected indexer type for {type(self).__name__}: {k!r}'
- )
+ f"unexpected indexer type for {type(self).__name__}: {k!r}"
+ )
new_key.append(k)
+
super().__init__(tuple(new_key))
class ExplicitlyIndexed:
"""Mixin to mark support for Indexer subclasses in indexing."""
+
__slots__ = ()
- def __array__(self, dtype: np.typing.DTypeLike=None) ->np.ndarray:
+ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
+ # Leave casting to an array up to the underlying array type.
return np.asarray(self.get_duck_array(), dtype=dtype)
+ def get_duck_array(self):
+ return self.array
+
class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed):
__slots__ = ()
- def __array__(self, dtype: np.typing.DTypeLike=None) ->np.ndarray:
+ def get_duck_array(self):
+ key = BasicIndexer((slice(None),) * self.ndim)
+ return self[key]
+
+ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
+ # This is necessary because we apply the indexing key in self.get_duck_array()
+ # Note this is the base class for all lazy indexing classes
return np.asarray(self.get_duck_array(), dtype=dtype)
+ def _oindex_get(self, indexer: OuterIndexer):
+ raise NotImplementedError(
+ f"{self.__class__.__name__}._oindex_get method should be overridden"
+ )
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ raise NotImplementedError(
+ f"{self.__class__.__name__}._vindex_get method should be overridden"
+ )
+
+ def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
+ raise NotImplementedError(
+ f"{self.__class__.__name__}._oindex_set method should be overridden"
+ )
+
+ def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None:
+ raise NotImplementedError(
+ f"{self.__class__.__name__}._vindex_set method should be overridden"
+ )
+
+ def _check_and_raise_if_non_basic_indexer(self, indexer: ExplicitIndexer) -> None:
+ if isinstance(indexer, (VectorizedIndexer, OuterIndexer)):
+ raise TypeError(
+ "Vectorized indexing with vectorized or outer indexers is not supported. "
+ "Please use .vindex and .oindex properties to index the array."
+ )
+
+ @property
+ def oindex(self) -> IndexCallable:
+ return IndexCallable(self._oindex_get, self._oindex_set)
+
+ @property
+ def vindex(self) -> IndexCallable:
+ return IndexCallable(self._vindex_get, self._vindex_set)
+
class ImplicitToExplicitIndexingAdapter(NDArrayMixin):
"""Wrap an array, converting tuples into the indicated explicit indexer."""
- __slots__ = 'array', 'indexer_cls'
- def __init__(self, array, indexer_cls: type[ExplicitIndexer]=BasicIndexer):
+ __slots__ = ("array", "indexer_cls")
+
+ def __init__(self, array, indexer_cls: type[ExplicitIndexer] = BasicIndexer):
self.array = as_indexable(array)
self.indexer_cls = indexer_cls
- def __array__(self, dtype: np.typing.DTypeLike=None) ->np.ndarray:
+ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)
+ def get_duck_array(self):
+ return self.array.get_duck_array()
+
def __getitem__(self, key: Any):
key = expanded_indexer(key, self.ndim)
indexer = self.indexer_cls(key)
+
result = apply_indexer(self.array, indexer)
+
if isinstance(result, ExplicitlyIndexed):
return type(self)(result, self.indexer_cls)
else:
+ # Sometimes explicitly indexed arrays return NumPy arrays or
+ # scalars.
return result
class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array to make basic and outer indexing lazy."""
- __slots__ = 'array', 'key', '_shape'
- def __init__(self, array: Any, key: (ExplicitIndexer | None)=None):
+ __slots__ = ("array", "key", "_shape")
+
+ def __init__(self, array: Any, key: ExplicitIndexer | None = None):
"""
Parameters
----------
@@ -330,39 +606,99 @@ class LazilyIndexedArray(ExplicitlyIndexedNDArrayMixin):
canonical expanded form.
"""
if isinstance(array, type(self)) and key is None:
- key = array.key
- array = array.array
+ # unwrap
+ key = array.key # type: ignore[has-type]
+ array = array.array # type: ignore[has-type]
+
if key is None:
key = BasicIndexer((slice(None),) * array.ndim)
+
self.array = as_indexable(array)
self.key = key
+
shape: _Shape = ()
for size, k in zip(self.array.shape, self.key.tuple):
if isinstance(k, slice):
- shape += len(range(*k.indices(size))),
+ shape += (len(range(*k.indices(size))),)
elif isinstance(k, np.ndarray):
- shape += k.size,
+ shape += (k.size,)
self._shape = shape
+ def _updated_key(self, new_key: ExplicitIndexer) -> BasicIndexer | OuterIndexer:
+ iter_new_key = iter(expanded_indexer(new_key.tuple, self.ndim))
+ full_key = []
+ for size, k in zip(self.array.shape, self.key.tuple):
+ if isinstance(k, integer_types):
+ full_key.append(k)
+ else:
+ full_key.append(_index_indexer_1d(k, next(iter_new_key), size))
+ full_key_tuple = tuple(full_key)
+
+ if all(isinstance(k, integer_types + (slice,)) for k in full_key_tuple):
+ return BasicIndexer(full_key_tuple)
+ return OuterIndexer(full_key_tuple)
+
+ @property
+ def shape(self) -> _Shape:
+ return self._shape
+
+ def get_duck_array(self):
+ if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
+ array = apply_indexer(self.array, self.key)
+ else:
+ # If the array is not an ExplicitlyIndexedNDArrayMixin,
+ # it may wrap a BackendArray so use its __getitem__
+ array = self.array[self.key]
+
+ # self.array[self.key] is now a numpy array when
+ # self.array is a BackendArray subclass
+ # and self.key is BasicIndexer((slice(None, None, None),))
+ # so we need the explicit check for ExplicitlyIndexed
+ if isinstance(array, ExplicitlyIndexed):
+ array = array.get_duck_array()
+ return _wrap_numpy_scalars(array)
+
+ def transpose(self, order):
+ return LazilyVectorizedIndexedArray(self.array, self.key).transpose(order)
+
+ def _oindex_get(self, indexer: OuterIndexer):
+ return type(self)(self.array, self._updated_key(indexer))
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ array = LazilyVectorizedIndexedArray(self.array, self.key)
+ return array.vindex[indexer]
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(self.array, self._updated_key(indexer))
- def __setitem__(self, key: BasicIndexer, value: Any) ->None:
+ def _vindex_set(self, key: VectorizedIndexer, value: Any) -> None:
+ raise NotImplementedError(
+ "Lazy item assignment with the vectorized indexer is not yet "
+ "implemented. Load your data first by .load() or compute()."
+ )
+
+ def _oindex_set(self, key: OuterIndexer, value: Any) -> None:
+ full_key = self._updated_key(key)
+ self.array.oindex[full_key] = value
+
+ def __setitem__(self, key: BasicIndexer, value: Any) -> None:
self._check_and_raise_if_non_basic_indexer(key)
full_key = self._updated_key(key)
self.array[full_key] = value
- def __repr__(self) ->str:
- return f'{type(self).__name__}(array={self.array!r}, key={self.key!r})'
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})"
+# keep an alias to the old name for external backends pydata/xarray#5111
LazilyOuterIndexedArray = LazilyIndexedArray
class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array to make vectorized indexing lazy."""
- __slots__ = 'array', 'key'
+
+ __slots__ = ("array", "key")
def __init__(self, array: duckarray[Any, Any], key: ExplicitIndexer):
"""
@@ -378,61 +714,149 @@ class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin):
self.key = _arrayize_vectorized_indexer(key, array.shape)
self.array = as_indexable(array)
+ @property
+ def shape(self) -> _Shape:
+ return np.broadcast(*self.key.tuple).shape
+
+ def get_duck_array(self):
+ if isinstance(self.array, ExplicitlyIndexedNDArrayMixin):
+ array = apply_indexer(self.array, self.key)
+ else:
+ # If the array is not an ExplicitlyIndexedNDArrayMixin,
+ # it may wrap a BackendArray so use its __getitem__
+ array = self.array[self.key]
+ # self.array[self.key] is now a numpy array when
+ # self.array is a BackendArray subclass
+ # and self.key is BasicIndexer((slice(None, None, None),))
+ # so we need the explicit check for ExplicitlyIndexed
+ if isinstance(array, ExplicitlyIndexed):
+ array = array.get_duck_array()
+ return _wrap_numpy_scalars(array)
+
+ def _updated_key(self, new_key: ExplicitIndexer):
+ return _combine_indexers(self.key, self.shape, new_key)
+
+ def _oindex_get(self, indexer: OuterIndexer):
+ return type(self)(self.array, self._updated_key(indexer))
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ return type(self)(self.array, self._updated_key(indexer))
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
+ # If the indexed array becomes a scalar, return LazilyIndexedArray
if all(isinstance(ind, integer_types) for ind in indexer.tuple):
key = BasicIndexer(tuple(k[indexer.tuple] for k in self.key.tuple))
return LazilyIndexedArray(self.array, key)
return type(self)(self.array, self._updated_key(indexer))
- def __setitem__(self, indexer: ExplicitIndexer, value: Any) ->None:
+ def transpose(self, order):
+ key = VectorizedIndexer(tuple(k.transpose(order) for k in self.key.tuple))
+ return type(self)(self.array, key)
+
+ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
raise NotImplementedError(
- 'Lazy item assignment with the vectorized indexer is not yet implemented. Load your data first by .load() or compute().'
- )
+ "Lazy item assignment with the vectorized indexer is not yet "
+ "implemented. Load your data first by .load() or compute()."
+ )
- def __repr__(self) ->str:
- return f'{type(self).__name__}(array={self.array!r}, key={self.key!r})'
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}(array={self.array!r}, key={self.key!r})"
def _wrap_numpy_scalars(array):
"""Wrap NumPy scalars in 0d arrays."""
- pass
+ if np.isscalar(array):
+ return np.array(array)
+ else:
+ return array
class CopyOnWriteArray(ExplicitlyIndexedNDArrayMixin):
- __slots__ = 'array', '_copied'
+ __slots__ = ("array", "_copied")
def __init__(self, array: duckarray[Any, Any]):
self.array = as_indexable(array)
self._copied = False
+ def _ensure_copied(self):
+ if not self._copied:
+ self.array = as_indexable(np.array(self.array))
+ self._copied = True
+
+ def get_duck_array(self):
+ return self.array.get_duck_array()
+
+ def _oindex_get(self, indexer: OuterIndexer):
+ return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer]))
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer]))
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(_wrap_numpy_scalars(self.array[indexer]))
- def __setitem__(self, indexer: ExplicitIndexer, value: Any) ->None:
+ def transpose(self, order):
+ return self.array.transpose(order)
+
+ def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None:
+ self._ensure_copied()
+ self.array.vindex[indexer] = value
+
+ def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
+ self._ensure_copied()
+ self.array.oindex[indexer] = value
+
+ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
self._check_and_raise_if_non_basic_indexer(indexer)
self._ensure_copied()
+
self.array[indexer] = value
def __deepcopy__(self, memo):
+ # CopyOnWriteArray is used to wrap backend array objects, which might
+ # point to files on disk, so we can't rely on the default deepcopy
+ # implementation.
return type(self)(self.array)
class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin):
- __slots__ = 'array',
+ __slots__ = ("array",)
def __init__(self, array):
self.array = _wrap_numpy_scalars(as_indexable(array))
- def __array__(self, dtype: np.typing.DTypeLike=None) ->np.ndarray:
+ def _ensure_cached(self):
+ self.array = as_indexable(self.array.get_duck_array())
+
+ def __array__(self, dtype: np.typing.DTypeLike = None) -> np.ndarray:
return np.asarray(self.get_duck_array(), dtype=dtype)
+ def get_duck_array(self):
+ self._ensure_cached()
+ return self.array.get_duck_array()
+
+ def _oindex_get(self, indexer: OuterIndexer):
+ return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer]))
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ return type(self)(_wrap_numpy_scalars(self.array.vindex[indexer]))
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return type(self)(_wrap_numpy_scalars(self.array[indexer]))
- def __setitem__(self, indexer: ExplicitIndexer, value: Any) ->None:
+ def transpose(self, order):
+ return self.array.transpose(order)
+
+ def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None:
+ self.array.vindex[indexer] = value
+
+ def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
+ self.array.oindex[indexer] = value
+
+ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
self._check_and_raise_if_non_basic_indexer(indexer)
self.array[indexer] = value
@@ -443,11 +867,25 @@ def as_indexable(array):
so that the vectorized indexing is always possible with the returned
object.
"""
- pass
-
-
-def _outer_to_vectorized_indexer(indexer: (BasicIndexer | OuterIndexer),
- shape: _Shape) ->VectorizedIndexer:
+ if isinstance(array, ExplicitlyIndexed):
+ return array
+ if isinstance(array, np.ndarray):
+ return NumpyIndexingAdapter(array)
+ if isinstance(array, pd.Index):
+ return PandasIndexingAdapter(array)
+ if is_duck_dask_array(array):
+ return DaskIndexingAdapter(array)
+ if hasattr(array, "__array_function__"):
+ return NdArrayLikeIndexingAdapter(array)
+ if hasattr(array, "__array_namespace__"):
+ return ArrayApiIndexingAdapter(array)
+
+ raise TypeError(f"Invalid array type: {type(array)}")
+
+
+def _outer_to_vectorized_indexer(
+ indexer: BasicIndexer | OuterIndexer, shape: _Shape
+) -> VectorizedIndexer:
"""Convert an OuterIndexer into an vectorized indexer.
Parameters
@@ -464,11 +902,25 @@ def _outer_to_vectorized_indexer(indexer: (BasicIndexer | OuterIndexer),
Each element is an array: broadcasting them together gives the shape
of the result.
"""
- pass
+ key = indexer.tuple
+
+ n_dim = len([k for k in key if not isinstance(k, integer_types)])
+ i_dim = 0
+ new_key = []
+ for k, size in zip(key, shape):
+ if isinstance(k, integer_types):
+ new_key.append(np.array(k).reshape((1,) * n_dim))
+ else: # np.ndarray or slice
+ if isinstance(k, slice):
+ k = np.arange(*k.indices(size))
+ assert k.dtype.kind in {"i", "u"}
+ new_shape = [(1,) * i_dim + (k.size,) + (1,) * (n_dim - i_dim - 1)]
+ new_key.append(k.reshape(*new_shape))
+ i_dim += 1
+ return VectorizedIndexer(tuple(new_key))
-def _outer_to_numpy_indexer(indexer: (BasicIndexer | OuterIndexer), shape:
- _Shape):
+def _outer_to_numpy_indexer(indexer: BasicIndexer | OuterIndexer, shape: _Shape):
"""Convert an OuterIndexer into an indexer for NumPy.
Parameters
@@ -483,10 +935,16 @@ def _outer_to_numpy_indexer(indexer: (BasicIndexer | OuterIndexer), shape:
tuple
Tuple suitable for use to index a NumPy array.
"""
- pass
+ if len([k for k in indexer.tuple if not isinstance(k, slice)]) <= 1:
+ # If there is only one vector and all others are slice,
+ # it can be safely used in mixed basic/advanced indexing.
+ # Boolean index should already be converted to integer array.
+ return indexer.tuple
+ else:
+ return _outer_to_vectorized_indexer(indexer, shape).tuple
-def _combine_indexers(old_key, shape: _Shape, new_key) ->VectorizedIndexer:
+def _combine_indexers(old_key, shape: _Shape, new_key) -> VectorizedIndexer:
"""Combine two indexers.
Parameters
@@ -498,20 +956,40 @@ def _combine_indexers(old_key, shape: _Shape, new_key) ->VectorizedIndexer:
new_key
The second indexer for indexing original[old_key]
"""
- pass
+ if not isinstance(old_key, VectorizedIndexer):
+ old_key = _outer_to_vectorized_indexer(old_key, shape)
+ if len(old_key.tuple) == 0:
+ return new_key
+
+ new_shape = np.broadcast(*old_key.tuple).shape
+ if isinstance(new_key, VectorizedIndexer):
+ new_key = _arrayize_vectorized_indexer(new_key, new_shape)
+ else:
+ new_key = _outer_to_vectorized_indexer(new_key, new_shape)
+
+ return VectorizedIndexer(
+ tuple(o[new_key.tuple] for o in np.broadcast_arrays(*old_key.tuple))
+ )
@enum.unique
class IndexingSupport(enum.Enum):
+ # for backends that support only basic indexer
BASIC = 0
+ # for backends that support basic / outer indexer
OUTER = 1
+ # for backends that support outer indexer including at most 1 vector.
OUTER_1VECTOR = 2
+ # for backends that support full vectorized indexer.
VECTORIZED = 3
-def explicit_indexing_adapter(key: ExplicitIndexer, shape: _Shape,
- indexing_support: IndexingSupport, raw_indexing_method: Callable[..., Any]
- ) ->Any:
+def explicit_indexing_adapter(
+ key: ExplicitIndexer,
+ shape: _Shape,
+ indexing_support: IndexingSupport,
+ raw_indexing_method: Callable[..., Any],
+) -> Any:
"""Support explicit indexing by delegating to a raw indexing method.
Outer and/or vectorized indexers are supported by indexing a second time
@@ -533,20 +1011,46 @@ def explicit_indexing_adapter(key: ExplicitIndexer, shape: _Shape,
-------
Indexing result, in the form of a duck numpy-array.
"""
- pass
+ raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support)
+ result = raw_indexing_method(raw_key.tuple)
+ if numpy_indices.tuple:
+ # index the loaded np.ndarray
+ indexable = NumpyIndexingAdapter(result)
+ result = apply_indexer(indexable, numpy_indices)
+ return result
def apply_indexer(indexable, indexer: ExplicitIndexer):
"""Apply an indexer to an indexable object."""
- pass
+ if isinstance(indexer, VectorizedIndexer):
+ return indexable.vindex[indexer]
+ elif isinstance(indexer, OuterIndexer):
+ return indexable.oindex[indexer]
+ else:
+ return indexable[indexer]
-def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) ->None:
+def set_with_indexer(indexable, indexer: ExplicitIndexer, value: Any) -> None:
"""Set values in an indexable object using an indexer."""
- pass
+ if isinstance(indexer, VectorizedIndexer):
+ indexable.vindex[indexer] = value
+ elif isinstance(indexer, OuterIndexer):
+ indexable.oindex[indexer] = value
+ else:
+ indexable[indexer] = value
+
+
+def decompose_indexer(
+ indexer: ExplicitIndexer, shape: _Shape, indexing_support: IndexingSupport
+) -> tuple[ExplicitIndexer, ExplicitIndexer]:
+ if isinstance(indexer, VectorizedIndexer):
+ return _decompose_vectorized_indexer(indexer, shape, indexing_support)
+ if isinstance(indexer, (BasicIndexer, OuterIndexer)):
+ return _decompose_outer_indexer(indexer, shape, indexing_support)
+ raise TypeError(f"unexpected key type: {indexer}")
-def _decompose_slice(key: slice, size: int) ->tuple[slice, slice]:
+def _decompose_slice(key: slice, size: int) -> tuple[slice, slice]:
"""convert a slice to successive two slices. The first slice always has
a positive step.
@@ -562,12 +1066,23 @@ def _decompose_slice(key: slice, size: int) ->tuple[slice, slice]:
>>> _decompose_slice(slice(360, None, -10), 361)
(slice(0, 361, 10), slice(None, None, -1))
"""
- pass
-
-
-def _decompose_vectorized_indexer(indexer: VectorizedIndexer, shape: _Shape,
- indexing_support: IndexingSupport) ->tuple[ExplicitIndexer, ExplicitIndexer
- ]:
+ start, stop, step = key.indices(size)
+ if step > 0:
+ # If key already has a positive step, use it as is in the backend
+ return key, slice(None)
+ else:
+ # determine stop precisely for step > 1 case
+ # Use the range object to do the calculation
+ # e.g. [98:2:-2] -> [98:3:-2]
+ exact_stop = range(start, stop, step)[-1]
+ return slice(exact_stop, start + 1, -step), slice(None, None, -1)
+
+
+def _decompose_vectorized_indexer(
+ indexer: VectorizedIndexer,
+ shape: _Shape,
+ indexing_support: IndexingSupport,
+) -> tuple[ExplicitIndexer, ExplicitIndexer]:
"""
Decompose vectorized indexer to the successive two indexers, where the
first indexer will be used to index backend arrays, while the second one
@@ -602,12 +1117,54 @@ def _decompose_vectorized_indexer(indexer: VectorizedIndexer, shape: _Shape,
... NumpyIndexingAdapter(array).vindex[np_indexer]
array([ 2, 21, 8])
"""
- pass
-
-
-def _decompose_outer_indexer(indexer: (BasicIndexer | OuterIndexer), shape:
- _Shape, indexing_support: IndexingSupport) ->tuple[ExplicitIndexer,
- ExplicitIndexer]:
+ assert isinstance(indexer, VectorizedIndexer)
+
+ if indexing_support is IndexingSupport.VECTORIZED:
+ return indexer, BasicIndexer(())
+
+ backend_indexer_elems = []
+ np_indexer_elems = []
+ # convert negative indices
+ indexer_elems = [
+ np.where(k < 0, k + s, k) if isinstance(k, np.ndarray) else k
+ for k, s in zip(indexer.tuple, shape)
+ ]
+
+ for k, s in zip(indexer_elems, shape):
+ if isinstance(k, slice):
+ # If it is a slice, then we will slice it as-is
+ # (but make its step positive) in the backend,
+ # and then use all of it (slice(None)) for the in-memory portion.
+ bk_slice, np_slice = _decompose_slice(k, s)
+ backend_indexer_elems.append(bk_slice)
+ np_indexer_elems.append(np_slice)
+ else:
+ # If it is a (multidimensional) np.ndarray, just pickup the used
+ # keys without duplication and store them as a 1d-np.ndarray.
+ oind, vind = np.unique(k, return_inverse=True)
+ backend_indexer_elems.append(oind)
+ np_indexer_elems.append(vind.reshape(*k.shape))
+
+ backend_indexer = OuterIndexer(tuple(backend_indexer_elems))
+ np_indexer = VectorizedIndexer(tuple(np_indexer_elems))
+
+ if indexing_support is IndexingSupport.OUTER:
+ return backend_indexer, np_indexer
+
+ # If the backend does not support outer indexing,
+ # backend_indexer (OuterIndexer) is also decomposed.
+ backend_indexer1, np_indexer1 = _decompose_outer_indexer(
+ backend_indexer, shape, indexing_support
+ )
+ np_indexer = _combine_indexers(np_indexer1, shape, np_indexer)
+ return backend_indexer1, np_indexer
+
+
+def _decompose_outer_indexer(
+ indexer: BasicIndexer | OuterIndexer,
+ shape: _Shape,
+ indexing_support: IndexingSupport,
+) -> tuple[ExplicitIndexer, ExplicitIndexer]:
"""
Decompose outer indexer to the successive two indexers, where the
first indexer will be used to index backend arrays, while the second one
@@ -644,23 +1201,175 @@ def _decompose_outer_indexer(indexer: (BasicIndexer | OuterIndexer), shape:
[14, 15, 14],
[ 8, 9, 8]])
"""
- pass
+ backend_indexer: list[Any] = []
+ np_indexer: list[Any] = []
+ assert isinstance(indexer, (OuterIndexer, BasicIndexer))
-def _arrayize_vectorized_indexer(indexer: VectorizedIndexer, shape: _Shape
- ) ->VectorizedIndexer:
+ if indexing_support == IndexingSupport.VECTORIZED:
+ for k, s in zip(indexer.tuple, shape):
+ if isinstance(k, slice):
+ # If it is a slice, then we will slice it as-is
+ # (but make its step positive) in the backend,
+ bk_slice, np_slice = _decompose_slice(k, s)
+ backend_indexer.append(bk_slice)
+ np_indexer.append(np_slice)
+ else:
+ backend_indexer.append(k)
+ if not is_scalar(k):
+ np_indexer.append(slice(None))
+ return type(indexer)(tuple(backend_indexer)), BasicIndexer(tuple(np_indexer))
+
+ # make indexer positive
+ pos_indexer: list[np.ndarray | int | np.number] = []
+ for k, s in zip(indexer.tuple, shape):
+ if isinstance(k, np.ndarray):
+ pos_indexer.append(np.where(k < 0, k + s, k))
+ elif isinstance(k, integer_types) and k < 0:
+ pos_indexer.append(k + s)
+ else:
+ pos_indexer.append(k)
+ indexer_elems = pos_indexer
+
+ if indexing_support is IndexingSupport.OUTER_1VECTOR:
+ # some backends such as h5py supports only 1 vector in indexers
+ # We choose the most efficient axis
+ gains = [
+ (
+ (np.max(k) - np.min(k) + 1.0) / len(np.unique(k))
+ if isinstance(k, np.ndarray)
+ else 0
+ )
+ for k in indexer_elems
+ ]
+ array_index = np.argmax(np.array(gains)) if len(gains) > 0 else None
+
+ for i, (k, s) in enumerate(zip(indexer_elems, shape)):
+ if isinstance(k, np.ndarray) and i != array_index:
+ # np.ndarray key is converted to slice that covers the entire
+ # entries of this key.
+ backend_indexer.append(slice(np.min(k), np.max(k) + 1))
+ np_indexer.append(k - np.min(k))
+ elif isinstance(k, np.ndarray):
+ # Remove duplicates and sort them in the increasing order
+ pkey, ekey = np.unique(k, return_inverse=True)
+ backend_indexer.append(pkey)
+ np_indexer.append(ekey)
+ elif isinstance(k, integer_types):
+ backend_indexer.append(k)
+ else: # slice: convert positive step slice for backend
+ bk_slice, np_slice = _decompose_slice(k, s)
+ backend_indexer.append(bk_slice)
+ np_indexer.append(np_slice)
+
+ return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer)))
+
+ if indexing_support == IndexingSupport.OUTER:
+ for k, s in zip(indexer_elems, shape):
+ if isinstance(k, slice):
+ # slice: convert positive step slice for backend
+ bk_slice, np_slice = _decompose_slice(k, s)
+ backend_indexer.append(bk_slice)
+ np_indexer.append(np_slice)
+ elif isinstance(k, integer_types):
+ backend_indexer.append(k)
+ elif isinstance(k, np.ndarray) and (np.diff(k) >= 0).all():
+ backend_indexer.append(k)
+ np_indexer.append(slice(None))
+ else:
+ # Remove duplicates and sort them in the increasing order
+ oind, vind = np.unique(k, return_inverse=True)
+ backend_indexer.append(oind)
+ np_indexer.append(vind.reshape(*k.shape))
+
+ return (OuterIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer)))
+
+ # basic indexer
+ assert indexing_support == IndexingSupport.BASIC
+
+ for k, s in zip(indexer_elems, shape):
+ if isinstance(k, np.ndarray):
+ # np.ndarray key is converted to slice that covers the entire
+ # entries of this key.
+ backend_indexer.append(slice(np.min(k), np.max(k) + 1))
+ np_indexer.append(k - np.min(k))
+ elif isinstance(k, integer_types):
+ backend_indexer.append(k)
+ else: # slice: convert positive step slice for backend
+ bk_slice, np_slice = _decompose_slice(k, s)
+ backend_indexer.append(bk_slice)
+ np_indexer.append(np_slice)
+
+ return (BasicIndexer(tuple(backend_indexer)), OuterIndexer(tuple(np_indexer)))
+
+
+def _arrayize_vectorized_indexer(
+ indexer: VectorizedIndexer, shape: _Shape
+) -> VectorizedIndexer:
"""Return an identical vindex but slices are replaced by arrays"""
- pass
+ slices = [v for v in indexer.tuple if isinstance(v, slice)]
+ if len(slices) == 0:
+ return indexer
+
+ arrays = [v for v in indexer.tuple if isinstance(v, np.ndarray)]
+ n_dim = arrays[0].ndim if len(arrays) > 0 else 0
+ i_dim = 0
+ new_key = []
+ for v, size in zip(indexer.tuple, shape):
+ if isinstance(v, np.ndarray):
+ new_key.append(np.reshape(v, v.shape + (1,) * len(slices)))
+ else: # slice
+ shape = (1,) * (n_dim + i_dim) + (-1,) + (1,) * (len(slices) - i_dim - 1)
+ new_key.append(np.arange(*v.indices(size)).reshape(shape))
+ i_dim += 1
+ return VectorizedIndexer(tuple(new_key))
+
+
+def _chunked_array_with_chunks_hint(
+ array, chunks, chunkmanager: ChunkManagerEntrypoint[Any]
+):
+ """Create a chunked array using the chunks hint for dimensions of size > 1."""
+ if len(chunks) < array.ndim:
+ raise ValueError("not enough chunks in hint")
+ new_chunks = []
+ for chunk, size in zip(chunks, array.shape):
+ new_chunks.append(chunk if size > 1 else (1,))
+ return chunkmanager.from_array(array, new_chunks) # type: ignore[arg-type]
-def _chunked_array_with_chunks_hint(array, chunks, chunkmanager:
- ChunkManagerEntrypoint[Any]):
- """Create a chunked array using the chunks hint for dimensions of size > 1."""
- pass
+
+def _logical_any(args):
+ return functools.reduce(operator.or_, args)
+
+
+def _masked_result_drop_slice(key, data: duckarray[Any, Any] | None = None):
+ key = (k for k in key if not isinstance(k, slice))
+ chunks_hint = getattr(data, "chunks", None)
+
+ new_keys = []
+ for k in key:
+ if isinstance(k, np.ndarray):
+ if is_chunked_array(data): # type: ignore[arg-type]
+ chunkmanager = get_chunked_array_type(data)
+ new_keys.append(
+ _chunked_array_with_chunks_hint(k, chunks_hint, chunkmanager)
+ )
+ elif isinstance(data, array_type("sparse")):
+ import sparse
+
+ new_keys.append(sparse.COO.from_numpy(k))
+ else:
+ new_keys.append(k)
+ else:
+ new_keys.append(k)
+
+ mask = _logical_any(k == -1 for k in new_keys)
+ return mask
-def create_mask(indexer: ExplicitIndexer, shape: _Shape, data: (duckarray[
- Any, Any] | None)=None):
+def create_mask(
+ indexer: ExplicitIndexer, shape: _Shape, data: duckarray[Any, Any] | None = None
+):
"""Create a mask for indexing with a fill-value.
Parameters
@@ -680,11 +1389,34 @@ def create_mask(indexer: ExplicitIndexer, shape: _Shape, data: (duckarray[
mask : bool, np.ndarray, SparseArray or dask.array.Array with dtype=bool
Same type as data. Has the same shape as the indexing result.
"""
- pass
-
-
-def _posify_mask_subindexer(index: np.ndarray[Any, np.dtype[np.generic]]
- ) ->np.ndarray[Any, np.dtype[np.generic]]:
+ if isinstance(indexer, OuterIndexer):
+ key = _outer_to_vectorized_indexer(indexer, shape).tuple
+ assert not any(isinstance(k, slice) for k in key)
+ mask = _masked_result_drop_slice(key, data)
+
+ elif isinstance(indexer, VectorizedIndexer):
+ key = indexer.tuple
+ base_mask = _masked_result_drop_slice(key, data)
+ slice_shape = tuple(
+ np.arange(*k.indices(size)).size
+ for k, size in zip(key, shape)
+ if isinstance(k, slice)
+ )
+ expanded_mask = base_mask[(Ellipsis,) + (np.newaxis,) * len(slice_shape)]
+ mask = duck_array_ops.broadcast_to(expanded_mask, base_mask.shape + slice_shape)
+
+ elif isinstance(indexer, BasicIndexer):
+ mask = any(k == -1 for k in indexer.tuple)
+
+ else:
+ raise TypeError(f"unexpected key type: {type(indexer)}")
+
+ return mask
+
+
+def _posify_mask_subindexer(
+ index: np.ndarray[Any, np.dtype[np.generic]],
+) -> np.ndarray[Any, np.dtype[np.generic]]:
"""Convert masked indices in a flat array to the nearest unmasked index.
Parameters
@@ -698,10 +1430,19 @@ def _posify_mask_subindexer(index: np.ndarray[Any, np.dtype[np.generic]]
One dimensional ndarray with all values equal to -1 replaced by an
adjacent non-masked element.
"""
- pass
-
-
-def posify_mask_indexer(indexer: ExplicitIndexer) ->ExplicitIndexer:
+ masked = index == -1
+ unmasked_locs = np.flatnonzero(~masked)
+ if not unmasked_locs.size:
+ # indexing unmasked_locs is invalid
+ return np.zeros_like(index)
+ masked_locs = np.flatnonzero(masked)
+ prev_value = np.maximum(0, np.searchsorted(unmasked_locs, masked_locs) - 1)
+ new_index = index.copy()
+ new_index[masked_locs] = index[unmasked_locs[prev_value]]
+ return new_index
+
+
+def posify_mask_indexer(indexer: ExplicitIndexer) -> ExplicitIndexer:
"""Convert masked values (-1) in an indexer to nearest unmasked values.
This routine is useful for dask, where it can be much faster to index
@@ -718,74 +1459,155 @@ def posify_mask_indexer(indexer: ExplicitIndexer) ->ExplicitIndexer:
Same type of input, with all values in ndarray keys equal to -1
replaced by an adjacent non-masked element.
"""
- pass
-
-
-def is_fancy_indexer(indexer: Any) ->bool:
+ key = tuple(
+ (
+ _posify_mask_subindexer(k.ravel()).reshape(k.shape)
+ if isinstance(k, np.ndarray)
+ else k
+ )
+ for k in indexer.tuple
+ )
+ return type(indexer)(key)
+
+
+def is_fancy_indexer(indexer: Any) -> bool:
"""Return False if indexer is a int, slice, a 1-dimensional list, or a 0 or
1-dimensional ndarray; in all other cases return True
"""
- pass
+ if isinstance(indexer, (int, slice)):
+ return False
+ if isinstance(indexer, np.ndarray):
+ return indexer.ndim > 1
+ if isinstance(indexer, list):
+ return bool(indexer) and not isinstance(indexer[0], int)
+ return True
class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a NumPy array to use explicit indexing."""
- __slots__ = 'array',
+
+ __slots__ = ("array",)
def __init__(self, array):
+ # In NumpyIndexingAdapter we only allow to store bare np.ndarray
if not isinstance(array, np.ndarray):
raise TypeError(
- f'NumpyIndexingAdapter only wraps np.ndarray. Trying to wrap {type(array)}'
- )
+ "NumpyIndexingAdapter only wraps np.ndarray. "
+ f"Trying to wrap {type(array)}"
+ )
self.array = array
+ def transpose(self, order):
+ return self.array.transpose(order)
+
+ def _oindex_get(self, indexer: OuterIndexer):
+ key = _outer_to_numpy_indexer(indexer, self.array.shape)
+ return self.array[key]
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ array = NumpyVIndexAdapter(self.array)
+ return array[indexer.tuple]
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
+
array = self.array
+ # We want 0d slices rather than scalars. This is achieved by
+ # appending an ellipsis (see
+ # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = indexer.tuple + (Ellipsis,)
return array[key]
- def __setitem__(self, indexer: ExplicitIndexer, value: Any) ->None:
+ def _safe_setitem(self, array, key: tuple[Any, ...], value: Any) -> None:
+ try:
+ array[key] = value
+ except ValueError as exc:
+ # More informative exception if read-only view
+ if not array.flags.writeable and not array.flags.owndata:
+ raise ValueError(
+ "Assignment destination is a view. "
+ "Do you want to .copy() array first?"
+ )
+ else:
+ raise exc
+
+ def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
+ key = _outer_to_numpy_indexer(indexer, self.array.shape)
+ self._safe_setitem(self.array, key, value)
+
+ def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None:
+ array = NumpyVIndexAdapter(self.array)
+ self._safe_setitem(array, indexer.tuple, value)
+
+ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
self._check_and_raise_if_non_basic_indexer(indexer)
array = self.array
+ # We want 0d slices rather than scalars. This is achieved by
+ # appending an ellipsis (see
+ # https://numpy.org/doc/stable/reference/arrays.indexing.html#detailed-notes).
key = indexer.tuple + (Ellipsis,)
self._safe_setitem(array, key, value)
class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
- __slots__ = 'array',
+ __slots__ = ("array",)
def __init__(self, array):
- if not hasattr(array, '__array_function__'):
+ if not hasattr(array, "__array_function__"):
raise TypeError(
- 'NdArrayLikeIndexingAdapter must wrap an object that implements the __array_function__ protocol'
- )
+ "NdArrayLikeIndexingAdapter must wrap an object that "
+ "implements the __array_function__ protocol"
+ )
self.array = array
class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap an array API array to use explicit indexing."""
- __slots__ = 'array',
+
+ __slots__ = ("array",)
def __init__(self, array):
- if not hasattr(array, '__array_namespace__'):
+ if not hasattr(array, "__array_namespace__"):
raise TypeError(
- 'ArrayApiIndexingAdapter must wrap an object that implements the __array_namespace__ protocol'
- )
+ "ArrayApiIndexingAdapter must wrap an object that "
+ "implements the __array_namespace__ protocol"
+ )
self.array = array
+ def _oindex_get(self, indexer: OuterIndexer):
+ # manual orthogonal indexing (implemented like DaskIndexingAdapter)
+ key = indexer.tuple
+ value = self.array
+ for axis, subkey in reversed(list(enumerate(key))):
+ value = value[(slice(None),) * axis + (subkey, Ellipsis)]
+ return value
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ raise TypeError("Vectorized indexing is not supported")
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return self.array[indexer.tuple]
- def __setitem__(self, indexer: ExplicitIndexer, value: Any) ->None:
+ def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
+ self.array[indexer.tuple] = value
+
+ def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None:
+ raise TypeError("Vectorized indexing is not supported")
+
+ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
self._check_and_raise_if_non_basic_indexer(indexer)
self.array[indexer.tuple] = value
+ def transpose(self, order):
+ xp = self.array.__array_namespace__()
+ return xp.permute_dims(self.array, order)
+
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""
- __slots__ = 'array',
+
+ __slots__ = ("array",)
def __init__(self, array):
"""This adapter is created in Variable.__getitem__ in
@@ -793,52 +1615,197 @@ class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""
self.array = array
+ def _oindex_get(self, indexer: OuterIndexer):
+ key = indexer.tuple
+ try:
+ return self.array[key]
+ except NotImplementedError:
+ # manual orthogonal indexing
+ value = self.array
+ for axis, subkey in reversed(list(enumerate(key))):
+ value = value[(slice(None),) * axis + (subkey,)]
+ return value
+
+ def _vindex_get(self, indexer: VectorizedIndexer):
+ return self.array.vindex[indexer.tuple]
+
def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
return self.array[indexer.tuple]
- def __setitem__(self, indexer: ExplicitIndexer, value: Any) ->None:
+ def _oindex_set(self, indexer: OuterIndexer, value: Any) -> None:
+ num_non_slices = sum(0 if isinstance(k, slice) else 1 for k in indexer.tuple)
+ if num_non_slices > 1:
+ raise NotImplementedError(
+ "xarray can't set arrays with multiple " "array indices to dask yet."
+ )
+ self.array[indexer.tuple] = value
+
+ def _vindex_set(self, indexer: VectorizedIndexer, value: Any) -> None:
+ self.array.vindex[indexer.tuple] = value
+
+ def __setitem__(self, indexer: ExplicitIndexer, value: Any) -> None:
self._check_and_raise_if_non_basic_indexer(indexer)
self.array[indexer.tuple] = value
+ def transpose(self, order):
+ return self.array.transpose(order)
+
class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a pandas.Index to preserve dtypes and handle explicit indexing."""
- __slots__ = 'array', '_dtype'
+
+ __slots__ = ("array", "_dtype")
+
array: pd.Index
_dtype: np.dtype
- def __init__(self, array: pd.Index, dtype: DTypeLike=None):
+ def __init__(self, array: pd.Index, dtype: DTypeLike = None):
from xarray.core.indexes import safe_cast_to_index
+
self.array = safe_cast_to_index(array)
+
if dtype is None:
self._dtype = get_valid_numpy_dtype(array)
else:
self._dtype = np.dtype(dtype)
- def __array__(self, dtype: DTypeLike=None) ->np.ndarray:
+ @property
+ def dtype(self) -> np.dtype:
+ return self._dtype
+
+ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
if dtype is None:
dtype = self.dtype
array = self.array
if isinstance(array, pd.PeriodIndex):
with suppress(AttributeError):
- array = array.astype('object')
+ # this might not be public API
+ array = array.astype("object")
return np.asarray(array.values, dtype=dtype)
- def __getitem__(self, indexer: ExplicitIndexer) ->(
- PandasIndexingAdapter | NumpyIndexingAdapter | np.ndarray | np.
- datetime64 | np.timedelta64):
+ def get_duck_array(self) -> np.ndarray:
+ return np.asarray(self)
+
+ @property
+ def shape(self) -> _Shape:
+ return (len(self.array),)
+
+ def _convert_scalar(self, item):
+ if item is pd.NaT:
+ # work around the impossibility of casting NaT with asarray
+ # note: it probably would be better in general to return
+ # pd.Timestamp rather np.than datetime64 but this is easier
+ # (for now)
+ item = np.datetime64("NaT", "ns")
+ elif isinstance(item, timedelta):
+ item = np.timedelta64(getattr(item, "value", item), "ns")
+ elif isinstance(item, pd.Timestamp):
+ # Work around for GH: pydata/xarray#1932 and numpy/numpy#10668
+ # numpy fails to convert pd.Timestamp to np.datetime64[ns]
+ item = np.asarray(item.to_datetime64())
+ elif self.dtype != object:
+ item = np.asarray(item, dtype=self.dtype)
+
+ # as for numpy.ndarray indexing, we always want the result to be
+ # a NumPy array.
+ return to_0d_array(item)
+
+ def _prepare_key(self, key: tuple[Any, ...]) -> tuple[Any, ...]:
+ if isinstance(key, tuple) and len(key) == 1:
+ # unpack key so it can index a pandas.Index object (pandas.Index
+ # objects don't like tuples)
+ (key,) = key
+
+ return key
+
+ def _handle_result(
+ self, result: Any
+ ) -> (
+ PandasIndexingAdapter
+ | NumpyIndexingAdapter
+ | np.ndarray
+ | np.datetime64
+ | np.timedelta64
+ ):
+ if isinstance(result, pd.Index):
+ return type(self)(result, dtype=self.dtype)
+ else:
+ return self._convert_scalar(result)
+
+ def _oindex_get(
+ self, indexer: OuterIndexer
+ ) -> (
+ PandasIndexingAdapter
+ | NumpyIndexingAdapter
+ | np.ndarray
+ | np.datetime64
+ | np.timedelta64
+ ):
key = self._prepare_key(indexer.tuple)
- if getattr(key, 'ndim', 0) > 1:
+
+ if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
+ indexable = NumpyIndexingAdapter(np.asarray(self))
+ return indexable.oindex[indexer]
+
+ result = self.array[key]
+
+ return self._handle_result(result)
+
+ def _vindex_get(
+ self, indexer: VectorizedIndexer
+ ) -> (
+ PandasIndexingAdapter
+ | NumpyIndexingAdapter
+ | np.ndarray
+ | np.datetime64
+ | np.timedelta64
+ ):
+ key = self._prepare_key(indexer.tuple)
+
+ if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
+ indexable = NumpyIndexingAdapter(np.asarray(self))
+ return indexable.vindex[indexer]
+
+ result = self.array[key]
+
+ return self._handle_result(result)
+
+ def __getitem__(
+ self, indexer: ExplicitIndexer
+ ) -> (
+ PandasIndexingAdapter
+ | NumpyIndexingAdapter
+ | np.ndarray
+ | np.datetime64
+ | np.timedelta64
+ ):
+ key = self._prepare_key(indexer.tuple)
+
+ if getattr(key, "ndim", 0) > 1: # Return np-array if multidimensional
indexable = NumpyIndexingAdapter(np.asarray(self))
return indexable[indexer]
+
result = self.array[key]
+
return self._handle_result(result)
- def __repr__(self) ->str:
- return (
- f'{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})'
- )
+ def transpose(self, order) -> pd.Index:
+ return self.array # self.array should be always one-dimensional
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})"
+
+ def copy(self, deep: bool = True) -> Self:
+ # Not the same as just writing `self.array.copy(deep=deep)`, as
+ # shallow copies of the underlying numpy.ndarrays become deep ones
+ # upon pickling
+ # >>> len(pickle.dumps((self.array, self.array)))
+ # 4000281
+ # >>> len(pickle.dumps((self.array, self.array.copy(deep=False))))
+ # 8000341
+ array = self.array.copy(deep=True) if deep else self.array
+ return type(self)(array, self._dtype)
class PandasMultiIndexingAdapter(PandasIndexingAdapter):
@@ -848,36 +1815,109 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter):
preserving indexing efficiency (memoized + might reuse another instance with
the same multi-index).
"""
- __slots__ = 'array', '_dtype', 'level', 'adapter'
+
+ __slots__ = ("array", "_dtype", "level", "adapter")
+
array: pd.MultiIndex
_dtype: np.dtype
level: str | None
- def __init__(self, array: pd.MultiIndex, dtype: DTypeLike=None, level:
- (str | None)=None):
+ def __init__(
+ self,
+ array: pd.MultiIndex,
+ dtype: DTypeLike = None,
+ level: str | None = None,
+ ):
super().__init__(array, dtype)
self.level = level
- def __array__(self, dtype: DTypeLike=None) ->np.ndarray:
+ def __array__(self, dtype: DTypeLike = None) -> np.ndarray:
if dtype is None:
dtype = self.dtype
if self.level is not None:
- return np.asarray(self.array.get_level_values(self.level).
- values, dtype=dtype)
+ return np.asarray(
+ self.array.get_level_values(self.level).values, dtype=dtype
+ )
else:
return super().__array__(dtype)
+ def _convert_scalar(self, item):
+ if isinstance(item, tuple) and self.level is not None:
+ idx = tuple(self.array.names).index(self.level)
+ item = item[idx]
+ return super()._convert_scalar(item)
+
+ def _oindex_get(
+ self, indexer: OuterIndexer
+ ) -> (
+ PandasIndexingAdapter
+ | NumpyIndexingAdapter
+ | np.ndarray
+ | np.datetime64
+ | np.timedelta64
+ ):
+ result = super()._oindex_get(indexer)
+ if isinstance(result, type(self)):
+ result.level = self.level
+ return result
+
+ def _vindex_get(
+ self, indexer: VectorizedIndexer
+ ) -> (
+ PandasIndexingAdapter
+ | NumpyIndexingAdapter
+ | np.ndarray
+ | np.datetime64
+ | np.timedelta64
+ ):
+ result = super()._vindex_get(indexer)
+ if isinstance(result, type(self)):
+ result.level = self.level
+ return result
+
def __getitem__(self, indexer: ExplicitIndexer):
result = super().__getitem__(indexer)
if isinstance(result, type(self)):
result.level = self.level
+
return result
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
if self.level is None:
return super().__repr__()
else:
props = (
- f'(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})'
- )
- return f'{type(self).__name__}{props}'
+ f"(array={self.array!r}, level={self.level!r}, dtype={self.dtype!r})"
+ )
+ return f"{type(self).__name__}{props}"
+
+ def _get_array_subset(self) -> np.ndarray:
+ # used to speed-up the repr for big multi-indexes
+ threshold = max(100, OPTIONS["display_values_threshold"] + 2)
+ if self.size > threshold:
+ pos = threshold // 2
+ indices = np.concatenate([np.arange(0, pos), np.arange(-pos, 0)])
+ subset = self[OuterIndexer((indices,))]
+ else:
+ subset = self
+
+ return np.asarray(subset)
+
+ def _repr_inline_(self, max_width: int) -> str:
+ from xarray.core.formatting import format_array_flat
+
+ if self.level is None:
+ return "MultiIndex"
+ else:
+ return format_array_flat(self._get_array_subset(), max_width)
+
+ def _repr_html_(self) -> str:
+ from xarray.core.formatting import short_array_repr
+
+ array_repr = short_array_repr(self._get_array_subset())
+ return f"<pre>{escape(array_repr)}</pre>"
+
+ def copy(self, deep: bool = True) -> Self:
+ # see PandasIndexingAdapter.copy
+ array = self.array.copy(deep=True) if deep else self.array
+ return type(self)(array, self._dtype, self.level)
diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py
index c1c13299..ae748b00 100644
--- a/xarray/core/iterators.py
+++ b/xarray/core/iterators.py
@@ -1,7 +1,10 @@
from __future__ import annotations
+
from collections.abc import Iterator
from typing import Callable
+
from xarray.core.treenode import Tree
+
"""These iterators are copied from anytree.iterators, with minor modifications."""
@@ -59,19 +62,70 @@ class LevelOrderIter(Iterator):
['f', 'b', 'g', 'a', 'i', 'h']
"""
- def __init__(self, node: Tree, filter_: (Callable | None)=None, stop: (
- Callable | None)=None, maxlevel: (int | None)=None):
+ def __init__(
+ self,
+ node: Tree,
+ filter_: Callable | None = None,
+ stop: Callable | None = None,
+ maxlevel: int | None = None,
+ ):
self.node = node
self.filter_ = filter_
self.stop = stop
self.maxlevel = maxlevel
self.__iter = None
- def __iter__(self) ->Iterator[Tree]:
+ def __init(self):
+ node = self.node
+ maxlevel = self.maxlevel
+ filter_ = self.filter_ or LevelOrderIter.__default_filter
+ stop = self.stop or LevelOrderIter.__default_stop
+ children = (
+ []
+ if LevelOrderIter._abort_at_level(1, maxlevel)
+ else LevelOrderIter._get_children([node], stop)
+ )
+ return self._iter(children, filter_, stop, maxlevel)
+
+ @staticmethod
+ def __default_filter(node: Tree) -> bool:
+ return True
+
+ @staticmethod
+ def __default_stop(node: Tree) -> bool:
+ return False
+
+ def __iter__(self) -> Iterator[Tree]:
return self
- def __next__(self) ->Iterator[Tree]:
+ def __next__(self) -> Iterator[Tree]:
if self.__iter is None:
self.__iter = self.__init()
- item = next(self.__iter)
+ item = next(self.__iter) # type: ignore[call-overload]
return item
+
+ @staticmethod
+ def _abort_at_level(level: int, maxlevel: int | None) -> bool:
+ return maxlevel is not None and level > maxlevel
+
+ @staticmethod
+ def _get_children(children: list[Tree], stop: Callable) -> list[Tree]:
+ return [child for child in children if not stop(child)]
+
+ @staticmethod
+ def _iter(
+ children: list[Tree], filter_: Callable, stop: Callable, maxlevel: int | None
+ ) -> Iterator[Tree]:
+ level = 1
+ while children:
+ next_children = []
+ for child in children:
+ if filter_(child):
+ yield child
+ next_children += LevelOrderIter._get_children(
+ list(child.children.values()), stop
+ )
+ children = next_children
+ level += 1
+ if LevelOrderIter._abort_at_level(level, maxlevel):
+ break
diff --git a/xarray/core/merge.py b/xarray/core/merge.py
index eb888a66..a90e59e7 100644
--- a/xarray/core/merge.py
+++ b/xarray/core/merge.py
@@ -1,31 +1,55 @@
from __future__ import annotations
+
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping, Sequence, Set
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
+
import pandas as pd
+
from xarray.core import dtypes
from xarray.core.alignment import deep_align
from xarray.core.duck_array_ops import lazy_array_equiv
-from xarray.core.indexes import Index, create_default_index_implicit, filter_indexes_from_coords, indexes_equal
+from xarray.core.indexes import (
+ Index,
+ create_default_index_implicit,
+ filter_indexes_from_coords,
+ indexes_equal,
+)
from xarray.core.utils import Frozen, compat_dict_union, dict_equiv, equivalent
from xarray.core.variable import Variable, as_variable, calculate_dimensions
+
if TYPE_CHECKING:
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import CombineAttrsOptions, CompatOptions, JoinOptions
+
DimsLike = Union[Hashable, Sequence[Hashable]]
ArrayLike = Any
- VariableLike = Union[ArrayLike, tuple[DimsLike, ArrayLike], tuple[
- DimsLike, ArrayLike, Mapping], tuple[DimsLike, ArrayLike, Mapping,
- Mapping]]
+ VariableLike = Union[
+ ArrayLike,
+ tuple[DimsLike, ArrayLike],
+ tuple[DimsLike, ArrayLike, Mapping],
+ tuple[DimsLike, ArrayLike, Mapping, Mapping],
+ ]
XarrayValue = Union[DataArray, Variable, VariableLike]
DatasetLike = Union[Dataset, Coordinates, Mapping[Any, XarrayValue]]
CoercibleValue = Union[XarrayValue, pd.Series, pd.DataFrame]
CoercibleMapping = Union[Dataset, Mapping[Any, CoercibleValue]]
-PANDAS_TYPES = pd.Series, pd.DataFrame
-_VALID_COMPAT = Frozen({'identical': 0, 'equals': 1, 'broadcast_equals': 2,
- 'minimal': 3, 'no_conflicts': 4, 'override': 5})
+
+
+PANDAS_TYPES = (pd.Series, pd.DataFrame)
+
+_VALID_COMPAT = Frozen(
+ {
+ "identical": 0,
+ "equals": 1,
+ "broadcast_equals": 2,
+ "minimal": 3,
+ "no_conflicts": 4,
+ "override": 5,
+ }
+)
class Context:
@@ -35,20 +59,33 @@ class Context:
self.func = func
-def broadcast_dimension_size(variables: list[Variable]) ->dict[Hashable, int]:
+def broadcast_dimension_size(variables: list[Variable]) -> dict[Hashable, int]:
"""Extract dimension sizes from a dictionary of variables.
Raises ValueError if any dimensions have different sizes.
"""
- pass
+ dims: dict[Hashable, int] = {}
+ for var in variables:
+ for dim, size in zip(var.dims, var.shape):
+ if dim in dims and size != dims[dim]:
+ raise ValueError(f"index {dim!r} not aligned")
+ dims[dim] = size
+ return dims
class MergeError(ValueError):
"""Error class for merge failures due to incompatible arguments."""
+ # inherits from ValueError for backward compatibility
+ # TODO: move this to an xarray.exceptions module?
-def unique_variable(name: Hashable, variables: list[Variable], compat:
- CompatOptions='broadcast_equals', equals: (bool | None)=None) ->Variable:
+
+def unique_variable(
+ name: Hashable,
+ variables: list[Variable],
+ compat: CompatOptions = "broadcast_equals",
+ equals: bool | None = None,
+) -> Variable:
"""Return the unique variable from a list of variables or raise MergeError.
Parameters
@@ -71,25 +108,98 @@ def unique_variable(name: Hashable, variables: list[Variable], compat:
------
MergeError: if any of the variables are not equal.
"""
- pass
+ out = variables[0]
+
+ if len(variables) == 1 or compat == "override":
+ return out
+
+ combine_method = None
+
+ if compat == "minimal":
+ compat = "broadcast_equals"
+
+ if compat == "broadcast_equals":
+ dim_lengths = broadcast_dimension_size(variables)
+ out = out.set_dims(dim_lengths)
+
+ if compat == "no_conflicts":
+ combine_method = "fillna"
+
+ if equals is None:
+ # first check without comparing values i.e. no computes
+ for var in variables[1:]:
+ equals = getattr(out, compat)(var, equiv=lazy_array_equiv)
+ if equals is not True:
+ break
+
+ if equals is None:
+ # now compare values with minimum number of computes
+ out = out.compute()
+ for var in variables[1:]:
+ equals = getattr(out, compat)(var)
+ if not equals:
+ break
+
+ if not equals:
+ raise MergeError(
+ f"conflicting values for variable {name!r} on objects to be combined. "
+ "You can skip this check by specifying compat='override'."
+ )
+
+ if combine_method:
+ for var in variables[1:]:
+ out = getattr(out, combine_method)(var)
+
+ return out
+
+
+def _assert_compat_valid(compat):
+ if compat not in _VALID_COMPAT:
+ raise ValueError(f"compat={compat!r} invalid: must be {set(_VALID_COMPAT)}")
MergeElement = tuple[Variable, Optional[Index]]
-def _assert_prioritized_valid(grouped: dict[Hashable, list[MergeElement]],
- prioritized: Mapping[Any, MergeElement]) ->None:
+def _assert_prioritized_valid(
+ grouped: dict[Hashable, list[MergeElement]],
+ prioritized: Mapping[Any, MergeElement],
+) -> None:
"""Make sure that elements given in prioritized will not corrupt any
index given in grouped.
"""
- pass
-
-
-def merge_collected(grouped: dict[Any, list[MergeElement]], prioritized: (
- Mapping[Any, MergeElement] | None)=None, compat: CompatOptions=
- 'minimal', combine_attrs: CombineAttrsOptions='override', equals: (dict
- [Any, bool] | None)=None) ->tuple[dict[Hashable, Variable], dict[
- Hashable, Index]]:
+ prioritized_names = set(prioritized)
+ grouped_by_index: dict[int, list[Hashable]] = defaultdict(list)
+ indexes: dict[int, Index] = {}
+
+ for name, elements_list in grouped.items():
+ for _, index in elements_list:
+ if index is not None:
+ grouped_by_index[id(index)].append(name)
+ indexes[id(index)] = index
+
+ # An index may be corrupted when the set of its corresponding coordinate name(s)
+ # partially overlaps the set of names given in prioritized
+ for index_id, index_coord_names in grouped_by_index.items():
+ index_names = set(index_coord_names)
+ common_names = index_names & prioritized_names
+ if common_names and len(common_names) != len(index_names):
+ common_names_str = ", ".join(f"{k!r}" for k in common_names)
+ index_names_str = ", ".join(f"{k!r}" for k in index_coord_names)
+ raise ValueError(
+ f"cannot set or update variable(s) {common_names_str}, which would corrupt "
+ f"the following index built from coordinates {index_names_str}:\n"
+ f"{indexes[index_id]!r}"
+ )
+
+
+def merge_collected(
+ grouped: dict[Any, list[MergeElement]],
+ prioritized: Mapping[Any, MergeElement] | None = None,
+ compat: CompatOptions = "minimal",
+ combine_attrs: CombineAttrsOptions = "override",
+ equals: dict[Any, bool] | None = None,
+) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
"""Merge dicts of variables, while resolving conflicts appropriately.
Parameters
@@ -98,7 +208,8 @@ def merge_collected(grouped: dict[Any, list[MergeElement]], prioritized: (
prioritized : mapping
compat : str
Type of equality check to use when checking for conflicts.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -122,12 +233,81 @@ def merge_collected(grouped: dict[Any, list[MergeElement]], prioritized: (
and Variable values corresponding to those that should be found on the
merged result.
"""
- pass
-
-
-def collect_variables_and_indexes(list_of_mappings: Iterable[DatasetLike],
- indexes: (Mapping[Any, Any] | None)=None) ->dict[Hashable, list[
- MergeElement]]:
+ if prioritized is None:
+ prioritized = {}
+ if equals is None:
+ equals = {}
+
+ _assert_compat_valid(compat)
+ _assert_prioritized_valid(grouped, prioritized)
+
+ merged_vars: dict[Hashable, Variable] = {}
+ merged_indexes: dict[Hashable, Index] = {}
+ index_cmp_cache: dict[tuple[int, int], bool | None] = {}
+
+ for name, elements_list in grouped.items():
+ if name in prioritized:
+ variable, index = prioritized[name]
+ merged_vars[name] = variable
+ if index is not None:
+ merged_indexes[name] = index
+ else:
+ indexed_elements = [
+ (variable, index)
+ for variable, index in elements_list
+ if index is not None
+ ]
+ if indexed_elements:
+ # TODO(shoyer): consider adjusting this logic. Are we really
+ # OK throwing away variable without an index in favor of
+ # indexed variables, without even checking if values match?
+ variable, index = indexed_elements[0]
+ for other_var, other_index in indexed_elements[1:]:
+ if not indexes_equal(
+ index, other_index, variable, other_var, index_cmp_cache
+ ):
+ raise MergeError(
+ f"conflicting values/indexes on objects to be combined fo coordinate {name!r}\n"
+ f"first index: {index!r}\nsecond index: {other_index!r}\n"
+ f"first variable: {variable!r}\nsecond variable: {other_var!r}\n"
+ )
+ if compat == "identical":
+ for other_variable, _ in indexed_elements[1:]:
+ if not dict_equiv(variable.attrs, other_variable.attrs):
+ raise MergeError(
+ "conflicting attribute values on combined "
+ f"variable {name!r}:\nfirst value: {variable.attrs!r}\nsecond value: {other_variable.attrs!r}"
+ )
+ merged_vars[name] = variable
+ merged_vars[name].attrs = merge_attrs(
+ [var.attrs for var, _ in indexed_elements],
+ combine_attrs=combine_attrs,
+ )
+ merged_indexes[name] = index
+ else:
+ variables = [variable for variable, _ in elements_list]
+ try:
+ merged_vars[name] = unique_variable(
+ name, variables, compat, equals.get(name, None)
+ )
+ except MergeError:
+ if compat != "minimal":
+ # we need more than "minimal" compatibility (for which
+ # we drop conflicting coordinates)
+ raise
+
+ if name in merged_vars:
+ merged_vars[name].attrs = merge_attrs(
+ [var.attrs for var in variables], combine_attrs=combine_attrs
+ )
+
+ return merged_vars, merged_indexes
+
+
+def collect_variables_and_indexes(
+ list_of_mappings: Iterable[DatasetLike],
+ indexes: Mapping[Any, Any] | None = None,
+) -> dict[Hashable, list[MergeElement]]:
"""Collect variables and indexes from list of mappings of xarray objects.
Mappings can be Dataset or Coordinates objects, in which case both
@@ -145,29 +325,102 @@ def collect_variables_and_indexes(list_of_mappings: Iterable[DatasetLike],
keys are also extracted.
"""
- pass
+ from xarray.core.coordinates import Coordinates
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ if indexes is None:
+ indexes = {}
+
+ grouped: dict[Hashable, list[MergeElement]] = defaultdict(list)
+
+ def append(name, variable, index):
+ grouped[name].append((variable, index))
+ def append_all(variables, indexes):
+ for name, variable in variables.items():
+ append(name, variable, indexes.get(name))
-def collect_from_coordinates(list_of_coords: list[Coordinates]) ->dict[
- Hashable, list[MergeElement]]:
+ for mapping in list_of_mappings:
+ if isinstance(mapping, (Coordinates, Dataset)):
+ append_all(mapping.variables, mapping.xindexes)
+ continue
+
+ for name, variable in mapping.items():
+ if isinstance(variable, DataArray):
+ coords_ = variable._coords.copy() # use private API for speed
+ indexes_ = dict(variable._indexes)
+ # explicitly overwritten variables should take precedence
+ coords_.pop(name, None)
+ indexes_.pop(name, None)
+ append_all(coords_, indexes_)
+
+ variable = as_variable(variable, name=name, auto_convert=False)
+ if name in indexes:
+ append(name, variable, indexes[name])
+ elif variable.dims == (name,):
+ idx, idx_vars = create_default_index_implicit(variable)
+ append_all(idx_vars, {k: idx for k in idx_vars})
+ else:
+ append(name, variable, None)
+
+ return grouped
+
+
+def collect_from_coordinates(
+ list_of_coords: list[Coordinates],
+) -> dict[Hashable, list[MergeElement]]:
"""Collect variables and indexes to be merged from Coordinate objects."""
- pass
+ grouped: dict[Hashable, list[MergeElement]] = defaultdict(list)
+
+ for coords in list_of_coords:
+ variables = coords.variables
+ indexes = coords.xindexes
+ for name, variable in variables.items():
+ grouped[name].append((variable, indexes.get(name)))
+ return grouped
-def merge_coordinates_without_align(objects: list[Coordinates], prioritized:
- (Mapping[Any, MergeElement] | None)=None, exclude_dims: Set=frozenset(),
- combine_attrs: CombineAttrsOptions='override') ->tuple[dict[Hashable,
- Variable], dict[Hashable, Index]]:
+
+def merge_coordinates_without_align(
+ objects: list[Coordinates],
+ prioritized: Mapping[Any, MergeElement] | None = None,
+ exclude_dims: Set = frozenset(),
+ combine_attrs: CombineAttrsOptions = "override",
+) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
"""Merge variables/indexes from coordinates without automatic alignments.
This function is used for merging coordinate from pre-existing xarray
objects.
"""
- pass
-
-
-def determine_coords(list_of_mappings: Iterable[DatasetLike]) ->tuple[set[
- Hashable], set[Hashable]]:
+ collected = collect_from_coordinates(objects)
+
+ if exclude_dims:
+ filtered: dict[Hashable, list[MergeElement]] = {}
+ for name, elements in collected.items():
+ new_elements = [
+ (variable, index)
+ for variable, index in elements
+ if exclude_dims.isdisjoint(variable.dims)
+ ]
+ if new_elements:
+ filtered[name] = new_elements
+ else:
+ filtered = collected
+
+ # TODO: indexes should probably be filtered in collected elements
+ # before merging them
+ merged_coords, merged_indexes = merge_collected(
+ filtered, prioritized, combine_attrs=combine_attrs
+ )
+ merged_indexes = filter_indexes_from_coords(merged_indexes, set(merged_coords))
+
+ return merged_coords, merged_indexes
+
+
+def determine_coords(
+ list_of_mappings: Iterable[DatasetLike],
+) -> tuple[set[Hashable], set[Hashable]]:
"""Given a list of dicts with xarray object values, identify coordinates.
Parameters
@@ -182,11 +435,28 @@ def determine_coords(list_of_mappings: Iterable[DatasetLike]) ->tuple[set[
All variable found in the input should appear in either the set of
coordinate or non-coordinate names.
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ coord_names: set[Hashable] = set()
+ noncoord_names: set[Hashable] = set()
+ for mapping in list_of_mappings:
+ if isinstance(mapping, Dataset):
+ coord_names.update(mapping.coords)
+ noncoord_names.update(mapping.data_vars)
+ else:
+ for name, var in mapping.items():
+ if isinstance(var, DataArray):
+ coords = set(var._coords) # use private API for speed
+ # explicitly overwritten variables should take precedence
+ coords.discard(name)
+ coord_names.update(coords)
-def coerce_pandas_values(objects: Iterable[CoercibleMapping]) ->list[
- DatasetLike]:
+ return coord_names, noncoord_names
+
+
+def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLike]:
"""Convert pandas values found in a list of labeled objects.
Parameters
@@ -200,12 +470,32 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) ->list[
List of Dataset or dictionary objects. Any inputs or values in the inputs
that were pandas objects have been converted into native xarray objects.
"""
- pass
-
+ from xarray.core.coordinates import Coordinates
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
-def _get_priority_vars_and_indexes(objects: Sequence[DatasetLike],
- priority_arg: (int | None), compat: CompatOptions='equals') ->dict[
- Hashable, MergeElement]:
+ out: list[DatasetLike] = []
+ for obj in objects:
+ variables: DatasetLike
+ if isinstance(obj, (Dataset, Coordinates)):
+ variables = obj
+ else:
+ variables = {}
+ if isinstance(obj, PANDAS_TYPES):
+ obj = dict(obj.items())
+ for k, v in obj.items():
+ if isinstance(v, PANDAS_TYPES):
+ v = DataArray(v)
+ variables[k] = v
+ out.append(variables)
+ return out
+
+
+def _get_priority_vars_and_indexes(
+ objects: Sequence[DatasetLike],
+ priority_arg: int | None,
+ compat: CompatOptions = "equals",
+) -> dict[Hashable, MergeElement]:
"""Extract the priority variable from a list of mappings.
We need this method because in some cases the priority argument itself
@@ -236,25 +526,94 @@ def _get_priority_vars_and_indexes(objects: Sequence[DatasetLike],
-------
A dictionary of variables and associated indexes (if any) to prioritize.
"""
- pass
-
-
-def merge_coords(objects: Iterable[CoercibleMapping], compat: CompatOptions
- ='minimal', join: JoinOptions='outer', priority_arg: (int | None)=None,
- indexes: (Mapping[Any, Index] | None)=None, fill_value: object=dtypes.NA
- ) ->tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
+ if priority_arg is None:
+ return {}
+
+ collected = collect_variables_and_indexes([objects[priority_arg]])
+ variables, indexes = merge_collected(collected, compat=compat)
+ grouped: dict[Hashable, MergeElement] = {}
+ for name, variable in variables.items():
+ grouped[name] = (variable, indexes.get(name))
+ return grouped
+
+
+def merge_coords(
+ objects: Iterable[CoercibleMapping],
+ compat: CompatOptions = "minimal",
+ join: JoinOptions = "outer",
+ priority_arg: int | None = None,
+ indexes: Mapping[Any, Index] | None = None,
+ fill_value: object = dtypes.NA,
+) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
"""Merge coordinate variables.
See merge_core below for argument descriptions. This works similarly to
merge_core, except everything we don't worry about whether variables are
coordinates or not.
"""
- pass
+ _assert_compat_valid(compat)
+ coerced = coerce_pandas_values(objects)
+ aligned = deep_align(
+ coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value
+ )
+ collected = collect_variables_and_indexes(aligned, indexes=indexes)
+ prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat)
+ variables, out_indexes = merge_collected(collected, prioritized, compat=compat)
+ return variables, out_indexes
def merge_attrs(variable_attrs, combine_attrs, context=None):
"""Combine attributes from different variables according to combine_attrs"""
- pass
+ if not variable_attrs:
+ # no attributes to merge
+ return None
+
+ if callable(combine_attrs):
+ return combine_attrs(variable_attrs, context=context)
+ elif combine_attrs == "drop":
+ return {}
+ elif combine_attrs == "override":
+ return dict(variable_attrs[0])
+ elif combine_attrs == "no_conflicts":
+ result = dict(variable_attrs[0])
+ for attrs in variable_attrs[1:]:
+ try:
+ result = compat_dict_union(result, attrs)
+ except ValueError as e:
+ raise MergeError(
+ "combine_attrs='no_conflicts', but some values are not "
+ f"the same. Merging {str(result)} with {str(attrs)}"
+ ) from e
+ return result
+ elif combine_attrs == "drop_conflicts":
+ result = {}
+ dropped_keys = set()
+ for attrs in variable_attrs:
+ result.update(
+ {
+ key: value
+ for key, value in attrs.items()
+ if key not in result and key not in dropped_keys
+ }
+ )
+ result = {
+ key: value
+ for key, value in result.items()
+ if key not in attrs or equivalent(attrs[key], value)
+ }
+ dropped_keys |= {key for key in attrs if key not in result}
+ return result
+ elif combine_attrs == "identical":
+ result = dict(variable_attrs[0])
+ for attrs in variable_attrs[1:]:
+ if not dict_equiv(result, attrs):
+ raise MergeError(
+ f"combine_attrs='identical', but attrs differ. First is {str(result)} "
+ f", other is {str(attrs)}."
+ )
+ return result
+ else:
+ raise ValueError(f"Unrecognised value for combine_attrs={combine_attrs}")
class _MergeResult(NamedTuple):
@@ -265,12 +624,17 @@ class _MergeResult(NamedTuple):
attrs: dict[Hashable, Any]
-def merge_core(objects: Iterable[CoercibleMapping], compat: CompatOptions=
- 'broadcast_equals', join: JoinOptions='outer', combine_attrs:
- CombineAttrsOptions='override', priority_arg: (int | None)=None,
- explicit_coords: (Iterable[Hashable] | None)=None, indexes: (Mapping[
- Any, Any] | None)=None, fill_value: object=dtypes.NA, skip_align_args:
- (list[int] | None)=None) ->_MergeResult:
+def merge_core(
+ objects: Iterable[CoercibleMapping],
+ compat: CompatOptions = "broadcast_equals",
+ join: JoinOptions = "outer",
+ combine_attrs: CombineAttrsOptions = "override",
+ priority_arg: int | None = None,
+ explicit_coords: Iterable[Hashable] | None = None,
+ indexes: Mapping[Any, Any] | None = None,
+ fill_value: object = dtypes.NA,
+ skip_align_args: list[int] | None = None,
+) -> _MergeResult:
"""Core logic for merging labeled objects.
This is not public API.
@@ -283,7 +647,8 @@ def merge_core(objects: Iterable[CoercibleMapping], compat: CompatOptions=
Compatibility checks to use when merging variables.
join : {"outer", "inner", "left", "right"}, optional
How to combine objects with different indexes.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "override"
How to combine attributes of objects
priority_arg : int, optional
Optional argument in `objects` that takes precedence over the others.
@@ -312,12 +677,64 @@ def merge_core(objects: Iterable[CoercibleMapping], compat: CompatOptions=
------
MergeError if the merge cannot be done successfully.
"""
- pass
-
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
-def merge(objects: Iterable[DataArray | CoercibleMapping], compat:
- CompatOptions='no_conflicts', join: JoinOptions='outer', fill_value:
- object=dtypes.NA, combine_attrs: CombineAttrsOptions='override') ->Dataset:
+ _assert_compat_valid(compat)
+
+ objects = list(objects)
+ if skip_align_args is None:
+ skip_align_args = []
+
+ skip_align_objs = [(pos, objects.pop(pos)) for pos in skip_align_args]
+
+ coerced = coerce_pandas_values(objects)
+ aligned = deep_align(
+ coerced, join=join, copy=False, indexes=indexes, fill_value=fill_value
+ )
+
+ for pos, obj in skip_align_objs:
+ aligned.insert(pos, obj)
+
+ collected = collect_variables_and_indexes(aligned, indexes=indexes)
+ prioritized = _get_priority_vars_and_indexes(aligned, priority_arg, compat=compat)
+ variables, out_indexes = merge_collected(
+ collected, prioritized, compat=compat, combine_attrs=combine_attrs
+ )
+
+ dims = calculate_dimensions(variables)
+
+ coord_names, noncoord_names = determine_coords(coerced)
+ if compat == "minimal":
+ # coordinates may be dropped in merged results
+ coord_names.intersection_update(variables)
+ if explicit_coords is not None:
+ coord_names.update(explicit_coords)
+ for dim, size in dims.items():
+ if dim in variables:
+ coord_names.add(dim)
+ ambiguous_coords = coord_names.intersection(noncoord_names)
+ if ambiguous_coords:
+ raise MergeError(
+ "unable to determine if these variables should be "
+ f"coordinates or not in the merged result: {ambiguous_coords}"
+ )
+
+ attrs = merge_attrs(
+ [var.attrs for var in coerced if isinstance(var, (Dataset, DataArray))],
+ combine_attrs,
+ )
+
+ return _MergeResult(variables, coord_names, dims, out_indexes, attrs)
+
+
+def merge(
+ objects: Iterable[DataArray | CoercibleMapping],
+ compat: CompatOptions = "no_conflicts",
+ join: JoinOptions = "outer",
+ fill_value: object = dtypes.NA,
+ combine_attrs: CombineAttrsOptions = "override",
+) -> Dataset:
"""Merge any number of xarray objects into a single Dataset as variables.
Parameters
@@ -325,7 +742,8 @@ def merge(objects: Iterable[DataArray | CoercibleMapping], compat:
objects : iterable of Dataset or iterable of DataArray or iterable of dict-like
Merge together all variables from these objects. If any of them are
DataArray objects, they must have a name.
- compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal"}, default: "no_conflicts"
+ compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \
+ "override", "minimal"}, default: "no_conflicts"
String indicating how to compare variables of the same name for
potential conflicts:
@@ -357,7 +775,8 @@ def merge(objects: Iterable[DataArray | CoercibleMapping], compat:
Value to use for newly missing values. If a dict-like, maps
variable names to fill values. Use a data array's name to
refer to its values.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"} or callable, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
merged:
@@ -535,23 +954,107 @@ def merge(objects: Iterable[DataArray | CoercibleMapping], compat:
combine_nested
combine_by_coords
"""
- pass
+ from xarray.core.coordinates import Coordinates
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
-def dataset_merge_method(dataset: Dataset, other: CoercibleMapping,
- overwrite_vars: (Hashable | Iterable[Hashable]), compat: CompatOptions,
- join: JoinOptions, fill_value: Any, combine_attrs: CombineAttrsOptions
- ) ->_MergeResult:
+ dict_like_objects = []
+ for obj in objects:
+ if not isinstance(obj, (DataArray, Dataset, Coordinates, dict)):
+ raise TypeError(
+ "objects must be an iterable containing only "
+ "Dataset(s), DataArray(s), and dictionaries."
+ )
+
+ if isinstance(obj, DataArray):
+ obj = obj.to_dataset(promote_attrs=True)
+ elif isinstance(obj, Coordinates):
+ obj = obj.to_dataset()
+ dict_like_objects.append(obj)
+
+ merge_result = merge_core(
+ dict_like_objects,
+ compat,
+ join,
+ combine_attrs=combine_attrs,
+ fill_value=fill_value,
+ )
+ return Dataset._construct_direct(**merge_result._asdict())
+
+
+def dataset_merge_method(
+ dataset: Dataset,
+ other: CoercibleMapping,
+ overwrite_vars: Hashable | Iterable[Hashable],
+ compat: CompatOptions,
+ join: JoinOptions,
+ fill_value: Any,
+ combine_attrs: CombineAttrsOptions,
+) -> _MergeResult:
"""Guts of the Dataset.merge method."""
- pass
-
-
-def dataset_update_method(dataset: Dataset, other: CoercibleMapping
- ) ->_MergeResult:
+ # we are locked into supporting overwrite_vars for the Dataset.merge
+ # method due for backwards compatibility
+ # TODO: consider deprecating it?
+
+ if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable):
+ overwrite_vars = set(overwrite_vars)
+ else:
+ overwrite_vars = {overwrite_vars}
+
+ if not overwrite_vars:
+ objs = [dataset, other]
+ priority_arg = None
+ elif overwrite_vars == set(other):
+ objs = [dataset, other]
+ priority_arg = 1
+ else:
+ other_overwrite: dict[Hashable, CoercibleValue] = {}
+ other_no_overwrite: dict[Hashable, CoercibleValue] = {}
+ for k, v in other.items():
+ if k in overwrite_vars:
+ other_overwrite[k] = v
+ else:
+ other_no_overwrite[k] = v
+ objs = [dataset, other_no_overwrite, other_overwrite]
+ priority_arg = 2
+
+ return merge_core(
+ objs,
+ compat,
+ join,
+ priority_arg=priority_arg,
+ fill_value=fill_value,
+ combine_attrs=combine_attrs,
+ )
+
+
+def dataset_update_method(dataset: Dataset, other: CoercibleMapping) -> _MergeResult:
"""Guts of the Dataset.update method.
This drops a duplicated coordinates from `other` if `other` is not an
`xarray.Dataset`, e.g., if it's a dict with DataArray values (GH2068,
GH2180).
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
+
+ if not isinstance(other, Dataset):
+ other = dict(other)
+ for key, value in other.items():
+ if isinstance(value, DataArray):
+ # drop conflicting coordinates
+ coord_names = [
+ c
+ for c in value.coords
+ if c not in value.dims and c in dataset.coords
+ ]
+ if coord_names:
+ other[key] = value.drop_vars(coord_names)
+
+ return merge_core(
+ [dataset, other],
+ priority_arg=1,
+ indexes=dataset.xindexes,
+ combine_attrs="override",
+ )
diff --git a/xarray/core/missing.py b/xarray/core/missing.py
index e589337e..bfbad726 100644
--- a/xarray/core/missing.py
+++ b/xarray/core/missing.py
@@ -1,38 +1,68 @@
from __future__ import annotations
+
import datetime as dt
import warnings
from collections.abc import Hashable, Sequence
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING, Any, Callable, get_args
+
import numpy as np
import pandas as pd
+
from xarray.core import utils
from xarray.core.common import _contains_datetime_like_objects, ones_like
from xarray.core.computation import apply_ufunc
-from xarray.core.duck_array_ops import datetime_to_numeric, push, reshape, timedelta_to_numeric
+from xarray.core.duck_array_ops import (
+ datetime_to_numeric,
+ push,
+ reshape,
+ timedelta_to_numeric,
+)
from xarray.core.options import _get_keep_attrs
from xarray.core.types import Interp1dOptions, InterpOptions
from xarray.core.utils import OrderedSet, is_scalar
from xarray.core.variable import Variable, broadcast_variables
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
-def _get_nan_block_lengths(obj: (Dataset | DataArray | Variable), dim:
- Hashable, index: Variable):
+def _get_nan_block_lengths(
+ obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable
+):
"""
Return an object where each NaN element in 'obj' is replaced by the
length of the gap the element is in.
"""
- pass
+
+ # make variable so that we get broadcasting for free
+ index = Variable([dim], index)
+
+ # algorithm from https://github.com/pydata/xarray/pull/3302#discussion_r324707072
+ arange = ones_like(obj) * index
+ valid = obj.notnull()
+ valid_arange = arange.where(valid)
+ cumulative_nans = valid_arange.ffill(dim=dim).fillna(index[0])
+
+ nan_block_lengths = (
+ cumulative_nans.diff(dim=dim, label="upper")
+ .reindex({dim: obj[dim]})
+ .where(valid)
+ .bfill(dim=dim)
+ .where(~valid, 0)
+ .fillna(index[-1] - valid_arange.max(dim=[dim]))
+ )
+
+ return nan_block_lengths
class BaseInterpolator:
"""Generic interpolator class for normalizing interpolation methods"""
+
cons_kwargs: dict[str, Any]
call_kwargs: dict[str, Any]
f: Callable
@@ -42,7 +72,7 @@ class BaseInterpolator:
return self.f(x, **self.call_kwargs)
def __repr__(self):
- return f'{self.__class__.__name__}: method={self.method}'
+ return f"{self.__class__.__name__}: method={self.method}"
class NumpyInterpolator(BaseInterpolator):
@@ -53,17 +83,20 @@ class NumpyInterpolator(BaseInterpolator):
numpy.interp
"""
- def __init__(self, xi, yi, method='linear', fill_value=None, period=None):
- if method != 'linear':
- raise ValueError(
- 'only method `linear` is valid for the NumpyInterpolator')
+ def __init__(self, xi, yi, method="linear", fill_value=None, period=None):
+ if method != "linear":
+ raise ValueError("only method `linear` is valid for the NumpyInterpolator")
+
self.method = method
self.f = np.interp
self.cons_kwargs = {}
- self.call_kwargs = {'period': period}
+ self.call_kwargs = {"period": period}
+
self._xi = xi
self._yi = yi
- nan = np.nan if yi.dtype.kind != 'c' else np.nan + np.nan * 1.0j
+
+ nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j
+
if fill_value is None:
self._left = nan
self._right = nan
@@ -74,11 +107,17 @@ class NumpyInterpolator(BaseInterpolator):
self._left = fill_value
self._right = fill_value
else:
- raise ValueError(f'{fill_value} is not a valid fill_value')
+ raise ValueError(f"{fill_value} is not a valid fill_value")
def __call__(self, x):
- return self.f(x, self._xi, self._yi, left=self._left, right=self.
- _right, **self.call_kwargs)
+ return self.f(
+ x,
+ self._xi,
+ self._yi,
+ left=self._left,
+ right=self._right,
+ **self.call_kwargs,
+ )
class ScipyInterpolator(BaseInterpolator):
@@ -89,28 +128,53 @@ class ScipyInterpolator(BaseInterpolator):
scipy.interpolate.interp1d
"""
- def __init__(self, xi, yi, method=None, fill_value=None, assume_sorted=
- True, copy=False, bounds_error=False, order=None, **kwargs):
+ def __init__(
+ self,
+ xi,
+ yi,
+ method=None,
+ fill_value=None,
+ assume_sorted=True,
+ copy=False,
+ bounds_error=False,
+ order=None,
+ **kwargs,
+ ):
from scipy.interpolate import interp1d
+
if method is None:
raise ValueError(
- 'method is a required argument, please supply a valid scipy.inter1d method (kind)'
- )
- if method == 'polynomial':
+ "method is a required argument, please supply a "
+ "valid scipy.inter1d method (kind)"
+ )
+
+ if method == "polynomial":
if order is None:
- raise ValueError('order is required when method=polynomial')
+ raise ValueError("order is required when method=polynomial")
method = order
+
self.method = method
+
self.cons_kwargs = kwargs
self.call_kwargs = {}
- nan = np.nan if yi.dtype.kind != 'c' else np.nan + np.nan * 1.0j
- if fill_value is None and method == 'linear':
+
+ nan = np.nan if yi.dtype.kind != "c" else np.nan + np.nan * 1j
+
+ if fill_value is None and method == "linear":
fill_value = nan, nan
elif fill_value is None:
fill_value = nan
- self.f = interp1d(xi, yi, kind=self.method, fill_value=fill_value,
- bounds_error=bounds_error, assume_sorted=assume_sorted, copy=
- copy, **self.cons_kwargs)
+
+ self.f = interp1d(
+ xi,
+ yi,
+ kind=self.method,
+ fill_value=fill_value,
+ bounds_error=bounds_error,
+ assume_sorted=assume_sorted,
+ copy=copy,
+ **self.cons_kwargs,
+ )
class SplineInterpolator(BaseInterpolator):
@@ -121,27 +185,48 @@ class SplineInterpolator(BaseInterpolator):
scipy.interpolate.UnivariateSpline
"""
- def __init__(self, xi, yi, method='spline', fill_value=None, order=3,
- nu=0, ext=None, **kwargs):
+ def __init__(
+ self,
+ xi,
+ yi,
+ method="spline",
+ fill_value=None,
+ order=3,
+ nu=0,
+ ext=None,
+ **kwargs,
+ ):
from scipy.interpolate import UnivariateSpline
- if method != 'spline':
- raise ValueError(
- 'only method `spline` is valid for the SplineInterpolator')
+
+ if method != "spline":
+ raise ValueError("only method `spline` is valid for the SplineInterpolator")
+
self.method = method
self.cons_kwargs = kwargs
- self.call_kwargs = {'nu': nu, 'ext': ext}
+ self.call_kwargs = {"nu": nu, "ext": ext}
+
if fill_value is not None:
- raise ValueError('SplineInterpolator does not support fill_value')
+ raise ValueError("SplineInterpolator does not support fill_value")
+
self.f = UnivariateSpline(xi, yi, k=order, **self.cons_kwargs)
def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):
"""Wrapper for datasets"""
- pass
+ ds = type(self)(coords=self.coords, attrs=self.attrs)
+
+ for name, var in self.data_vars.items():
+ if dim in var.dims:
+ ds[name] = func(var, dim=dim, **kwargs)
+ else:
+ ds[name] = var
+
+ return ds
-def get_clean_interp_index(arr, dim: Hashable, use_coordinate: (str | bool)
- =True, strict: bool=True):
+def get_clean_interp_index(
+ arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True
+):
"""Return index to use for x values in interpolation or curve fitting.
Parameters
@@ -167,49 +252,268 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: (str | bool)
If indexing is along the time dimension, datetime coordinates are converted
to time deltas with respect to 1970-01-01.
"""
- pass
+ # Question: If use_coordinate is a string, what role does `dim` play?
+ from xarray.coding.cftimeindex import CFTimeIndex
-def interp_na(self, dim: (Hashable | None)=None, use_coordinate: (bool |
- str)=True, method: InterpOptions='linear', limit: (int | None)=None,
- max_gap: (int | float | str | pd.Timedelta | np.timedelta64 | dt.
- timedelta | None)=None, keep_attrs: (bool | None)=None, **kwargs):
+ if use_coordinate is False:
+ axis = arr.get_axis_num(dim)
+ return np.arange(arr.shape[axis], dtype=np.float64)
+
+ if use_coordinate is True:
+ index = arr.get_index(dim)
+
+ else: # string
+ index = arr.coords[use_coordinate]
+ if index.ndim != 1:
+ raise ValueError(
+ f"Coordinates used for interpolation must be 1D, "
+ f"{use_coordinate} is {index.ndim}D."
+ )
+ index = index.to_index()
+
+ # TODO: index.name is None for multiindexes
+ # set name for nice error messages below
+ if isinstance(index, pd.MultiIndex):
+ index.name = dim
+
+ if strict:
+ if not index.is_monotonic_increasing:
+ raise ValueError(f"Index {index.name!r} must be monotonically increasing")
+
+ if not index.is_unique:
+ raise ValueError(f"Index {index.name!r} has duplicate values")
+
+ # Special case for non-standard calendar indexes
+ # Numerical datetime values are defined with respect to 1970-01-01T00:00:00 in units of nanoseconds
+ if isinstance(index, (CFTimeIndex, pd.DatetimeIndex)):
+ offset = type(index[0])(1970, 1, 1)
+ if isinstance(index, CFTimeIndex):
+ index = index.values
+ index = Variable(
+ data=datetime_to_numeric(index, offset=offset, datetime_unit="ns"),
+ dims=(dim,),
+ )
+
+ # raise if index cannot be cast to a float (e.g. MultiIndex)
+ try:
+ index = index.values.astype(np.float64)
+ except (TypeError, ValueError):
+ # pandas raises a TypeError
+ # xarray/numpy raise a ValueError
+ raise TypeError(
+ f"Index {index.name!r} must be castable to float64 to support "
+ f"interpolation or curve fitting, got {type(index).__name__}."
+ )
+
+ return index
+
+
+def interp_na(
+ self,
+ dim: Hashable | None = None,
+ use_coordinate: bool | str = True,
+ method: InterpOptions = "linear",
+ limit: int | None = None,
+ max_gap: (
+ int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta | None
+ ) = None,
+ keep_attrs: bool | None = None,
+ **kwargs,
+):
"""Interpolate values according to different methods."""
- pass
+ from xarray.coding.cftimeindex import CFTimeIndex
+
+ if dim is None:
+ raise NotImplementedError("dim is a required argument")
+
+ if limit is not None:
+ valids = _get_valid_fill_mask(self, dim, limit)
+
+ if max_gap is not None:
+ max_type = type(max_gap).__name__
+ if not is_scalar(max_gap):
+ raise ValueError("max_gap must be a scalar.")
+
+ if (
+ dim in self._indexes
+ and isinstance(
+ self._indexes[dim].to_pandas_index(), (pd.DatetimeIndex, CFTimeIndex)
+ )
+ and use_coordinate
+ ):
+ # Convert to float
+ max_gap = timedelta_to_numeric(max_gap)
+
+ if not use_coordinate:
+ if not isinstance(max_gap, (Number, np.number)):
+ raise TypeError(
+ f"Expected integer or floating point max_gap since use_coordinate=False. Received {max_type}."
+ )
+
+ # method
+ index = get_clean_interp_index(self, dim, use_coordinate=use_coordinate)
+ interp_class, kwargs = _get_interpolator(method, **kwargs)
+ interpolator = partial(func_interpolate_na, interp_class, **kwargs)
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", "overflow", RuntimeWarning)
+ warnings.filterwarnings("ignore", "invalid value", RuntimeWarning)
+ arr = apply_ufunc(
+ interpolator,
+ self,
+ index,
+ input_core_dims=[[dim], [dim]],
+ output_core_dims=[[dim]],
+ output_dtypes=[self.dtype],
+ dask="parallelized",
+ vectorize=True,
+ keep_attrs=keep_attrs,
+ ).transpose(*self.dims)
+
+ if limit is not None:
+ arr = arr.where(valids)
+
+ if max_gap is not None:
+ if dim not in self.coords:
+ raise NotImplementedError(
+ "max_gap not implemented for unlabeled coordinates yet."
+ )
+ nan_block_lengths = _get_nan_block_lengths(self, dim, index)
+ arr = arr.where(nan_block_lengths <= max_gap)
+
+ return arr
def func_interpolate_na(interpolator, y, x, **kwargs):
"""helper function to apply interpolation along 1 dimension"""
- pass
+ # reversed arguments are so that attrs are preserved from da, not index
+ # it would be nice if this wasn't necessary, works around:
+ # "ValueError: assignment destination is read-only" in assignment below
+ out = y.copy()
+
+ nans = pd.isnull(y)
+ nonans = ~nans
+
+ # fast track for no-nans, all nan but one, and all-nans cases
+ n_nans = nans.sum()
+ if n_nans == 0 or n_nans >= len(y) - 1:
+ return y
+
+ f = interpolator(x[nonans], y[nonans], **kwargs)
+ out[nans] = f(x[nans])
+ return out
def _bfill(arr, n=None, axis=-1):
"""inverse of ffill"""
- pass
+ arr = np.flip(arr, axis=axis)
+
+ # fill
+ arr = push(arr, axis=axis, n=n)
+
+ # reverse back to original
+ return np.flip(arr, axis=axis)
def ffill(arr, dim=None, limit=None):
"""forward fill missing values"""
- pass
+
+ axis = arr.get_axis_num(dim)
+
+ # work around for bottleneck 178
+ _limit = limit if limit is not None else arr.shape[axis]
+
+ return apply_ufunc(
+ push,
+ arr,
+ dask="allowed",
+ keep_attrs=True,
+ output_dtypes=[arr.dtype],
+ kwargs=dict(n=_limit, axis=axis),
+ ).transpose(*arr.dims)
def bfill(arr, dim=None, limit=None):
"""backfill missing values"""
- pass
+
+ axis = arr.get_axis_num(dim)
+
+ # work around for bottleneck 178
+ _limit = limit if limit is not None else arr.shape[axis]
+
+ return apply_ufunc(
+ _bfill,
+ arr,
+ dask="allowed",
+ keep_attrs=True,
+ output_dtypes=[arr.dtype],
+ kwargs=dict(n=_limit, axis=axis),
+ ).transpose(*arr.dims)
def _import_interpolant(interpolant, method):
"""Import interpolant from scipy.interpolate."""
- pass
+ try:
+ from scipy import interpolate
+
+ return getattr(interpolate, interpolant)
+ except ImportError as e:
+ raise ImportError(f"Interpolation with method {method} requires scipy.") from e
-def _get_interpolator(method: InterpOptions, vectorizeable_only: bool=False,
- **kwargs):
+def _get_interpolator(
+ method: InterpOptions, vectorizeable_only: bool = False, **kwargs
+):
"""helper function to select the appropriate interpolator class
returns interpolator class and keyword arguments for the class
"""
- pass
+ interp_class: (
+ type[NumpyInterpolator] | type[ScipyInterpolator] | type[SplineInterpolator]
+ )
+
+ interp1d_methods = get_args(Interp1dOptions)
+ valid_methods = tuple(vv for v in get_args(InterpOptions) for vv in get_args(v))
+
+ # prioritize scipy.interpolate
+ if (
+ method == "linear"
+ and not kwargs.get("fill_value", None) == "extrapolate"
+ and not vectorizeable_only
+ ):
+ kwargs.update(method=method)
+ interp_class = NumpyInterpolator
+
+ elif method in valid_methods:
+ if method in interp1d_methods:
+ kwargs.update(method=method)
+ interp_class = ScipyInterpolator
+ elif vectorizeable_only:
+ raise ValueError(
+ f"{method} is not a vectorizeable interpolator. "
+ f"Available methods are {interp1d_methods}"
+ )
+ elif method == "barycentric":
+ interp_class = _import_interpolant("BarycentricInterpolator", method)
+ elif method in ["krogh", "krog"]:
+ interp_class = _import_interpolant("KroghInterpolator", method)
+ elif method == "pchip":
+ interp_class = _import_interpolant("PchipInterpolator", method)
+ elif method == "spline":
+ kwargs.update(method=method)
+ interp_class = SplineInterpolator
+ elif method == "akima":
+ interp_class = _import_interpolant("Akima1DInterpolator", method)
+ else:
+ raise ValueError(f"{method} is not a valid scipy interpolator")
+ else:
+ raise ValueError(f"{method} is not a valid interpolator")
+
+ return interp_class, kwargs
def _get_interpolator_nd(method, **kwargs):
@@ -217,20 +521,48 @@ def _get_interpolator_nd(method, **kwargs):
returns interpolator class and keyword arguments for the class
"""
- pass
+ valid_methods = ["linear", "nearest"]
+
+ if method in valid_methods:
+ kwargs.update(method=method)
+ interp_class = _import_interpolant("interpn", method)
+ else:
+ raise ValueError(
+ f"{method} is not a valid interpolator for interpolating "
+ "over multiple dimensions."
+ )
+
+ return interp_class, kwargs
def _get_valid_fill_mask(arr, dim, limit):
"""helper function to determine values that can be filled when limit is not
None"""
- pass
+ kw = {dim: limit + 1}
+ # we explicitly use construct method to avoid copy.
+ new_dim = utils.get_temp_dimname(arr.dims, "_window")
+ return (
+ arr.isnull()
+ .rolling(min_periods=1, **kw)
+ .construct(new_dim, fill_value=False)
+ .sum(new_dim, skipna=False)
+ ) <= limit
def _localize(var, indexes_coords):
"""Speed up for linear and nearest neighbor method.
Only consider a subspace that is needed for the interpolation
"""
- pass
+ indexes = {}
+ for dim, [x, new_x] in indexes_coords.items():
+ new_x_loaded = new_x.values
+ minval = np.nanmin(new_x_loaded)
+ maxval = np.nanmax(new_x_loaded)
+ index = x.to_index()
+ imin, imax = index.get_indexer([minval, maxval], method="nearest")
+ indexes[dim] = slice(max(imin - 2, 0), imax + 2)
+ indexes_coords[dim] = (x[indexes[dim]], new_x)
+ return var.isel(**indexes), indexes_coords
def _floatize_x(x, new_x):
@@ -238,7 +570,19 @@ def _floatize_x(x, new_x):
This is particularly useful for datetime dtype.
x, new_x: tuple of np.ndarray
"""
- pass
+ x = list(x)
+ new_x = list(new_x)
+ for i in range(len(x)):
+ if _contains_datetime_like_objects(x[i]):
+ # Scipy casts coordinates to np.float64, which is not accurate
+ # enough for datetime64 (uses 64bit integer).
+ # We assume that the most of the bits are used to represent the
+ # offset (min(x)) and the variation (x - min(x)) can be
+ # represented by float.
+ xmin = x[i].values.min()
+ x[i] = x[i]._to_numeric(offset=xmin, dtype=np.float64)
+ new_x[i] = new_x[i]._to_numeric(offset=xmin, dtype=np.float64)
+ return x, new_x
def interp(var, indexes_coords, method: InterpOptions, **kwargs):
@@ -267,7 +611,42 @@ def interp(var, indexes_coords, method: InterpOptions, **kwargs):
DataArray.interp
Dataset.interp
"""
- pass
+ if not indexes_coords:
+ return var.copy()
+
+ # default behavior
+ kwargs["bounds_error"] = kwargs.get("bounds_error", False)
+
+ result = var
+ # decompose the interpolation into a succession of independent interpolation
+ for indexes_coords in decompose_interp(indexes_coords):
+ var = result
+
+ # target dimensions
+ dims = list(indexes_coords)
+ x, new_x = zip(*[indexes_coords[d] for d in dims])
+ destination = broadcast_variables(*new_x)
+
+ # transpose to make the interpolated axis to the last position
+ broadcast_dims = [d for d in var.dims if d not in dims]
+ original_dims = broadcast_dims + dims
+ new_dims = broadcast_dims + list(destination[0].dims)
+ interped = interp_func(
+ var.transpose(*original_dims).data, x, destination, method, kwargs
+ )
+
+ result = Variable(new_dims, interped, attrs=var.attrs, fastpath=True)
+
+ # dimension of the output array
+ out_dims: OrderedSet = OrderedSet()
+ for d in var.dims:
+ if d in dims:
+ out_dims.update(indexes_coords[d][1].dims)
+ else:
+ out_dims.add(d)
+ if len(out_dims) > 1:
+ result = result.transpose(*out_dims)
+ return result
def interp_func(var, x, new_x, method: InterpOptions, kwargs):
@@ -303,19 +682,161 @@ def interp_func(var, x, new_x, method: InterpOptions, kwargs):
--------
scipy.interpolate.interp1d
"""
- pass
+ if not x:
+ return var.copy()
+
+ if len(x) == 1:
+ func, kwargs = _get_interpolator(method, vectorizeable_only=True, **kwargs)
+ else:
+ func, kwargs = _get_interpolator_nd(method, **kwargs)
+
+ if is_chunked_array(var):
+ chunkmanager = get_chunked_array_type(var)
+
+ ndim = var.ndim
+ nconst = ndim - len(x)
+
+ out_ind = list(range(nconst)) + list(range(ndim, ndim + new_x[0].ndim))
+
+ # blockwise args format
+ x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)]
+ x_arginds = [item for pair in x_arginds for item in pair]
+ new_x_arginds = [
+ [_x, [ndim + index for index in range(_x.ndim)]] for _x in new_x
+ ]
+ new_x_arginds = [item for pair in new_x_arginds for item in pair]
+
+ args = (var, range(ndim), *x_arginds, *new_x_arginds)
+ _, rechunked = chunkmanager.unify_chunks(*args)
-def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs,
- localize=True):
+ args = tuple(elem for pair in zip(rechunked, args[1::2]) for elem in pair)
+
+ new_x = rechunked[1 + (len(rechunked) - 1) // 2 :]
+
+ new_x0_chunks = new_x[0].chunks
+ new_x0_shape = new_x[0].shape
+ new_x0_chunks_is_not_none = new_x0_chunks is not None
+ new_axes = {
+ ndim + i: new_x0_chunks[i] if new_x0_chunks_is_not_none else new_x0_shape[i]
+ for i in range(new_x[0].ndim)
+ }
+
+ # if useful, reuse localize for each chunk of new_x
+ localize = (method in ["linear", "nearest"]) and new_x0_chunks_is_not_none
+
+ # scipy.interpolate.interp1d always forces to float.
+ # Use the same check for blockwise as well:
+ if not issubclass(var.dtype.type, np.inexact):
+ dtype = float
+ else:
+ dtype = var.dtype
+
+ meta = var._meta
+
+ return chunkmanager.blockwise(
+ _chunked_aware_interpnd,
+ out_ind,
+ *args,
+ interp_func=func,
+ interp_kwargs=kwargs,
+ localize=localize,
+ concatenate=True,
+ dtype=dtype,
+ new_axes=new_axes,
+ meta=meta,
+ align_arrays=False,
+ )
+
+ return _interpnd(var, x, new_x, func, kwargs)
+
+
+def _interp1d(var, x, new_x, func, kwargs):
+ # x, new_x are tuples of size 1.
+ x, new_x = x[0], new_x[0]
+ rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
+ if new_x.ndim > 1:
+ return reshape(rslt, (var.shape[:-1] + new_x.shape))
+ if new_x.ndim == 0:
+ return rslt[..., -1]
+ return rslt
+
+
+def _interpnd(var, x, new_x, func, kwargs):
+ x, new_x = _floatize_x(x, new_x)
+
+ if len(x) == 1:
+ return _interp1d(var, x, new_x, func, kwargs)
+
+ # move the interpolation axes to the start position
+ var = var.transpose(range(-len(x), var.ndim - len(x)))
+ # stack new_x to 1 vector, with reshape
+ xi = np.stack([x1.values.ravel() for x1 in new_x], axis=-1)
+ rslt = func(x, var, xi, **kwargs)
+ # move back the interpolation axes to the last position
+ rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
+ return reshape(rslt, rslt.shape[:-1] + new_x[0].shape)
+
+
+def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):
"""Wrapper for `_interpnd` through `blockwise` for chunked arrays.
The first half arrays in `coords` are original coordinates,
the other half are destination coordinates
"""
- pass
+ n_x = len(coords) // 2
+ nconst = len(var.shape) - n_x
+
+ # _interpnd expect coords to be Variables
+ x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])]
+ new_x = [
+ Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x)
+ for _x in coords[n_x:]
+ ]
+
+ if localize:
+ # _localize expect var to be a Variable
+ var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var)
+
+ indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)}
+
+ # simple speed up for the local interpolation
+ var, indexes_coords = _localize(var, indexes_coords)
+ x, new_x = zip(*[indexes_coords[d] for d in indexes_coords])
+
+ # put var back as a ndarray
+ var = var.data
+
+ return _interpnd(var, x, new_x, interp_func, interp_kwargs)
def decompose_interp(indexes_coords):
"""Decompose the interpolation into a succession of independent interpolation keeping the order"""
- pass
+
+ dest_dims = [
+ dest[1].dims if dest[1].ndim > 0 else [dim]
+ for dim, dest in indexes_coords.items()
+ ]
+ partial_dest_dims = []
+ partial_indexes_coords = {}
+ for i, index_coords in enumerate(indexes_coords.items()):
+ partial_indexes_coords.update([index_coords])
+
+ if i == len(dest_dims) - 1:
+ break
+
+ partial_dest_dims += [dest_dims[i]]
+ other_dims = dest_dims[i + 1 :]
+
+ s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims}
+ s_other_dims = {dim for dims in other_dims for dim in dims}
+
+ if not s_partial_dest_dims.intersection(s_other_dims):
+ # this interpolation is orthogonal to the rest
+
+ yield partial_indexes_coords
+
+ partial_dest_dims = []
+ partial_indexes_coords = {}
+
+ yield partial_indexes_coords
diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py
index 78ecefc3..fc724013 100644
--- a/xarray/core/nanops.py
+++ b/xarray/core/nanops.py
@@ -1,29 +1,177 @@
from __future__ import annotations
+
import warnings
+
import numpy as np
+
from xarray.core import dtypes, duck_array_ops, nputils, utils
-from xarray.core.duck_array_ops import astype, count, fillna, isnull, sum_where, where, where_method
+from xarray.core.duck_array_ops import (
+ astype,
+ count,
+ fillna,
+ isnull,
+ sum_where,
+ where,
+ where_method,
+)
def _maybe_null_out(result, axis, mask, min_count=1):
"""
xarray version of pandas.core.nanops._maybe_null_out
"""
- pass
+ if axis is not None and getattr(result, "ndim", False):
+ null_mask = (
+ np.take(mask.shape, axis).prod()
+ - duck_array_ops.sum(mask, axis)
+ - min_count
+ ) < 0
+ dtype, fill_value = dtypes.maybe_promote(result.dtype)
+ result = where(null_mask, fill_value, astype(result, dtype))
+
+ elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
+ null_mask = mask.size - duck_array_ops.sum(mask)
+ result = where(null_mask < min_count, np.nan, result)
+
+ return result
def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
"""In house nanargmin, nanargmax for object arrays. Always return integer
type
"""
- pass
+ valid_count = count(value, axis=axis)
+ value = fillna(value, fill_value)
+ data = getattr(np, func)(value, axis=axis, **kwargs)
+
+ # TODO This will evaluate dask arrays and might be costly.
+ if (valid_count == 0).any():
+ raise ValueError("All-NaN slice encountered")
+
+ return data
def _nan_minmax_object(func, fill_value, value, axis=None, **kwargs):
"""In house nanmin and nanmax for object array"""
- pass
+ valid_count = count(value, axis=axis)
+ filled_value = fillna(value, fill_value)
+ data = getattr(np, func)(filled_value, axis=axis, **kwargs)
+ if not hasattr(data, "dtype"): # scalar case
+ data = fill_value if valid_count == 0 else data
+ # we've computed a single min, max value of type object.
+ # don't let np.array turn a tuple back into an array
+ return utils.to_0d_object_array(data)
+ return where_method(data, valid_count != 0)
+
+
+def nanmin(a, axis=None, out=None):
+ if a.dtype.kind == "O":
+ return _nan_minmax_object("min", dtypes.get_pos_infinity(a.dtype), a, axis)
+
+ return nputils.nanmin(a, axis=axis)
+
+
+def nanmax(a, axis=None, out=None):
+ if a.dtype.kind == "O":
+ return _nan_minmax_object("max", dtypes.get_neg_infinity(a.dtype), a, axis)
+
+ return nputils.nanmax(a, axis=axis)
+
+
+def nanargmin(a, axis=None):
+ if a.dtype.kind == "O":
+ fill_value = dtypes.get_pos_infinity(a.dtype)
+ return _nan_argminmax_object("argmin", fill_value, a, axis=axis)
+
+ return nputils.nanargmin(a, axis=axis)
+
+
+def nanargmax(a, axis=None):
+ if a.dtype.kind == "O":
+ fill_value = dtypes.get_neg_infinity(a.dtype)
+ return _nan_argminmax_object("argmax", fill_value, a, axis=axis)
+
+ return nputils.nanargmax(a, axis=axis)
+
+
+def nansum(a, axis=None, dtype=None, out=None, min_count=None):
+ mask = isnull(a)
+ result = sum_where(a, axis=axis, dtype=dtype, where=mask)
+ if min_count is not None:
+ return _maybe_null_out(result, axis, mask, min_count)
+ else:
+ return result
def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
"""In house nanmean. ddof argument will be used in _nanvar method"""
- pass
+ from xarray.core.duck_array_ops import count, fillna, where_method
+
+ valid_count = count(value, axis=axis)
+ value = fillna(value, 0)
+ # As dtype inference is impossible for object dtype, we assume float
+ # https://github.com/dask/dask/issues/3162
+ if dtype is None and value.dtype.kind == "O":
+ dtype = value.dtype if value.dtype.kind in ["cf"] else float
+
+ data = np.sum(value, axis=axis, dtype=dtype, **kwargs)
+ data = data / (valid_count - ddof)
+ return where_method(data, valid_count != 0)
+
+
+def nanmean(a, axis=None, dtype=None, out=None):
+ if a.dtype.kind == "O":
+ return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype)
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore", r"Mean of empty slice", category=RuntimeWarning
+ )
+
+ return np.nanmean(a, axis=axis, dtype=dtype)
+
+
+def nanmedian(a, axis=None, out=None):
+ # The dask algorithm works by rechunking to one chunk along axis
+ # Make sure we trigger the dask error when passing all dimensions
+ # so that we don't rechunk the entire array to one chunk and
+ # possibly blow memory
+ if axis is not None and len(np.atleast_1d(axis)) == a.ndim:
+ axis = None
+ return nputils.nanmedian(a, axis=axis)
+
+
+def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs):
+ value_mean = _nanmean_ddof_object(
+ ddof=0, value=value, axis=axis, keepdims=True, **kwargs
+ )
+ squared = (astype(value, value_mean.dtype) - value_mean) ** 2
+ return _nanmean_ddof_object(ddof, squared, axis=axis, keepdims=keepdims, **kwargs)
+
+
+def nanvar(a, axis=None, dtype=None, out=None, ddof=0):
+ if a.dtype.kind == "O":
+ return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof)
+
+ return nputils.nanvar(a, axis=axis, dtype=dtype, ddof=ddof)
+
+
+def nanstd(a, axis=None, dtype=None, out=None, ddof=0):
+ return nputils.nanstd(a, axis=axis, dtype=dtype, ddof=ddof)
+
+
+def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
+ mask = isnull(a)
+ result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
+ if min_count is not None:
+ return _maybe_null_out(result, axis, mask, min_count)
+ else:
+ return result
+
+
+def nancumsum(a, axis=None, dtype=None, out=None):
+ return nputils.nancumsum(a, axis=axis, dtype=dtype)
+
+
+def nancumprod(a, axis=None, dtype=None, out=None):
+ return nputils.nancumprod(a, axis=axis, dtype=dtype)
diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py
index 3b0aa890..92d30e1d 100644
--- a/xarray/core/npcompat.py
+++ b/xarray/core/npcompat.py
@@ -1,11 +1,73 @@
+# Copyright (c) 2005-2011, NumPy Developers.
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the NumPy Developers nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
from typing import Any
+
try:
- from numpy import isdtype
+ # requires numpy>=2.0
+ from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
except ImportError:
import numpy as np
from numpy.typing import DTypeLike
- kind_mapping = {'bool': np.bool_, 'signed integer': np.signedinteger,
- 'unsigned integer': np.unsignedinteger, 'integral': np.integer,
- 'real floating': np.floating, 'complex floating': np.
- complexfloating, 'numeric': np.number}
+
+ kind_mapping = {
+ "bool": np.bool_,
+ "signed integer": np.signedinteger,
+ "unsigned integer": np.unsignedinteger,
+ "integral": np.integer,
+ "real floating": np.floating,
+ "complex floating": np.complexfloating,
+ "numeric": np.number,
+ }
+
+ def isdtype(
+ dtype: np.dtype[Any] | type[Any], kind: DTypeLike | tuple[DTypeLike, ...]
+ ) -> bool:
+ kinds = kind if isinstance(kind, tuple) else (kind,)
+ str_kinds = {k for k in kinds if isinstance(k, str)}
+ type_kinds = {k.type for k in kinds if isinstance(k, np.dtype)}
+
+ if unknown_kind_types := set(kinds) - str_kinds - type_kinds:
+ raise TypeError(
+ f"kind must be str, np.dtype or a tuple of these, got {unknown_kind_types}"
+ )
+ if unknown_kinds := {k for k in str_kinds if k not in kind_mapping}:
+ raise ValueError(
+ f"unknown kind: {unknown_kinds}, must be a np.dtype or one of {list(kind_mapping)}"
+ )
+
+ # verified the dtypes already, no need to check again
+ translated_kinds = {kind_mapping[k] for k in str_kinds} | type_kinds
+ if isinstance(dtype, np.generic):
+ return isinstance(dtype, translated_kinds)
+ else:
+ return any(np.issubdtype(dtype, k) for k in translated_kinds)
diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py
index abe87b99..1d30fe9d 100644
--- a/xarray/core/nputils.py
+++ b/xarray/core/nputils.py
@@ -1,30 +1,75 @@
from __future__ import annotations
+
import warnings
from typing import Callable
+
import numpy as np
import pandas as pd
from packaging.version import Version
+
from xarray.core.utils import is_duck_array, module_available
from xarray.namedarray import pycompat
-if module_available('numpy', minversion='2.0.0.dev0'):
- from numpy.lib.array_utils import normalize_axis_index
+
+# remove once numpy 2.0 is the oldest supported version
+if module_available("numpy", minversion="2.0.0.dev0"):
+ from numpy.lib.array_utils import ( # type: ignore[import-not-found,unused-ignore]
+ normalize_axis_index,
+ )
else:
- from numpy.core.multiarray import normalize_axis_index
+ from numpy.core.multiarray import ( # type: ignore[attr-defined,no-redef,unused-ignore]
+ normalize_axis_index,
+ )
+
+# remove once numpy 2.0 is the oldest supported version
try:
- from numpy.exceptions import RankWarning
+ from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore]
except ImportError:
- from numpy import RankWarning
+ from numpy import RankWarning # type: ignore[attr-defined,no-redef,unused-ignore]
+
from xarray.core.options import OPTIONS
+
try:
import bottleneck as bn
+
_BOTTLENECK_AVAILABLE = True
except ImportError:
+ # use numpy methods instead
bn = np
_BOTTLENECK_AVAILABLE = False
-def inverse_permutation(indices: np.ndarray, N: (int | None)=None
- ) ->np.ndarray:
+def _select_along_axis(values, idx, axis):
+ other_ind = np.ix_(*[np.arange(s) for s in idx.shape])
+ sl = other_ind[:axis] + (idx,) + other_ind[axis:]
+ return values[sl]
+
+
+def nanfirst(values, axis, keepdims=False):
+ if isinstance(axis, tuple):
+ (axis,) = axis
+ axis = normalize_axis_index(axis, values.ndim)
+ idx_first = np.argmax(~pd.isnull(values), axis=axis)
+ result = _select_along_axis(values, idx_first, axis)
+ if keepdims:
+ return np.expand_dims(result, axis=axis)
+ else:
+ return result
+
+
+def nanlast(values, axis, keepdims=False):
+ if isinstance(axis, tuple):
+ (axis,) = axis
+ axis = normalize_axis_index(axis, values.ndim)
+ rev = (slice(None),) * axis + (slice(None, None, -1),)
+ idx_last = -1 - np.argmax(~pd.isnull(values)[rev], axis=axis)
+ result = _select_along_axis(values, idx_last, axis)
+ if keepdims:
+ return np.expand_dims(result, axis=axis)
+ else:
+ return result
+
+
+def inverse_permutation(indices: np.ndarray, N: int | None = None) -> np.ndarray:
"""Return indices for an inverse permutation.
Parameters
@@ -40,17 +85,73 @@ def inverse_permutation(indices: np.ndarray, N: (int | None)=None
Integer indices to take from the original array to create the
permutation.
"""
- pass
+ if N is None:
+ N = len(indices)
+ # use intp instead of int64 because of windows :(
+ inverse_permutation = np.full(N, -1, dtype=np.intp)
+ inverse_permutation[indices] = np.arange(len(indices), dtype=np.intp)
+ return inverse_permutation
+
+
+def _ensure_bool_is_ndarray(result, *args):
+ # numpy will sometimes return a scalar value from binary comparisons if it
+ # can't handle the comparison instead of broadcasting, e.g.,
+ # In [10]: 1 == np.array(['a', 'b'])
+ # Out[10]: False
+ # This function ensures that the result is the appropriate shape in these
+ # cases
+ if isinstance(result, bool):
+ shape = np.broadcast(*args).shape
+ constructor = np.ones if result else np.zeros
+ result = constructor(shape, dtype=bool)
+ return result
+
+
+def array_eq(self, other):
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", r"elementwise comparison failed")
+ return _ensure_bool_is_ndarray(self == other, self, other)
+
+
+def array_ne(self, other):
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore", r"elementwise comparison failed")
+ return _ensure_bool_is_ndarray(self != other, self, other)
def _is_contiguous(positions):
"""Given a non-empty list, does it consist of contiguous integers?"""
- pass
+ previous = positions[0]
+ for current in positions[1:]:
+ if current != previous + 1:
+ return False
+ previous = current
+ return True
def _advanced_indexer_subspaces(key):
"""Indices of the advanced indexes subspaces for mixed indexing and vindex."""
- pass
+ if not isinstance(key, tuple):
+ key = (key,)
+ advanced_index_positions = [
+ i for i, k in enumerate(key) if not isinstance(k, slice)
+ ]
+
+ if not advanced_index_positions or not _is_contiguous(advanced_index_positions):
+ # Nothing to reorder: dimensions on the indexing result are already
+ # ordered like vindex. See NumPy's rule for "Combining advanced and
+ # basic indexing":
+ # https://numpy.org/doc/stable/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
+ return (), ()
+
+ non_slices = [k for k in key if not isinstance(k, slice)]
+ broadcasted_shape = np.broadcast_shapes(
+ *[item.shape if is_duck_array(item) else (0,) for item in non_slices]
+ )
+ ndim = len(broadcasted_shape)
+ mixed_positions = advanced_index_positions[0] + np.arange(ndim)
+ vindex_positions = np.arange(ndim)
+ return mixed_positions, vindex_positions
class NumpyVIndexAdapter:
@@ -70,19 +171,128 @@ class NumpyVIndexAdapter:
def __setitem__(self, key, value):
"""Value must have dimensionality matching the key."""
mixed_positions, vindex_positions = _advanced_indexer_subspaces(key)
- self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions
+ self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions)
+
+
+def _create_method(name, npmodule=np) -> Callable:
+ def f(values, axis=None, **kwargs):
+ dtype = kwargs.get("dtype", None)
+ bn_func = getattr(bn, name, None)
+
+ if (
+ module_available("numbagg")
+ and pycompat.mod_version("numbagg") >= Version("0.5.0")
+ and OPTIONS["use_numbagg"]
+ and isinstance(values, np.ndarray)
+ # numbagg<0.7.0 uses ddof=1 only, but numpy uses ddof=0 by default
+ and (
+ pycompat.mod_version("numbagg") >= Version("0.7.0")
+ or ("var" not in name and "std" not in name)
+ or kwargs.get("ddof", 0) == 1
+ )
+ # TODO: bool?
+ and values.dtype.kind in "uif"
+ # and values.dtype.isnative
+ and (dtype is None or np.dtype(dtype) == values.dtype)
+ # numbagg.nanquantile only available after 0.8.0 and with linear method
+ and (
+ name != "nanquantile"
+ or (
+ pycompat.mod_version("numbagg") >= Version("0.8.0")
+ and kwargs.get("method", "linear") == "linear"
+ )
+ )
+ ):
+ import numbagg
+
+ nba_func = getattr(numbagg, name, None)
+ if nba_func is not None:
+ # numbagg does not use dtype
+ kwargs.pop("dtype", None)
+ # prior to 0.7.0, numbagg did not support ddof; we ensure it's limited
+ # to ddof=1 above.
+ if pycompat.mod_version("numbagg") < Version("0.7.0"):
+ kwargs.pop("ddof", None)
+ if name == "nanquantile":
+ kwargs["quantiles"] = kwargs.pop("q")
+ kwargs.pop("method", None)
+ return nba_func(values, axis=axis, **kwargs)
+ if (
+ _BOTTLENECK_AVAILABLE
+ and OPTIONS["use_bottleneck"]
+ and isinstance(values, np.ndarray)
+ and bn_func is not None
+ and not isinstance(axis, tuple)
+ and values.dtype.kind in "uifc"
+ and values.dtype.isnative
+ and (dtype is None or np.dtype(dtype) == values.dtype)
+ ):
+ # bottleneck does not take care dtype, min_count
+ kwargs.pop("dtype", None)
+ result = bn_func(values, axis=axis, **kwargs)
+ else:
+ result = getattr(npmodule, name)(values, axis=axis, **kwargs)
+
+ return result
+
+ f.__name__ = name
+ return f
+
+
+def _nanpolyfit_1d(arr, x, rcond=None):
+ out = np.full((x.shape[1] + 1,), np.nan)
+ mask = np.isnan(arr)
+ if not np.all(mask):
+ out[:-1], resid, rank, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond)
+ out[-1] = resid[0] if resid.size > 0 else np.nan
+ warn_on_deficient_rank(rank, x.shape[1])
+ return out
+
+
+def warn_on_deficient_rank(rank, order):
+ if rank != order:
+ warnings.warn("Polyfit may be poorly conditioned", RankWarning, stacklevel=2)
+
+
+def least_squares(lhs, rhs, rcond=None, skipna=False):
+ if skipna:
+ added_dim = rhs.ndim == 1
+ if added_dim:
+ rhs = rhs.reshape(rhs.shape[0], 1)
+ nan_cols = np.any(np.isnan(rhs), axis=0)
+ out = np.empty((lhs.shape[1] + 1, rhs.shape[1]))
+ if np.any(nan_cols):
+ out[:, nan_cols] = np.apply_along_axis(
+ _nanpolyfit_1d, 0, rhs[:, nan_cols], lhs
+ )
+ if np.any(~nan_cols):
+ out[:-1, ~nan_cols], resids, rank, _ = np.linalg.lstsq(
+ lhs, rhs[:, ~nan_cols], rcond=rcond
)
+ out[-1, ~nan_cols] = resids if resids.size > 0 else np.nan
+ warn_on_deficient_rank(rank, lhs.shape[1])
+ coeffs = out[:-1, :]
+ residuals = out[-1, :]
+ if added_dim:
+ coeffs = coeffs.reshape(coeffs.shape[0])
+ residuals = residuals.reshape(residuals.shape[0])
+ else:
+ coeffs, residuals, rank, _ = np.linalg.lstsq(lhs, rhs, rcond=rcond)
+ if residuals.size == 0:
+ residuals = coeffs[0] * np.nan
+ warn_on_deficient_rank(rank, lhs.shape[1])
+ return coeffs, residuals
-nanmin = _create_method('nanmin')
-nanmax = _create_method('nanmax')
-nanmean = _create_method('nanmean')
-nanmedian = _create_method('nanmedian')
-nanvar = _create_method('nanvar')
-nanstd = _create_method('nanstd')
-nanprod = _create_method('nanprod')
-nancumsum = _create_method('nancumsum')
-nancumprod = _create_method('nancumprod')
-nanargmin = _create_method('nanargmin')
-nanargmax = _create_method('nanargmax')
-nanquantile = _create_method('nanquantile')
+nanmin = _create_method("nanmin")
+nanmax = _create_method("nanmax")
+nanmean = _create_method("nanmean")
+nanmedian = _create_method("nanmedian")
+nanvar = _create_method("nanvar")
+nanstd = _create_method("nanstd")
+nanprod = _create_method("nanprod")
+nancumsum = _create_method("nancumsum")
+nancumprod = _create_method("nancumprod")
+nanargmin = _create_method("nanargmin")
+nanargmax = _create_method("nanargmax")
+nanquantile = _create_method("nanquantile")
diff --git a/xarray/core/ops.py b/xarray/core/ops.py
index f0f8e050..c67b4669 100644
--- a/xarray/core/ops.py
+++ b/xarray/core/ops.py
@@ -4,23 +4,61 @@ TODO(shoyer): rewrite this module, making use of xarray.core.computation,
NumPy's __array_ufunc__ and mixin classes instead of the unintuitive "inject"
functions.
"""
+
from __future__ import annotations
+
import operator
+
import numpy as np
+
from xarray.core import dtypes, duck_array_ops
+
try:
import bottleneck as bn
+
has_bottleneck = True
except ImportError:
+ # use numpy methods instead
bn = np
has_bottleneck = False
-NUM_BINARY_OPS = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow',
- 'and', 'xor', 'or', 'lshift', 'rshift']
-NUMPY_SAME_METHODS = ['item', 'searchsorted']
-REDUCE_METHODS = ['all', 'any']
-NAN_REDUCE_METHODS = ['max', 'min', 'mean', 'prod', 'sum', 'std', 'var',
- 'median']
-_CUM_DOCSTRING_TEMPLATE = """Apply `{name}` along some dimension of {cls}.
+
+
+NUM_BINARY_OPS = [
+ "add",
+ "sub",
+ "mul",
+ "truediv",
+ "floordiv",
+ "mod",
+ "pow",
+ "and",
+ "xor",
+ "or",
+ "lshift",
+ "rshift",
+]
+
+# methods which pass on the numpy return value unchanged
+# be careful not to list methods that we would want to wrap later
+NUMPY_SAME_METHODS = ["item", "searchsorted"]
+
+# methods which remove an axis
+REDUCE_METHODS = ["all", "any"]
+NAN_REDUCE_METHODS = [
+ "max",
+ "min",
+ "mean",
+ "prod",
+ "sum",
+ "std",
+ "var",
+ "median",
+]
+# TODO: wrap take, dot, sort
+
+
+_CUM_DOCSTRING_TEMPLATE = """\
+Apply `{name}` along some dimension of {cls}.
Parameters
----------
@@ -43,7 +81,9 @@ cumvalue : {cls}
New {cls} object with `{name}` applied to its data along the
indicated dimension.
"""
-_REDUCE_DOCSTRING_TEMPLATE = """Reduce this {cls}'s data by applying `{name}` along some dimension(s).
+
+_REDUCE_DOCSTRING_TEMPLATE = """\
+Reduce this {cls}'s data by applying `{name}` along some dimension(s).
Parameters
----------
@@ -62,12 +102,14 @@ reduced : {cls}
New {cls} object with `{name}` applied to its data and the
indicated dimension(s) removed.
"""
+
_SKIPNA_DOCSTRING = """
skipna : bool, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64)."""
+
_MINCOUNT_DOCSTRING = """
min_count : int, default: None
The required number of valid values to perform the operation. If
@@ -78,7 +120,7 @@ min_count : int, default: None
and skipna=True, the result will be a float array."""
-def fillna(data, other, join='left', dataset_join='left'):
+def fillna(data, other, join="left", dataset_join="left"):
"""Fill missing values in this object with data from the other object.
Follows normal broadcasting and alignment rules.
@@ -101,7 +143,18 @@ def fillna(data, other, join='left', dataset_join='left'):
- "left": take only variables from the first object
- "right": take only variables from the last object
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ return apply_ufunc(
+ duck_array_ops.fillna,
+ data,
+ other,
+ join=join,
+ dask="allowed",
+ dataset_join=dataset_join,
+ dataset_fill_value=np.nan,
+ keep_attrs=True,
+ )
def where_method(self, cond, other=dtypes.NA):
@@ -119,14 +172,124 @@ def where_method(self, cond, other=dtypes.NA):
-------
Same type as caller.
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ # alignment for three arguments is complicated, so don't support it yet
+ join = "inner" if other is dtypes.NA else "exact"
+ return apply_ufunc(
+ duck_array_ops.where_method,
+ self,
+ cond,
+ other,
+ join=join,
+ dataset_join=join,
+ dask="allowed",
+ keep_attrs=True,
+ )
+
+
+def _call_possibly_missing_method(arg, name, args, kwargs):
+ try:
+ method = getattr(arg, name)
+ except AttributeError:
+ duck_array_ops.fail_on_dask_array_input(arg, func_name=name)
+ if hasattr(arg, "data"):
+ duck_array_ops.fail_on_dask_array_input(arg.data, func_name=name)
+ raise
+ else:
+ return method(*args, **kwargs)
+
+
+def _values_method_wrapper(name):
+ def func(self, *args, **kwargs):
+ return _call_possibly_missing_method(self.data, name, args, kwargs)
+
+ func.__name__ = name
+ func.__doc__ = getattr(np.ndarray, name).__doc__
+ return func
-NON_INPLACE_OP = {get_op('i' + name): get_op(name) for name in NUM_BINARY_OPS}
-argsort = _method_wrapper('argsort')
-conj = _method_wrapper('conj')
-conjugate = _method_wrapper('conjugate')
-round_ = _func_slash_method_wrapper(duck_array_ops.around, name='round')
+def _method_wrapper(name):
+ def func(self, *args, **kwargs):
+ return _call_possibly_missing_method(self, name, args, kwargs)
+
+ func.__name__ = name
+ func.__doc__ = getattr(np.ndarray, name).__doc__
+ return func
+
+
+def _func_slash_method_wrapper(f, name=None):
+ # try to wrap a method, but if not found use the function
+ # this is useful when patching in a function as both a DataArray and
+ # Dataset method
+ if name is None:
+ name = f.__name__
+
+ def func(self, *args, **kwargs):
+ try:
+ return getattr(self, name)(*args, **kwargs)
+ except AttributeError:
+ return f(self, *args, **kwargs)
+
+ func.__name__ = name
+ func.__doc__ = f.__doc__
+ return func
+
+
+def inject_reduce_methods(cls):
+ methods = (
+ [
+ (name, getattr(duck_array_ops, f"array_{name}"), False)
+ for name in REDUCE_METHODS
+ ]
+ + [(name, getattr(duck_array_ops, name), True) for name in NAN_REDUCE_METHODS]
+ + [("count", duck_array_ops.count, False)]
+ )
+ for name, f, include_skipna in methods:
+ numeric_only = getattr(f, "numeric_only", False)
+ available_min_count = getattr(f, "available_min_count", False)
+ skip_na_docs = _SKIPNA_DOCSTRING if include_skipna else ""
+ min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else ""
+
+ func = cls._reduce_method(f, include_skipna, numeric_only)
+ func.__name__ = name
+ func.__doc__ = _REDUCE_DOCSTRING_TEMPLATE.format(
+ name=name,
+ cls=cls.__name__,
+ extra_args=cls._reduce_extra_args_docstring.format(name=name),
+ skip_na_docs=skip_na_docs,
+ min_count_docs=min_count_docs,
+ )
+ setattr(cls, name, func)
+
+
+def op_str(name):
+ return f"__{name}__"
+
+
+def get_op(name):
+ return getattr(operator, op_str(name))
+
+
+NON_INPLACE_OP = {get_op("i" + name): get_op(name) for name in NUM_BINARY_OPS}
+
+
+def inplace_to_noninplace_op(f):
+ return NON_INPLACE_OP[f]
+
+
+# _typed_ops.py uses the following wrapped functions as a kind of unary operator
+argsort = _method_wrapper("argsort")
+conj = _method_wrapper("conj")
+conjugate = _method_wrapper("conjugate")
+round_ = _func_slash_method_wrapper(duck_array_ops.around, name="round")
+
+
+def inject_numpy_same(cls):
+ # these methods don't return arrays of the same shape as the input, so
+ # don't try to patch these in for Dataset objects
+ for name in NUMPY_SAME_METHODS:
+ setattr(cls, name, _values_method_wrapper(name))
class IncludeReduceMethods:
@@ -134,7 +297,8 @@ class IncludeReduceMethods:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
- if getattr(cls, '_reduce_method', None):
+
+ if getattr(cls, "_reduce_method", None):
inject_reduce_methods(cls)
@@ -143,4 +307,5 @@ class IncludeNumpySameMethods:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
- inject_numpy_same(cls)
+
+ inject_numpy_same(cls) # some methods not applicable to Dataset objects
diff --git a/xarray/core/options.py b/xarray/core/options.py
index ca162668..f5614104 100644
--- a/xarray/core/options.py
+++ b/xarray/core/options.py
@@ -1,77 +1,157 @@
from __future__ import annotations
+
import warnings
from typing import TYPE_CHECKING, Literal, TypedDict
+
from xarray.core.utils import FrozenDict
+
if TYPE_CHECKING:
from matplotlib.colors import Colormap
- Options = Literal['arithmetic_join', 'cmap_divergent',
- 'cmap_sequential', 'display_max_rows', 'display_values_threshold',
- 'display_style', 'display_width', 'display_expand_attrs',
- 'display_expand_coords', 'display_expand_data_vars',
- 'display_expand_data', 'display_expand_groups',
- 'display_expand_indexes', 'display_default_indexes',
- 'enable_cftimeindex', 'file_cache_maxsize', 'keep_attrs',
- 'warn_for_unclosed_files', 'use_bottleneck', 'use_numbagg',
- 'use_opt_einsum', 'use_flox']
+ Options = Literal[
+ "arithmetic_join",
+ "cmap_divergent",
+ "cmap_sequential",
+ "display_max_rows",
+ "display_values_threshold",
+ "display_style",
+ "display_width",
+ "display_expand_attrs",
+ "display_expand_coords",
+ "display_expand_data_vars",
+ "display_expand_data",
+ "display_expand_groups",
+ "display_expand_indexes",
+ "display_default_indexes",
+ "enable_cftimeindex",
+ "file_cache_maxsize",
+ "keep_attrs",
+ "warn_for_unclosed_files",
+ "use_bottleneck",
+ "use_numbagg",
+ "use_opt_einsum",
+ "use_flox",
+ ]
class T_Options(TypedDict):
arithmetic_broadcast: bool
- arithmetic_join: Literal['inner', 'outer', 'left', 'right', 'exact']
+ arithmetic_join: Literal["inner", "outer", "left", "right", "exact"]
cmap_divergent: str | Colormap
cmap_sequential: str | Colormap
display_max_rows: int
display_values_threshold: int
- display_style: Literal['text', 'html']
+ display_style: Literal["text", "html"]
display_width: int
- display_expand_attrs: Literal['default', True, False]
- display_expand_coords: Literal['default', True, False]
- display_expand_data_vars: Literal['default', True, False]
- display_expand_data: Literal['default', True, False]
- display_expand_groups: Literal['default', True, False]
- display_expand_indexes: Literal['default', True, False]
- display_default_indexes: Literal['default', True, False]
+ display_expand_attrs: Literal["default", True, False]
+ display_expand_coords: Literal["default", True, False]
+ display_expand_data_vars: Literal["default", True, False]
+ display_expand_data: Literal["default", True, False]
+ display_expand_groups: Literal["default", True, False]
+ display_expand_indexes: Literal["default", True, False]
+ display_default_indexes: Literal["default", True, False]
enable_cftimeindex: bool
file_cache_maxsize: int
- keep_attrs: Literal['default', True, False]
+ keep_attrs: Literal["default", True, False]
warn_for_unclosed_files: bool
use_bottleneck: bool
use_flox: bool
use_numbagg: bool
use_opt_einsum: bool
-OPTIONS: T_Options = {'arithmetic_broadcast': True, 'arithmetic_join':
- 'inner', 'cmap_divergent': 'RdBu_r', 'cmap_sequential': 'viridis',
- 'display_max_rows': 12, 'display_values_threshold': 200,
- 'display_style': 'html', 'display_width': 80, 'display_expand_attrs':
- 'default', 'display_expand_coords': 'default',
- 'display_expand_data_vars': 'default', 'display_expand_data': 'default',
- 'display_expand_groups': 'default', 'display_expand_indexes': 'default',
- 'display_default_indexes': False, 'enable_cftimeindex': True,
- 'file_cache_maxsize': 128, 'keep_attrs': 'default',
- 'warn_for_unclosed_files': False, 'use_bottleneck': True, 'use_flox':
- True, 'use_numbagg': True, 'use_opt_einsum': True}
-_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact'])
-_DISPLAY_OPTIONS = frozenset(['text', 'html'])
-_VALIDATORS = {'arithmetic_broadcast': lambda value: isinstance(value, bool
- ), 'arithmetic_join': _JOIN_OPTIONS.__contains__, 'display_max_rows':
- _positive_integer, 'display_values_threshold': _positive_integer,
- 'display_style': _DISPLAY_OPTIONS.__contains__, 'display_width':
- _positive_integer, 'display_expand_attrs': lambda choice: choice in [
- True, False, 'default'], 'display_expand_coords': lambda choice: choice in
- [True, False, 'default'], 'display_expand_data_vars': lambda choice:
- choice in [True, False, 'default'], 'display_expand_data': lambda
- choice: choice in [True, False, 'default'], 'display_expand_indexes':
- lambda choice: choice in [True, False, 'default'],
- 'display_default_indexes': lambda choice: choice in [True, False,
- 'default'], 'enable_cftimeindex': lambda value: isinstance(value, bool),
- 'file_cache_maxsize': _positive_integer, 'keep_attrs': lambda choice:
- choice in [True, False, 'default'], 'use_bottleneck': lambda value:
- isinstance(value, bool), 'use_numbagg': lambda value: isinstance(value,
- bool), 'use_opt_einsum': lambda value: isinstance(value, bool),
- 'use_flox': lambda value: isinstance(value, bool),
- 'warn_for_unclosed_files': lambda value: isinstance(value, bool)}
-_SETTERS = {'enable_cftimeindex': _warn_on_setting_enable_cftimeindex,
- 'file_cache_maxsize': _set_file_cache_maxsize}
+
+
+OPTIONS: T_Options = {
+ "arithmetic_broadcast": True,
+ "arithmetic_join": "inner",
+ "cmap_divergent": "RdBu_r",
+ "cmap_sequential": "viridis",
+ "display_max_rows": 12,
+ "display_values_threshold": 200,
+ "display_style": "html",
+ "display_width": 80,
+ "display_expand_attrs": "default",
+ "display_expand_coords": "default",
+ "display_expand_data_vars": "default",
+ "display_expand_data": "default",
+ "display_expand_groups": "default",
+ "display_expand_indexes": "default",
+ "display_default_indexes": False,
+ "enable_cftimeindex": True,
+ "file_cache_maxsize": 128,
+ "keep_attrs": "default",
+ "warn_for_unclosed_files": False,
+ "use_bottleneck": True,
+ "use_flox": True,
+ "use_numbagg": True,
+ "use_opt_einsum": True,
+}
+
+_JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"])
+_DISPLAY_OPTIONS = frozenset(["text", "html"])
+
+
+def _positive_integer(value: int) -> bool:
+ return isinstance(value, int) and value > 0
+
+
+_VALIDATORS = {
+ "arithmetic_broadcast": lambda value: isinstance(value, bool),
+ "arithmetic_join": _JOIN_OPTIONS.__contains__,
+ "display_max_rows": _positive_integer,
+ "display_values_threshold": _positive_integer,
+ "display_style": _DISPLAY_OPTIONS.__contains__,
+ "display_width": _positive_integer,
+ "display_expand_attrs": lambda choice: choice in [True, False, "default"],
+ "display_expand_coords": lambda choice: choice in [True, False, "default"],
+ "display_expand_data_vars": lambda choice: choice in [True, False, "default"],
+ "display_expand_data": lambda choice: choice in [True, False, "default"],
+ "display_expand_indexes": lambda choice: choice in [True, False, "default"],
+ "display_default_indexes": lambda choice: choice in [True, False, "default"],
+ "enable_cftimeindex": lambda value: isinstance(value, bool),
+ "file_cache_maxsize": _positive_integer,
+ "keep_attrs": lambda choice: choice in [True, False, "default"],
+ "use_bottleneck": lambda value: isinstance(value, bool),
+ "use_numbagg": lambda value: isinstance(value, bool),
+ "use_opt_einsum": lambda value: isinstance(value, bool),
+ "use_flox": lambda value: isinstance(value, bool),
+ "warn_for_unclosed_files": lambda value: isinstance(value, bool),
+}
+
+
+def _set_file_cache_maxsize(value) -> None:
+ from xarray.backends.file_manager import FILE_CACHE
+
+ FILE_CACHE.maxsize = value
+
+
+def _warn_on_setting_enable_cftimeindex(enable_cftimeindex):
+ warnings.warn(
+ "The enable_cftimeindex option is now a no-op "
+ "and will be removed in a future version of xarray.",
+ FutureWarning,
+ )
+
+
+_SETTERS = {
+ "enable_cftimeindex": _warn_on_setting_enable_cftimeindex,
+ "file_cache_maxsize": _set_file_cache_maxsize,
+}
+
+
+def _get_boolean_with_default(option: Options, default: bool) -> bool:
+ global_choice = OPTIONS[option]
+
+ if global_choice == "default":
+ return default
+ elif isinstance(global_choice, bool):
+ return global_choice
+ else:
+ raise ValueError(
+ f"The global option {option} must be one of True, False or 'default'."
+ )
+
+
+def _get_keep_attrs(default: bool) -> bool:
+ return _get_boolean_with_default("keep_attrs", default)
class set_options:
@@ -199,20 +279,27 @@ class set_options:
for k, v in kwargs.items():
if k not in OPTIONS:
raise ValueError(
- f'argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}'
- )
+ f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}"
+ )
if k in _VALIDATORS and not _VALIDATORS[k](v):
- if k == 'arithmetic_join':
- expected = f'Expected one of {_JOIN_OPTIONS!r}'
- elif k == 'display_style':
- expected = f'Expected one of {_DISPLAY_OPTIONS!r}'
+ if k == "arithmetic_join":
+ expected = f"Expected one of {_JOIN_OPTIONS!r}"
+ elif k == "display_style":
+ expected = f"Expected one of {_DISPLAY_OPTIONS!r}"
else:
- expected = ''
+ expected = ""
raise ValueError(
- f'option {k!r} given an invalid value: {v!r}. ' + expected)
+ f"option {k!r} given an invalid value: {v!r}. " + expected
+ )
self.old[k] = OPTIONS[k]
self._apply_update(kwargs)
+ def _apply_update(self, options_dict):
+ for k, v in options_dict.items():
+ if k in _SETTERS:
+ _SETTERS[k](v)
+ OPTIONS.update(options_dict)
+
def __enter__(self):
return
@@ -229,4 +316,4 @@ def get_options():
set_options
"""
- pass
+ return FrozenDict(OPTIONS)
diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py
index 9f983816..41311497 100644
--- a/xarray/core/parallel.py
+++ b/xarray/core/parallel.py
@@ -1,10 +1,13 @@
from __future__ import annotations
+
import collections
import itertools
import operator
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
+
import numpy as np
+
from xarray.core.alignment import align
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
@@ -13,6 +16,7 @@ from xarray.core.indexes import Index
from xarray.core.merge import merge
from xarray.core.utils import is_dask_collection
from xarray.core.variable import Variable
+
if TYPE_CHECKING:
from xarray.core.types import T_Xarray
@@ -24,41 +28,211 @@ class ExpectedDict(TypedDict):
indexes: dict[Hashable, Index]
+def unzip(iterable):
+ return zip(*iterable)
+
+
+def assert_chunks_compatible(a: Dataset, b: Dataset):
+ a = a.unify_chunks()
+ b = b.unify_chunks()
+
+ for dim in set(a.chunks).intersection(set(b.chunks)):
+ if a.chunks[dim] != b.chunks[dim]:
+ raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.")
+
+
+def check_result_variables(
+ result: DataArray | Dataset,
+ expected: ExpectedDict,
+ kind: Literal["coords", "data_vars"],
+):
+ if kind == "coords":
+ nice_str = "coordinate"
+ elif kind == "data_vars":
+ nice_str = "data"
+
+ # check that coords and data variables are as expected
+ missing = expected[kind] - set(getattr(result, kind))
+ if missing:
+ raise ValueError(
+ "Result from applying user function does not contain "
+ f"{nice_str} variables {missing}."
+ )
+ extra = set(getattr(result, kind)) - expected[kind]
+ if extra:
+ raise ValueError(
+ "Result from applying user function has unexpected "
+ f"{nice_str} variables {extra}."
+ )
+
+
+def dataset_to_dataarray(obj: Dataset) -> DataArray:
+ if not isinstance(obj, Dataset):
+ raise TypeError(f"Expected Dataset, got {type(obj)}")
+
+ if len(obj.data_vars) > 1:
+ raise TypeError(
+ "Trying to convert Dataset with more than one data variable to DataArray"
+ )
+
+ return next(iter(obj.data_vars.values()))
+
+
+def dataarray_to_dataset(obj: DataArray) -> Dataset:
+ # only using _to_temp_dataset would break
+ # func = lambda x: x.to_dataset()
+ # since that relies on preserving name.
+ if obj.name is None:
+ dataset = obj._to_temp_dataset()
+ else:
+ dataset = obj.to_dataset()
+ return dataset
+
+
def make_meta(obj):
"""If obj is a DataArray or Dataset, return a new object of the same type and with
the same variables and dtypes, but where all variables have size 0 and numpy
backend.
If obj is neither a DataArray nor Dataset, return it unaltered.
"""
- pass
+ if isinstance(obj, DataArray):
+ obj_array = obj
+ obj = dataarray_to_dataset(obj)
+ elif isinstance(obj, Dataset):
+ obj_array = None
+ else:
+ return obj
+
+ from dask.array.utils import meta_from_array
+
+ meta = Dataset()
+ for name, variable in obj.variables.items():
+ meta_obj = meta_from_array(variable.data, ndim=variable.ndim)
+ meta[name] = (variable.dims, meta_obj, variable.attrs)
+ meta.attrs = obj.attrs
+ meta = meta.set_coords(obj.coords)
+ if obj_array is not None:
+ return dataset_to_dataarray(meta)
+ return meta
-def infer_template(func: Callable[..., T_Xarray], obj: (DataArray | Dataset
- ), *args, **kwargs) ->T_Xarray:
+
+def infer_template(
+ func: Callable[..., T_Xarray], obj: DataArray | Dataset, *args, **kwargs
+) -> T_Xarray:
"""Infer return object by running the function on meta objects."""
- pass
+ meta_args = [make_meta(arg) for arg in (obj,) + args]
+
+ try:
+ template = func(*meta_args, **kwargs)
+ except Exception as e:
+ raise Exception(
+ "Cannot infer object returned from running user provided function. "
+ "Please supply the 'template' kwarg to map_blocks."
+ ) from e
+
+ if not isinstance(template, (Dataset, DataArray)):
+ raise TypeError(
+ "Function must return an xarray DataArray or Dataset. Instead it returned "
+ f"{type(template)}"
+ )
+ return template
-def make_dict(x: (DataArray | Dataset)) ->dict[Hashable, Any]:
+
+def make_dict(x: DataArray | Dataset) -> dict[Hashable, Any]:
"""Map variable name to numpy(-like) data
(Dataset.to_dict() is too complicated).
"""
- pass
+ if isinstance(x, DataArray):
+ x = x._to_temp_dataset()
+
+ return {k: v.data for k, v in x.variables.items()}
-def subset_dataset_to_block(graph: dict, gname: str, dataset: Dataset,
- input_chunk_bounds, chunk_index):
+def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping):
+ if dim in chunk_index:
+ which_chunk = chunk_index[dim]
+ return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1])
+ return slice(None)
+
+
+def subset_dataset_to_block(
+ graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index
+):
"""
Creates a task that subsets an xarray dataset to a block determined by chunk_index.
Block extents are determined by input_chunk_bounds.
Also subtasks that subset the constituent variables of a dataset.
"""
- pass
+ import dask
+
+ # this will become [[name1, variable1],
+ # [name2, variable2],
+ # ...]
+ # which is passed to dict and then to Dataset
+ data_vars = []
+ coords = []
+
+ chunk_tuple = tuple(chunk_index.values())
+ chunk_dims_set = set(chunk_index)
+ variable: Variable
+ for name, variable in dataset.variables.items():
+ # make a task that creates tuple of (dims, chunk)
+ if dask.is_dask_collection(variable.data):
+ # get task name for chunk
+ chunk = (
+ variable.data.name,
+ *tuple(chunk_index[dim] for dim in variable.dims),
+ )
+
+ chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple
+ graph[chunk_variable_task] = (
+ tuple,
+ [variable.dims, chunk, variable.attrs],
+ )
+ else:
+ assert name in dataset.dims or variable.ndim == 0
+ # non-dask array possibly with dimensions chunked on other variables
+ # index into variable appropriately
+ subsetter = {
+ dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds)
+ for dim in variable.dims
+ }
+ if set(variable.dims) < chunk_dims_set:
+ this_var_chunk_tuple = tuple(chunk_index[dim] for dim in variable.dims)
+ else:
+ this_var_chunk_tuple = chunk_tuple
-def map_blocks(func: Callable[..., T_Xarray], obj: (DataArray | Dataset),
- args: Sequence[Any]=(), kwargs: (Mapping[str, Any] | None)=None,
- template: (DataArray | Dataset | None)=None) ->T_Xarray:
+ chunk_variable_task = (
+ f"{name}-{gname}-{dask.base.tokenize(subsetter)}",
+ ) + this_var_chunk_tuple
+ # We are including a dimension coordinate,
+ # minimize duplication by not copying it in the graph for every chunk.
+ if variable.ndim == 0 or chunk_variable_task not in graph:
+ subset = variable.isel(subsetter)
+ graph[chunk_variable_task] = (
+ tuple,
+ [subset.dims, subset._data, subset.attrs],
+ )
+
+ # this task creates dict mapping variable name to above tuple
+ if name in dataset._coord_names:
+ coords.append([name, chunk_variable_task])
+ else:
+ data_vars.append([name, chunk_variable_task])
+
+ return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)
+
+
+def map_blocks(
+ func: Callable[..., T_Xarray],
+ obj: DataArray | Dataset,
+ args: Sequence[Any] = (),
+ kwargs: Mapping[str, Any] | None = None,
+ template: DataArray | Dataset | None = None,
+) -> T_Xarray:
"""Apply a function to each block of a DataArray or Dataset.
.. warning::
@@ -156,4 +330,310 @@ def map_blocks(func: Callable[..., T_Xarray], obj: (DataArray | Dataset),
* time (time) object 192B 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
month (time) int64 192B dask.array<chunksize=(24,), meta=np.ndarray>
"""
- pass
+
+ def _wrapper(
+ func: Callable,
+ args: list,
+ kwargs: dict,
+ arg_is_array: Iterable[bool],
+ expected: ExpectedDict,
+ ):
+ """
+ Wrapper function that receives datasets in args; converts to dataarrays when necessary;
+ passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc.
+ """
+
+ converted_args = [
+ dataset_to_dataarray(arg) if is_array else arg
+ for is_array, arg in zip(arg_is_array, args)
+ ]
+
+ result = func(*converted_args, **kwargs)
+
+ merged_coordinates = merge(
+ [arg.coords for arg in args if isinstance(arg, (Dataset, DataArray))]
+ ).coords
+
+ # check all dims are present
+ missing_dimensions = set(expected["shapes"]) - set(result.sizes)
+ if missing_dimensions:
+ raise ValueError(
+ f"Dimensions {missing_dimensions} missing on returned object."
+ )
+
+ # check that index lengths and values are as expected
+ for name, index in result._indexes.items():
+ if name in expected["shapes"]:
+ if result.sizes[name] != expected["shapes"][name]:
+ raise ValueError(
+ f"Received dimension {name!r} of length {result.sizes[name]}. "
+ f"Expected length {expected['shapes'][name]}."
+ )
+
+ # ChainMap wants MutableMapping, but xindexes is Mapping
+ merged_indexes = collections.ChainMap(
+ expected["indexes"], merged_coordinates.xindexes # type: ignore[arg-type]
+ )
+ expected_index = merged_indexes.get(name, None)
+ if expected_index is not None and not index.equals(expected_index):
+ raise ValueError(
+ f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead."
+ )
+
+ # check that all expected variables were returned
+ check_result_variables(result, expected, "coords")
+ if isinstance(result, Dataset):
+ check_result_variables(result, expected, "data_vars")
+
+ return make_dict(result)
+
+ if template is not None and not isinstance(template, (DataArray, Dataset)):
+ raise TypeError(
+ f"template must be a DataArray or Dataset. Received {type(template).__name__} instead."
+ )
+ if not isinstance(args, Sequence):
+ raise TypeError("args must be a sequence (for example, a list or tuple).")
+ if kwargs is None:
+ kwargs = {}
+ elif not isinstance(kwargs, Mapping):
+ raise TypeError("kwargs must be a mapping (for example, a dict)")
+
+ for value in kwargs.values():
+ if is_dask_collection(value):
+ raise TypeError(
+ "Cannot pass dask collections in kwargs yet. Please compute or "
+ "load values before passing to map_blocks."
+ )
+
+ if not is_dask_collection(obj):
+ return func(obj, *args, **kwargs)
+
+ try:
+ import dask
+ import dask.array
+ from dask.highlevelgraph import HighLevelGraph
+
+ except ImportError:
+ pass
+
+ all_args = [obj] + list(args)
+ is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args]
+ is_array = [isinstance(arg, DataArray) for arg in all_args]
+
+ # there should be a better way to group this. partition?
+ xarray_indices, xarray_objs = unzip(
+ (index, arg) for index, arg in enumerate(all_args) if is_xarray[index]
+ )
+ others = [
+ (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index]
+ ]
+
+ # all xarray objects must be aligned. This is consistent with apply_ufunc.
+ aligned = align(*xarray_objs, join="exact")
+ xarray_objs = tuple(
+ dataarray_to_dataset(arg) if isinstance(arg, DataArray) else arg
+ for arg in aligned
+ )
+ # rechunk any numpy variables appropriately
+ xarray_objs = tuple(arg.chunk(arg.chunksizes) for arg in xarray_objs)
+
+ merged_coordinates = merge([arg.coords for arg in aligned]).coords
+
+ _, npargs = unzip(
+ sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
+ )
+
+ # check that chunk sizes are compatible
+ input_chunks = dict(npargs[0].chunks)
+ for arg in xarray_objs[1:]:
+ assert_chunks_compatible(npargs[0], arg)
+ input_chunks.update(arg.chunks)
+
+ coordinates: Coordinates
+ if template is None:
+ # infer template by providing zero-shaped arrays
+ template = infer_template(func, aligned[0], *args, **kwargs)
+ template_coords = set(template.coords)
+ preserved_coord_vars = template_coords & set(merged_coordinates)
+ new_coord_vars = template_coords - set(merged_coordinates)
+
+ preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
+ # preserved_coords contains all coordinates variables that share a dimension
+ # with any index variable in preserved_indexes
+ # Drop any unneeded vars in a second pass, this is required for e.g.
+ # if the mapped function were to drop a non-dimension coordinate variable.
+ preserved_coords = preserved_coords.drop_vars(
+ tuple(k for k in preserved_coords.variables if k not in template_coords)
+ )
+
+ coordinates = merge(
+ (preserved_coords, template.coords.to_dataset()[new_coord_vars])
+ ).coords
+ output_chunks: Mapping[Hashable, tuple[int, ...]] = {
+ dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
+ }
+
+ else:
+ # template xarray object has been provided with proper sizes and chunk shapes
+ coordinates = template.coords
+ output_chunks = template.chunksizes
+ if not output_chunks:
+ raise ValueError(
+ "Provided template has no dask arrays. "
+ " Please construct a template with appropriately chunked dask arrays."
+ )
+
+ new_indexes = set(template.xindexes) - set(merged_coordinates)
+ modified_indexes = set(
+ name
+ for name, xindex in coordinates.xindexes.items()
+ if not xindex.equals(merged_coordinates.xindexes.get(name, None))
+ )
+
+ for dim in output_chunks:
+ if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):
+ raise ValueError(
+ "map_blocks requires that one block of the input maps to one block of output. "
+ f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. "
+ f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or "
+ "fix the provided template."
+ )
+
+ if isinstance(template, DataArray):
+ result_is_array = True
+ template_name = template.name
+ template = template._to_temp_dataset()
+ elif isinstance(template, Dataset):
+ result_is_array = False
+ else:
+ raise TypeError(
+ f"func output must be DataArray or Dataset; got {type(template)}"
+ )
+
+ # We're building a new HighLevelGraph hlg. We'll have one new layer
+ # for each variable in the dataset, which is the result of the
+ # func applied to the values.
+
+ graph: dict[Any, Any] = {}
+ new_layers: collections.defaultdict[str, dict[Any, Any]] = collections.defaultdict(
+ dict
+ )
+ gname = f"{dask.utils.funcname(func)}-{dask.base.tokenize(npargs[0], args, kwargs)}"
+
+ # map dims to list of chunk indexes
+ ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()}
+ # mapping from chunk index to slice bounds
+ input_chunk_bounds = {
+ dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items()
+ }
+ output_chunk_bounds = {
+ dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items()
+ }
+
+ computed_variables = set(template.variables) - set(coordinates.indexes)
+ # iterate over all possible chunk combinations
+ for chunk_tuple in itertools.product(*ichunk.values()):
+ # mapping from dimension name to chunk index
+ chunk_index = dict(zip(ichunk.keys(), chunk_tuple))
+
+ blocked_args = [
+ (
+ subset_dataset_to_block(
+ graph, gname, arg, input_chunk_bounds, chunk_index
+ )
+ if isxr
+ else arg
+ )
+ for isxr, arg in zip(is_xarray, npargs)
+ ]
+
+ # raise nice error messages in _wrapper
+ expected: ExpectedDict = {
+ # input chunk 0 along a dimension maps to output chunk 0 along the same dimension
+ # even if length of dimension is changed by the applied function
+ "shapes": {
+ k: output_chunks[k][v]
+ for k, v in chunk_index.items()
+ if k in output_chunks
+ },
+ "data_vars": set(template.data_vars.keys()),
+ "coords": set(template.coords.keys()),
+ # only include new or modified indexes to minimize duplication of data, and graph size.
+ "indexes": {
+ dim: coordinates.xindexes[dim][
+ _get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
+ ]
+ for dim in (new_indexes | modified_indexes)
+ },
+ }
+
+ from_wrapper = (gname,) + chunk_tuple
+ graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
+
+ # mapping from variable name to dask graph key
+ var_key_map: dict[Hashable, str] = {}
+ for name in computed_variables:
+ variable = template.variables[name]
+ gname_l = f"{name}-{gname}"
+ var_key_map[name] = gname_l
+
+ # unchunked dimensions in the input have one chunk in the result
+ # output can have new dimensions with exactly one chunk
+ key: tuple[Any, ...] = (gname_l,) + tuple(
+ chunk_index[dim] if dim in chunk_index else 0 for dim in variable.dims
+ )
+
+ # We're adding multiple new layers to the graph:
+ # The first new layer is the result of the computation on
+ # the array.
+ # Then we add one layer per variable, which extracts the
+ # result for that variable, and depends on just the first new
+ # layer.
+ new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)
+
+ hlg = HighLevelGraph.from_collections(
+ gname,
+ graph,
+ dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)],
+ )
+
+ # This adds in the getitems for each variable in the dataset.
+ hlg = HighLevelGraph(
+ {**hlg.layers, **new_layers},
+ dependencies={
+ **hlg.dependencies,
+ **{name: {gname} for name in new_layers.keys()},
+ },
+ )
+
+ result = Dataset(coords=coordinates, attrs=template.attrs)
+
+ for index in result._indexes:
+ result[index].attrs = template[index].attrs
+ result[index].encoding = template[index].encoding
+
+ for name, gname_l in var_key_map.items():
+ dims = template[name].dims
+ var_chunks = []
+ for dim in dims:
+ if dim in output_chunks:
+ var_chunks.append(output_chunks[dim])
+ elif dim in result._indexes:
+ var_chunks.append((result.sizes[dim],))
+ elif dim in template.dims:
+ # new unindexed dimension
+ var_chunks.append((template.sizes[dim],))
+
+ data = dask.array.Array(
+ hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
+ )
+ result[name] = (dims, data, template[name].attrs)
+ result[name].encoding = template[name].encoding
+
+ result = result.set_coords(template._coord_names)
+
+ if result_is_array:
+ da = dataset_to_dataarray(result)
+ da.name = template_name
+ return da # type: ignore[return-value]
+ return result # type: ignore[return-value]
diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py
index cecf4d90..ae4febd6 100644
--- a/xarray/core/pdcompat.py
+++ b/xarray/core/pdcompat.py
@@ -1,16 +1,53 @@
+# For reference, here is a copy of the pandas copyright notice:
+
+# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2008-2011 AQR Capital Management, LLC
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following
+# disclaimer in the documentation and/or other materials provided
+# with the distribution.
+
+# * Neither the name of the copyright holder nor the names of any
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
from enum import Enum
from typing import Literal
+
import pandas as pd
from packaging.version import Version
-def count_not_none(*args) ->int:
+def count_not_none(*args) -> int:
"""Compute the number of non-None arguments.
Copied from pandas.core.common.count_not_none (not part of the public API)
"""
- pass
+ return sum(arg is not None for arg in args)
class _NoDefault(Enum):
@@ -23,20 +60,26 @@ class _NoDefault(Enum):
- pandas-dev/pandas#40715
- pandas-dev/pandas#47045
"""
- no_default = 'NO_DEFAULT'
- def __repr__(self) ->str:
- return '<no_default>'
+ no_default = "NO_DEFAULT"
+
+ def __repr__(self) -> str:
+ return "<no_default>"
-no_default = _NoDefault.no_default
-NoDefault = Literal[_NoDefault.no_default]
+no_default = (
+ _NoDefault.no_default
+) # Sentinel indicating the default value following pandas
+NoDefault = Literal[_NoDefault.no_default] # For typing following pandas
-def nanosecond_precision_timestamp(*args, **kwargs) ->pd.Timestamp:
+def nanosecond_precision_timestamp(*args, **kwargs) -> pd.Timestamp:
"""Return a nanosecond-precision Timestamp object.
Note this function should no longer be needed after addressing GitHub issue
#7493.
"""
- pass
+ if Version(pd.__version__) >= Version("2.0.0"):
+ return pd.Timestamp(*args, **kwargs).as_unit("ns")
+ else:
+ return pd.Timestamp(*args, **kwargs)
diff --git a/xarray/core/resample.py b/xarray/core/resample.py
index 147ca143..86b55046 100644
--- a/xarray/core/resample.py
+++ b/xarray/core/resample.py
@@ -1,13 +1,20 @@
from __future__ import annotations
+
import warnings
from collections.abc import Hashable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Callable
-from xarray.core._aggregations import DataArrayResampleAggregations, DatasetResampleAggregations
+
+from xarray.core._aggregations import (
+ DataArrayResampleAggregations,
+ DatasetResampleAggregations,
+)
from xarray.core.groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy
from xarray.core.types import Dims, InterpOptions, T_Xarray
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
+
from xarray.groupers import RESAMPLE_DIM
@@ -25,21 +32,41 @@ class Resample(GroupBy[T_Xarray]):
"""
- def __init__(self, *args, dim: (Hashable | None)=None, resample_dim: (
- Hashable | None)=None, **kwargs) ->None:
+ def __init__(
+ self,
+ *args,
+ dim: Hashable | None = None,
+ resample_dim: Hashable | None = None,
+ **kwargs,
+ ) -> None:
if dim == resample_dim:
raise ValueError(
- f"Proxy resampling dimension ('{resample_dim}') cannot have the same name as actual dimension ('{dim}')!"
- )
+ f"Proxy resampling dimension ('{resample_dim}') "
+ f"cannot have the same name as actual dimension ('{dim}')!"
+ )
self._dim = dim
+
super().__init__(*args, **kwargs)
- def _drop_coords(self) ->T_Xarray:
+ def _flox_reduce(
+ self,
+ dim: Dims,
+ keep_attrs: bool | None = None,
+ **kwargs,
+ ) -> T_Xarray:
+ result = super()._flox_reduce(dim=dim, keep_attrs=keep_attrs, **kwargs)
+ result = result.rename({RESAMPLE_DIM: self._group_dim})
+ return result
+
+ def _drop_coords(self) -> T_Xarray:
"""Drop non-dimension coordinates along the resampled dimension."""
- pass
+ obj = self._obj
+ for k, v in obj.coords.items():
+ if k != self._dim and self._dim in v.dims:
+ obj = obj.drop_vars([k])
+ return obj
- def pad(self, tolerance: (float | Iterable[float] | str | None)=None
- ) ->T_Xarray:
+ def pad(self, tolerance: float | Iterable[float] | str | None = None) -> T_Xarray:
"""Forward fill new values at up-sampled frequency.
Parameters
@@ -56,11 +83,17 @@ class Resample(GroupBy[T_Xarray]):
-------
padded : DataArray or Dataset
"""
- pass
+ obj = self._drop_coords()
+ (grouper,) = self.groupers
+ return obj.reindex(
+ {self._dim: grouper.full_index}, method="pad", tolerance=tolerance
+ )
+
ffill = pad
- def backfill(self, tolerance: (float | Iterable[float] | str | None)=None
- ) ->T_Xarray:
+ def backfill(
+ self, tolerance: float | Iterable[float] | str | None = None
+ ) -> T_Xarray:
"""Backward fill new values at up-sampled frequency.
Parameters
@@ -77,11 +110,17 @@ class Resample(GroupBy[T_Xarray]):
-------
backfilled : DataArray or Dataset
"""
- pass
+ obj = self._drop_coords()
+ (grouper,) = self.groupers
+ return obj.reindex(
+ {self._dim: grouper.full_index}, method="backfill", tolerance=tolerance
+ )
+
bfill = backfill
- def nearest(self, tolerance: (float | Iterable[float] | str | None)=None
- ) ->T_Xarray:
+ def nearest(
+ self, tolerance: float | Iterable[float] | str | None = None
+ ) -> T_Xarray:
"""Take new values from nearest original coordinate to up-sampled
frequency coordinates.
@@ -99,14 +138,19 @@ class Resample(GroupBy[T_Xarray]):
-------
upsampled : DataArray or Dataset
"""
- pass
+ obj = self._drop_coords()
+ (grouper,) = self.groupers
+ return obj.reindex(
+ {self._dim: grouper.full_index}, method="nearest", tolerance=tolerance
+ )
- def interpolate(self, kind: InterpOptions='linear', **kwargs) ->T_Xarray:
+ def interpolate(self, kind: InterpOptions = "linear", **kwargs) -> T_Xarray:
"""Interpolate up-sampled data using the original data as knots.
Parameters
----------
- kind : {"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"}, default: "linear"
+ kind : {"linear", "nearest", "zero", "slinear", \
+ "quadratic", "cubic", "polynomial"}, default: "linear"
The method used to interpolate. The method should be supported by
the scipy interpolator:
@@ -128,22 +172,38 @@ class Resample(GroupBy[T_Xarray]):
scipy.interpolate.interp1d
"""
- pass
+ return self._interpolate(kind=kind, **kwargs)
- def _interpolate(self, kind='linear', **kwargs) ->T_Xarray:
+ def _interpolate(self, kind="linear", **kwargs) -> T_Xarray:
"""Apply scipy.interpolate.interp1d along resampling dimension."""
- pass
-
-
-class DataArrayResample(Resample['DataArray'], DataArrayGroupByBase,
- DataArrayResampleAggregations):
+ obj = self._drop_coords()
+ (grouper,) = self.groupers
+ kwargs.setdefault("bounds_error", False)
+ return obj.interp(
+ coords={self._dim: grouper.full_index},
+ assume_sorted=True,
+ method=kind,
+ kwargs=kwargs,
+ )
+
+
+# https://github.com/python/mypy/issues/9031
+class DataArrayResample(Resample["DataArray"], DataArrayGroupByBase, DataArrayResampleAggregations): # type: ignore[misc]
"""DataArrayGroupBy object specialized to time resampling operations over a
specified dimension
"""
- def reduce(self, func: Callable[..., Any], dim: Dims=None, *, axis: (
- int | Sequence[int] | None)=None, keep_attrs: (bool | None)=None,
- keepdims: bool=False, shortcut: bool=True, **kwargs: Any) ->DataArray:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ shortcut: bool = True,
+ **kwargs: Any,
+ ) -> DataArray:
"""Reduce the items in this group by applying `func` along the
pre-defined resampling dimension.
@@ -168,10 +228,23 @@ class DataArrayResample(Resample['DataArray'], DataArrayGroupByBase,
Array with summarized data and the indicated dimension(s)
removed.
"""
- pass
-
- def map(self, func: Callable[..., Any], args: tuple[Any, ...]=(),
- shortcut: (bool | None)=False, **kwargs: Any) ->DataArray:
+ return super().reduce(
+ func=func,
+ dim=dim,
+ axis=axis,
+ keep_attrs=keep_attrs,
+ keepdims=keepdims,
+ shortcut=shortcut,
+ **kwargs,
+ )
+
+ def map(
+ self,
+ func: Callable[..., Any],
+ args: tuple[Any, ...] = (),
+ shortcut: bool | None = False,
+ **kwargs: Any,
+ ) -> DataArray:
"""Apply a function to each array in the group and concatenate them
together into a new array.
@@ -213,7 +286,20 @@ class DataArrayResample(Resample['DataArray'], DataArrayGroupByBase,
applied : DataArray
The result of splitting, applying and combining this array.
"""
- pass
+ # TODO: the argument order for Resample doesn't match that for its parent,
+ # GroupBy
+ combined = super().map(func, shortcut=shortcut, args=args, **kwargs)
+
+ # If the aggregation function didn't drop the original resampling
+ # dimension, then we need to do so before we can rename the proxy
+ # dimension we used.
+ if self._dim in combined.coords:
+ combined = combined.drop_vars([self._dim])
+
+ if RESAMPLE_DIM in combined.dims:
+ combined = combined.rename({RESAMPLE_DIM: self._dim})
+
+ return combined
def apply(self, func, args=(), shortcut=None, **kwargs):
"""
@@ -223,9 +309,14 @@ class DataArrayResample(Resample['DataArray'], DataArrayGroupByBase,
--------
DataArrayResample.map
"""
- pass
-
- def asfreq(self) ->DataArray:
+ warnings.warn(
+ "Resample.apply may be deprecated in the future. Using Resample.map is encouraged",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.map(func=func, shortcut=shortcut, args=args, **kwargs)
+
+ def asfreq(self) -> DataArray:
"""Return values of original object at the new up-sampling frequency;
essentially a re-index with new times set to NaN.
@@ -233,15 +324,21 @@ class DataArrayResample(Resample['DataArray'], DataArrayGroupByBase,
-------
resampled : DataArray
"""
- pass
+ self._obj = self._drop_coords()
+ return self.mean(None if self._dim is None else [self._dim])
-class DatasetResample(Resample['Dataset'], DatasetGroupByBase,
- DatasetResampleAggregations):
+# https://github.com/python/mypy/issues/9031
+class DatasetResample(Resample["Dataset"], DatasetGroupByBase, DatasetResampleAggregations): # type: ignore[misc]
"""DatasetGroupBy object specialized to resampling a specified dimension"""
- def map(self, func: Callable[..., Any], args: tuple[Any, ...]=(),
- shortcut: (bool | None)=None, **kwargs: Any) ->Dataset:
+ def map(
+ self,
+ func: Callable[..., Any],
+ args: tuple[Any, ...] = (),
+ shortcut: bool | None = None,
+ **kwargs: Any,
+ ) -> Dataset:
"""Apply a function over each Dataset in the groups generated for
resampling and concatenate them together into a new Dataset.
@@ -271,7 +368,20 @@ class DatasetResample(Resample['Dataset'], DatasetGroupByBase,
applied : Dataset
The result of splitting, applying and combining this dataset.
"""
- pass
+ # ignore shortcut if set (for now)
+ applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
+ combined = self._combine(applied)
+
+ # If the aggregation function didn't drop the original resampling
+ # dimension, then we need to do so before we can rename the proxy
+ # dimension we used.
+ if self._dim in combined.coords:
+ combined = combined.drop_vars(self._dim)
+
+ if RESAMPLE_DIM in combined.dims:
+ combined = combined.rename({RESAMPLE_DIM: self._dim})
+
+ return combined
def apply(self, func, args=(), shortcut=None, **kwargs):
"""
@@ -281,11 +391,25 @@ class DatasetResample(Resample['Dataset'], DatasetGroupByBase,
--------
DataSetResample.map
"""
- pass
- def reduce(self, func: Callable[..., Any], dim: Dims=None, *, axis: (
- int | Sequence[int] | None)=None, keep_attrs: (bool | None)=None,
- keepdims: bool=False, shortcut: bool=True, **kwargs: Any) ->Dataset:
+ warnings.warn(
+ "Resample.apply may be deprecated in the future. Using Resample.map is encouraged",
+ PendingDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.map(func=func, shortcut=shortcut, args=args, **kwargs)
+
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ shortcut: bool = True,
+ **kwargs: Any,
+ ) -> Dataset:
"""Reduce the items in this group by applying `func` along the
pre-defined resampling dimension.
@@ -310,9 +434,17 @@ class DatasetResample(Resample['Dataset'], DatasetGroupByBase,
Array with summarized data and the indicated dimension(s)
removed.
"""
- pass
-
- def asfreq(self) ->Dataset:
+ return super().reduce(
+ func=func,
+ dim=dim,
+ axis=axis,
+ keep_attrs=keep_attrs,
+ keepdims=keepdims,
+ shortcut=shortcut,
+ **kwargs,
+ )
+
+ def asfreq(self) -> Dataset:
"""Return values of original object at the new up-sampling frequency;
essentially a re-index with new times set to NaN.
@@ -320,4 +452,5 @@ class DatasetResample(Resample['Dataset'], DatasetGroupByBase,
-------
resampled : Dataset
"""
- pass
+ self._obj = self._drop_coords()
+ return self.mean(None if self._dim is None else [self._dim])
diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py
index c819839a..caa2d166 100644
--- a/xarray/core/resample_cftime.py
+++ b/xarray/core/resample_cftime.py
@@ -1,12 +1,62 @@
"""Resampling for CFTimeIndex. Does not support non-integer freq."""
+
+# The mechanisms for resampling CFTimeIndex was copied and adapted from
+# the source code defined in pandas.core.resample
+#
+# For reference, here is a copy of the pandas copyright notice:
+#
+# BSD 3-Clause License
+#
+# Copyright (c) 2008-2012, AQR Capital Management, LLC, Lambda Foundry, Inc.
+# and PyData Development Team
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
import datetime
import typing
+
import numpy as np
import pandas as pd
-from xarray.coding.cftime_offsets import BaseCFTimeOffset, MonthEnd, QuarterEnd, Tick, YearEnd, cftime_range, normalize_date, to_offset
+
+from xarray.coding.cftime_offsets import (
+ BaseCFTimeOffset,
+ MonthEnd,
+ QuarterEnd,
+ Tick,
+ YearEnd,
+ cftime_range,
+ normalize_date,
+ to_offset,
+)
from xarray.coding.cftimeindex import CFTimeIndex
from xarray.core.types import SideOptions
+
if typing.TYPE_CHECKING:
from xarray.core.types import CFTimeDatetime
@@ -15,6 +65,7 @@ class CFTimeGrouper:
"""This is a simple container for the grouping parameters that implements a
single method, the only one required for resampling in xarray. It cannot
be used in a call to groupby like a pandas.Grouper object can."""
+
freq: BaseCFTimeOffset
closed: SideOptions
label: SideOptions
@@ -22,46 +73,60 @@ class CFTimeGrouper:
origin: str | CFTimeDatetime
offset: datetime.timedelta | None
- def __init__(self, freq: (str | BaseCFTimeOffset), closed: (SideOptions |
- None)=None, label: (SideOptions | None)=None, origin: (str |
- CFTimeDatetime)='start_day', offset: (str | datetime.timedelta |
- BaseCFTimeOffset | None)=None):
+ def __init__(
+ self,
+ freq: str | BaseCFTimeOffset,
+ closed: SideOptions | None = None,
+ label: SideOptions | None = None,
+ origin: str | CFTimeDatetime = "start_day",
+ offset: str | datetime.timedelta | BaseCFTimeOffset | None = None,
+ ):
self.freq = to_offset(freq)
self.origin = origin
+
if isinstance(self.freq, (MonthEnd, QuarterEnd, YearEnd)):
if closed is None:
- self.closed = 'right'
- else:
- self.closed = closed
- if label is None:
- self.label = 'right'
- else:
- self.label = label
- elif self.origin in ['end', 'end_day']:
- if closed is None:
- self.closed = 'right'
+ self.closed = "right"
else:
self.closed = closed
if label is None:
- self.label = 'right'
+ self.label = "right"
else:
self.label = label
else:
- if closed is None:
- self.closed = 'left'
+ # The backward resample sets ``closed`` to ``'right'`` by default
+ # since the last value should be considered as the edge point for
+ # the last bin. When origin in "end" or "end_day", the value for a
+ # specific ``cftime.datetime`` index stands for the resample result
+ # from the current ``cftime.datetime`` minus ``freq`` to the current
+ # ``cftime.datetime`` with a right close.
+ if self.origin in ["end", "end_day"]:
+ if closed is None:
+ self.closed = "right"
+ else:
+ self.closed = closed
+ if label is None:
+ self.label = "right"
+ else:
+ self.label = label
else:
- self.closed = closed
- if label is None:
- self.label = 'left'
- else:
- self.label = label
+ if closed is None:
+ self.closed = "left"
+ else:
+ self.closed = closed
+ if label is None:
+ self.label = "left"
+ else:
+ self.label = label
+
if offset is not None:
try:
self.offset = _convert_offset_to_timedelta(offset)
except (ValueError, TypeError) as error:
raise ValueError(
- f'offset must be a datetime.timedelta object or an offset string that can be converted to a timedelta. Got {type(offset)} instead.'
- ) from error
+ f"offset must be a datetime.timedelta object or an offset string "
+ f"that can be converted to a timedelta. Got {type(offset)} instead."
+ ) from error
else:
self.offset = None
@@ -74,12 +139,34 @@ class CFTimeGrouper:
with index being a CFTimeIndex instead of a DatetimeIndex.
"""
- pass
-
-def _get_time_bins(index: CFTimeIndex, freq: BaseCFTimeOffset, closed:
- SideOptions, label: SideOptions, origin: (str | CFTimeDatetime), offset:
- (datetime.timedelta | None)):
+ datetime_bins, labels = _get_time_bins(
+ index, self.freq, self.closed, self.label, self.origin, self.offset
+ )
+ # check binner fits data
+ if index[0] < datetime_bins[0]:
+ raise ValueError("Value falls before first bin")
+ if index[-1] > datetime_bins[-1]:
+ raise ValueError("Value falls after last bin")
+
+ integer_bins = np.searchsorted(index, datetime_bins, side=self.closed)
+ counts = np.diff(integer_bins)
+ codes = np.repeat(np.arange(len(labels)), counts)
+ first_items = pd.Series(integer_bins[:-1], labels, copy=False)
+
+ # Mask duplicate values with NaNs, preserving the last values
+ non_duplicate = ~first_items.duplicated("last")
+ return first_items.where(non_duplicate), codes
+
+
+def _get_time_bins(
+ index: CFTimeIndex,
+ freq: BaseCFTimeOffset,
+ closed: SideOptions,
+ label: SideOptions,
+ origin: str | CFTimeDatetime,
+ offset: datetime.timedelta | None,
+):
"""Obtain the bins and their respective labels for resampling operations.
Parameters
@@ -119,12 +206,42 @@ def _get_time_bins(index: CFTimeIndex, freq: BaseCFTimeOffset, closed:
labels : CFTimeIndex
Define what the user actually sees the bins labeled as.
"""
- pass
-
-def _adjust_bin_edges(datetime_bins: CFTimeIndex, freq: BaseCFTimeOffset,
- closed: SideOptions, index: CFTimeIndex, labels: CFTimeIndex) ->tuple[
- CFTimeIndex, CFTimeIndex]:
+ if not isinstance(index, CFTimeIndex):
+ raise TypeError(
+ "index must be a CFTimeIndex, but got "
+ f"an instance of {type(index).__name__!r}"
+ )
+ if len(index) == 0:
+ datetime_bins = labels = CFTimeIndex(data=[], name=index.name)
+ return datetime_bins, labels
+
+ first, last = _get_range_edges(
+ index.min(), index.max(), freq, closed=closed, origin=origin, offset=offset
+ )
+ datetime_bins = labels = cftime_range(
+ freq=freq, start=first, end=last, name=index.name
+ )
+
+ datetime_bins, labels = _adjust_bin_edges(
+ datetime_bins, freq, closed, index, labels
+ )
+
+ labels = labels[1:] if label == "right" else labels[:-1]
+ # TODO: when CFTimeIndex supports missing values, if the reference index
+ # contains missing values, insert the appropriate NaN value at the
+ # beginning of the datetime_bins and labels indexes.
+
+ return datetime_bins, labels
+
+
+def _adjust_bin_edges(
+ datetime_bins: CFTimeIndex,
+ freq: BaseCFTimeOffset,
+ closed: SideOptions,
+ index: CFTimeIndex,
+ labels: CFTimeIndex,
+) -> tuple[CFTimeIndex, CFTimeIndex]:
"""This is required for determining the bin edges resampling with
month end, quarter end, and year end frequencies.
@@ -155,12 +272,24 @@ def _adjust_bin_edges(datetime_bins: CFTimeIndex, freq: BaseCFTimeOffset,
CFTimeIndex([2000-01-31 00:00:00, 2000-02-29 00:00:00], dtype='object')
"""
- pass
-
-
-def _get_range_edges(first: CFTimeDatetime, last: CFTimeDatetime, freq:
- BaseCFTimeOffset, closed: SideOptions='left', origin: (str |
- CFTimeDatetime)='start_day', offset: (datetime.timedelta | None)=None):
+ if isinstance(freq, (MonthEnd, QuarterEnd, YearEnd)):
+ if closed == "right":
+ datetime_bins = datetime_bins + datetime.timedelta(days=1, microseconds=-1)
+ if datetime_bins[-2] > index.max():
+ datetime_bins = datetime_bins[:-1]
+ labels = labels[:-1]
+
+ return datetime_bins, labels
+
+
+def _get_range_edges(
+ first: CFTimeDatetime,
+ last: CFTimeDatetime,
+ freq: BaseCFTimeOffset,
+ closed: SideOptions = "left",
+ origin: str | CFTimeDatetime = "start_day",
+ offset: datetime.timedelta | None = None,
+):
"""Get the correct starting and ending datetimes for the resampled
CFTimeIndex range.
@@ -198,12 +327,28 @@ def _get_range_edges(first: CFTimeDatetime, last: CFTimeDatetime, freq:
last : cftime.datetime
Corrected ending datetime object for resampled CFTimeIndex range.
"""
- pass
-
-
-def _adjust_dates_anchored(first: CFTimeDatetime, last: CFTimeDatetime,
- freq: Tick, closed: SideOptions='right', origin: (str | CFTimeDatetime)
- ='start_day', offset: (datetime.timedelta | None)=None):
+ if isinstance(freq, Tick):
+ first, last = _adjust_dates_anchored(
+ first, last, freq, closed=closed, origin=origin, offset=offset
+ )
+ return first, last
+ else:
+ first = normalize_date(first)
+ last = normalize_date(last)
+
+ first = freq.rollback(first) if closed == "left" else first - freq
+ last = last + freq
+ return first, last
+
+
+def _adjust_dates_anchored(
+ first: CFTimeDatetime,
+ last: CFTimeDatetime,
+ freq: Tick,
+ closed: SideOptions = "right",
+ origin: str | CFTimeDatetime = "start_day",
+ offset: datetime.timedelta | None = None,
+):
"""First and last offsets should be calculated from the start day to fix
an error cause by resampling across multiple days when a one day period is
not a multiple of the frequency.
@@ -243,7 +388,56 @@ def _adjust_dates_anchored(first: CFTimeDatetime, last: CFTimeDatetime,
A datetime object representing the end of a date range that has been
adjusted to fix resampling errors.
"""
- pass
+ import cftime
+
+ if origin == "start_day":
+ origin_date = normalize_date(first)
+ elif origin == "start":
+ origin_date = first
+ elif origin == "epoch":
+ origin_date = type(first)(1970, 1, 1)
+ elif origin in ["end", "end_day"]:
+ origin_last = last if origin == "end" else _ceil_via_cftimeindex(last, "D")
+ sub_freq_times = (origin_last - first) // freq.as_timedelta()
+ if closed == "left":
+ sub_freq_times += 1
+ first = origin_last - sub_freq_times * freq
+ origin_date = first
+ elif isinstance(origin, cftime.datetime):
+ origin_date = origin
+ else:
+ raise ValueError(
+ f"origin must be one of {{'epoch', 'start_day', 'start', 'end', 'end_day'}} "
+ f"or a cftime.datetime object. Got {origin}."
+ )
+
+ if offset is not None:
+ origin_date = origin_date + offset
+
+ foffset = (first - origin_date) % freq.as_timedelta()
+ loffset = (last - origin_date) % freq.as_timedelta()
+
+ if closed == "right":
+ if foffset.total_seconds() > 0:
+ fresult = first - foffset
+ else:
+ fresult = first - freq.as_timedelta()
+
+ if loffset.total_seconds() > 0:
+ lresult = last + (freq.as_timedelta() - loffset)
+ else:
+ lresult = last
+ else:
+ if foffset.total_seconds() > 0:
+ fresult = first - foffset
+ else:
+ fresult = first
+
+ if loffset.total_seconds() > 0:
+ lresult = last + (freq.as_timedelta() - loffset)
+ else:
+ lresult = last + freq
+ return fresult, lresult
def exact_cftime_datetime_difference(a: CFTimeDatetime, b: CFTimeDatetime):
@@ -280,4 +474,24 @@ def exact_cftime_datetime_difference(a: CFTimeDatetime, b: CFTimeDatetime):
-------
datetime.timedelta
"""
- pass
+ seconds = b.replace(microsecond=0) - a.replace(microsecond=0)
+ seconds = int(round(seconds.total_seconds()))
+ microseconds = b.microsecond - a.microsecond
+ return datetime.timedelta(seconds=seconds, microseconds=microseconds)
+
+
+def _convert_offset_to_timedelta(
+ offset: datetime.timedelta | str | BaseCFTimeOffset,
+) -> datetime.timedelta:
+ if isinstance(offset, datetime.timedelta):
+ return offset
+ if isinstance(offset, (str, Tick)):
+ timedelta_cftime_offset = to_offset(offset)
+ if isinstance(timedelta_cftime_offset, Tick):
+ return timedelta_cftime_offset.as_timedelta()
+ raise TypeError(f"Expected timedelta, str or Tick, got {type(offset)}")
+
+
+def _ceil_via_cftimeindex(date: CFTimeDatetime, freq: str | BaseCFTimeOffset):
+ index = CFTimeIndex([date])
+ return index.ceil(freq).item()
diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py
index 5a608958..6cf49fc9 100644
--- a/xarray/core/rolling.py
+++ b/xarray/core/rolling.py
@@ -1,28 +1,41 @@
from __future__ import annotations
+
import functools
import itertools
import math
import warnings
from collections.abc import Hashable, Iterator, Mapping
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
+
import numpy as np
from packaging.version import Version
+
from xarray.core import dtypes, duck_array_ops, utils
from xarray.core.arithmetic import CoarsenArithmetic
from xarray.core.options import OPTIONS, _get_keep_attrs
from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray
-from xarray.core.utils import either_dict_or_kwargs, is_duck_dask_array, module_available
+from xarray.core.utils import (
+ either_dict_or_kwargs,
+ is_duck_dask_array,
+ module_available,
+)
from xarray.namedarray import pycompat
+
try:
import bottleneck
except ImportError:
+ # use numpy methods instead
bottleneck = None
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
+
RollingKey = Any
- _T = TypeVar('_T')
-_ROLLING_REDUCE_DOCSTRING_TEMPLATE = """Reduce this object's data windows by applying `{name}` along its dimension.
+ _T = TypeVar("_T")
+
+_ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\
+Reduce this object's data windows by applying `{name}` along its dimension.
Parameters
----------
@@ -50,17 +63,22 @@ class Rolling(Generic[T_Xarray]):
xarray.Dataset.rolling
xarray.DataArray.rolling
"""
- __slots__ = 'obj', 'window', 'min_periods', 'center', 'dim'
- _attributes = 'window', 'min_periods', 'center', 'dim'
+
+ __slots__ = ("obj", "window", "min_periods", "center", "dim")
+ _attributes = ("window", "min_periods", "center", "dim")
dim: list[Hashable]
window: list[int]
center: list[bool]
obj: T_Xarray
min_periods: int
- def __init__(self, obj: T_Xarray, windows: Mapping[Any, int],
- min_periods: (int | None)=None, center: (bool | Mapping[Any, bool])
- =False) ->None:
+ def __init__(
+ self,
+ obj: T_Xarray,
+ windows: Mapping[Any, int],
+ min_periods: int | None = None,
+ center: bool | Mapping[Any, bool] = False,
+ ) -> None:
"""
Moving window object.
@@ -88,33 +106,49 @@ class Rolling(Generic[T_Xarray]):
for d, w in windows.items():
self.dim.append(d)
if w <= 0:
- raise ValueError('window must be > 0')
+ raise ValueError("window must be > 0")
self.window.append(w)
+
self.center = self._mapping_to_list(center, default=False)
self.obj = obj
- missing_dims = tuple(dim for dim in self.dim if dim not in self.obj
- .dims)
+
+ missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims)
if missing_dims:
+ # NOTE: we raise KeyError here but ValueError in Coarsen.
raise KeyError(
- f'Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} dimensions {tuple(self.obj.dims)}'
- )
+ f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} "
+ f"dimensions {tuple(self.obj.dims)}"
+ )
+
+ # attributes
if min_periods is not None and min_periods <= 0:
- raise ValueError('min_periods must be greater than zero or None')
- self.min_periods = math.prod(self.window
- ) if min_periods is None else min_periods
+ raise ValueError("min_periods must be greater than zero or None")
+
+ self.min_periods = (
+ math.prod(self.window) if min_periods is None else min_periods
+ )
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
"""provide a nice str repr of our rolling object"""
- attrs = ['{k}->{v}{c}'.format(k=k, v=w, c='(center)' if c else '') for
- k, w, c in zip(self.dim, self.window, self.center)]
- return '{klass} [{attrs}]'.format(klass=self.__class__.__name__,
- attrs=','.join(attrs))
- def __len__(self) ->int:
+ attrs = [
+ "{k}->{v}{c}".format(k=k, v=w, c="(center)" if c else "")
+ for k, w, c in zip(self.dim, self.window, self.center)
+ ]
+ return "{klass} [{attrs}]".format(
+ klass=self.__class__.__name__, attrs=",".join(attrs)
+ )
+
+ def __len__(self) -> int:
return math.prod(self.obj.sizes[d] for d in self.dim)
- def _reduce_method(name: str, fillna: Any, rolling_agg_func: (Callable |
- None)=None) ->Callable[..., T_Xarray]:
+ @property
+ def ndim(self) -> int:
+ return len(self.dim)
+
+ def _reduce_method( # type: ignore[misc]
+ name: str, fillna: Any, rolling_agg_func: Callable | None = None
+ ) -> Callable[..., T_Xarray]:
"""Constructs reduction methods built on a numpy reduction function (e.g. sum),
a numbagg reduction function (e.g. move_sum), a bottleneck reduction function
(e.g. move_sum), or a Rolling reduction (_mean).
@@ -124,27 +158,105 @@ class Rolling(Generic[T_Xarray]):
need context of xarray options, of the functions each library offers, of
the array (e.g. dtype).
"""
- pass
- _mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name='mean')
- argmax = _reduce_method('argmax', dtypes.NINF)
- argmin = _reduce_method('argmin', dtypes.INF)
- max = _reduce_method('max', dtypes.NINF)
- min = _reduce_method('min', dtypes.INF)
- prod = _reduce_method('prod', 1)
- sum = _reduce_method('sum', 0)
- mean = _reduce_method('mean', None, _mean)
- std = _reduce_method('std', None)
- var = _reduce_method('var', None)
- median = _reduce_method('median', None)
- count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name='count')
-
-
-class DataArrayRolling(Rolling['DataArray']):
- __slots__ = 'window_labels',
-
- def __init__(self, obj: DataArray, windows: Mapping[Any, int],
- min_periods: (int | None)=None, center: (bool | Mapping[Any, bool])
- =False) ->None:
+ if rolling_agg_func:
+ array_agg_func = None
+ else:
+ array_agg_func = getattr(duck_array_ops, name)
+
+ bottleneck_move_func = getattr(bottleneck, "move_" + name, None)
+ if module_available("numbagg"):
+ import numbagg
+
+ numbagg_move_func = getattr(numbagg, "move_" + name, None)
+ else:
+ numbagg_move_func = None
+
+ def method(self, keep_attrs=None, **kwargs):
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+
+ return self._array_reduce(
+ array_agg_func=array_agg_func,
+ bottleneck_move_func=bottleneck_move_func,
+ numbagg_move_func=numbagg_move_func,
+ rolling_agg_func=rolling_agg_func,
+ keep_attrs=keep_attrs,
+ fillna=fillna,
+ **kwargs,
+ )
+
+ method.__name__ = name
+ method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name)
+ return method
+
+ def _mean(self, keep_attrs, **kwargs):
+ result = self.sum(keep_attrs=False, **kwargs) / duck_array_ops.astype(
+ self.count(keep_attrs=False), dtype=self.obj.dtype, copy=False
+ )
+ if keep_attrs:
+ result.attrs = self.obj.attrs
+ return result
+
+ _mean.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="mean")
+
+ argmax = _reduce_method("argmax", dtypes.NINF)
+ argmin = _reduce_method("argmin", dtypes.INF)
+ max = _reduce_method("max", dtypes.NINF)
+ min = _reduce_method("min", dtypes.INF)
+ prod = _reduce_method("prod", 1)
+ sum = _reduce_method("sum", 0)
+ mean = _reduce_method("mean", None, _mean)
+ std = _reduce_method("std", None)
+ var = _reduce_method("var", None)
+ median = _reduce_method("median", None)
+
+ def _counts(self, keep_attrs: bool | None) -> T_Xarray:
+ raise NotImplementedError()
+
+ def count(self, keep_attrs: bool | None = None) -> T_Xarray:
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+ rolling_count = self._counts(keep_attrs=keep_attrs)
+ enough_periods = rolling_count >= self.min_periods
+ return rolling_count.where(enough_periods)
+
+ count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count")
+
+ def _mapping_to_list(
+ self,
+ arg: _T | Mapping[Any, _T],
+ default: _T | None = None,
+ allow_default: bool = True,
+ allow_allsame: bool = True,
+ ) -> list[_T]:
+ if utils.is_dict_like(arg):
+ if allow_default:
+ return [arg.get(d, default) for d in self.dim]
+ for d in self.dim:
+ if d not in arg:
+ raise KeyError(f"Argument has no dimension key {d}.")
+ return [arg[d] for d in self.dim]
+ if allow_allsame: # for single argument
+ return [arg] * self.ndim # type: ignore[list-item] # no check for negatives
+ if self.ndim == 1:
+ return [arg] # type: ignore[list-item] # no check for negatives
+ raise ValueError(f"Mapping argument is necessary for {self.ndim}d-rolling.")
+
+ def _get_keep_attrs(self, keep_attrs):
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ return keep_attrs
+
+
+class DataArrayRolling(Rolling["DataArray"]):
+ __slots__ = ("window_labels",)
+
+ def __init__(
+ self,
+ obj: DataArray,
+ windows: Mapping[Any, int],
+ min_periods: int | None = None,
+ center: bool | Mapping[Any, bool] = False,
+ ) -> None:
"""
Moving window object for DataArray.
You should use DataArray.rolling() method to construct this object
@@ -176,27 +288,37 @@ class DataArrayRolling(Rolling['DataArray']):
xarray.Dataset.groupby
"""
super().__init__(obj, windows, min_periods=min_periods, center=center)
+
+ # TODO legacy attribute
self.window_labels = self.obj[self.dim[0]]
- def __iter__(self) ->Iterator[tuple[DataArray, DataArray]]:
+ def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
if self.ndim > 1:
- raise ValueError('__iter__ is only supported for 1d-rolling')
+ raise ValueError("__iter__ is only supported for 1d-rolling")
+
dim0 = self.dim[0]
window0 = int(self.window[0])
offset = (window0 + 1) // 2 if self.center[0] else 1
stops = np.arange(offset, self.obj.sizes[dim0] + offset)
starts = stops - window0
- starts[:window0 - offset] = 0
+ starts[: window0 - offset] = 0
+
for label, start, stop in zip(self.window_labels, starts, stops):
window = self.obj.isel({dim0: slice(start, stop)})
+
counts = window.count(dim=[dim0])
window = window.where(counts >= self.min_periods)
- yield label, window
- def construct(self, window_dim: (Hashable | Mapping[Any, Hashable] |
- None)=None, stride: (int | Mapping[Any, int])=1, fill_value: Any=
- dtypes.NA, keep_attrs: (bool | None)=None, **window_dim_kwargs:
- Hashable) ->DataArray:
+ yield (label, window)
+
+ def construct(
+ self,
+ window_dim: Hashable | Mapping[Any, Hashable] | None = None,
+ stride: int | Mapping[Any, int] = 1,
+ fill_value: Any = dtypes.NA,
+ keep_attrs: bool | None = None,
+ **window_dim_kwargs: Hashable,
+ ) -> DataArray:
"""
Convert this rolling object to xr.DataArray,
where the window dimension is stacked as a new dimension
@@ -254,10 +376,59 @@ class DataArrayRolling(Rolling['DataArray']):
Dimensions without coordinates: a, b, window_dim
"""
- pass
- def reduce(self, func: Callable, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ return self._construct(
+ self.obj,
+ window_dim=window_dim,
+ stride=stride,
+ fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **window_dim_kwargs,
+ )
+
+ def _construct(
+ self,
+ obj: DataArray,
+ window_dim: Hashable | Mapping[Any, Hashable] | None = None,
+ stride: int | Mapping[Any, int] = 1,
+ fill_value: Any = dtypes.NA,
+ keep_attrs: bool | None = None,
+ **window_dim_kwargs: Hashable,
+ ) -> DataArray:
+ from xarray.core.dataarray import DataArray
+
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+
+ if window_dim is None:
+ if len(window_dim_kwargs) == 0:
+ raise ValueError(
+ "Either window_dim or window_dim_kwargs need to be specified."
+ )
+ window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim}
+
+ window_dims = self._mapping_to_list(
+ window_dim, allow_default=False, allow_allsame=False
+ )
+ strides = self._mapping_to_list(stride, default=1)
+
+ window = obj.variable.rolling_window(
+ self.dim, self.window, window_dims, self.center, fill_value=fill_value
+ )
+
+ attrs = obj.attrs if keep_attrs else {}
+
+ result = DataArray(
+ window,
+ dims=obj.dims + tuple(window_dims),
+ coords=obj.coords,
+ attrs=attrs,
+ name=obj.name,
+ )
+ return result.isel({d: slice(None, None, s) for d, s in zip(self.dim, strides)})
+
+ def reduce(
+ self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any
+ ) -> DataArray:
"""Reduce the items in this group by applying `func` along some
dimension(s).
@@ -309,19 +480,227 @@ class DataArrayRolling(Rolling['DataArray']):
[ 4., 9., 15., 18.]])
Dimensions without coordinates: a, b
"""
- pass
- def _counts(self, keep_attrs: (bool | None)) ->DataArray:
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+
+ rolling_dim = {
+ d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}")
+ for d in self.dim
+ }
+
+ # save memory with reductions GH4325
+ fillna = kwargs.pop("fillna", dtypes.NA)
+ if fillna is not dtypes.NA:
+ obj = self.obj.fillna(fillna)
+ else:
+ obj = self.obj
+ windows = self._construct(
+ obj, rolling_dim, keep_attrs=keep_attrs, fill_value=fillna
+ )
+
+ dim = list(rolling_dim.values())
+ result = windows.reduce(func, dim=dim, keep_attrs=keep_attrs, **kwargs)
+
+ # Find valid windows based on count.
+ counts = self._counts(keep_attrs=False)
+ return result.where(counts >= self.min_periods)
+
+ def _counts(self, keep_attrs: bool | None) -> DataArray:
"""Number of non-nan entries in each rolling window."""
- pass
+ rolling_dim = {
+ d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}")
+ for d in self.dim
+ }
+ # We use False as the fill_value instead of np.nan, since boolean
+ # array is faster to be reduced than object array.
+ # The use of skipna==False is also faster since it does not need to
+ # copy the strided array.
+ dim = list(rolling_dim.values())
+ counts = (
+ self.obj.notnull(keep_attrs=keep_attrs)
+ .rolling(
+ {d: w for d, w in zip(self.dim, self.window)},
+ center={d: self.center[i] for i, d in enumerate(self.dim)},
+ )
+ .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs)
+ .sum(dim=dim, skipna=False, keep_attrs=keep_attrs)
+ )
+ return counts
+
+ def _numbagg_reduce(self, func, keep_attrs, **kwargs):
+ # Some of this is copied from `_bottleneck_reduce`, we could reduce this as part
+ # of a wider refactor.
+
+ axis = self.obj.get_axis_num(self.dim[0])
+
+ padded = self.obj.variable
+ if self.center[0]:
+ if is_duck_dask_array(padded.data):
+ # workaround to make the padded chunk size larger than
+ # self.window - 1
+ shift = -(self.window[0] + 1) // 2
+ offset = (self.window[0] - 1) // 2
+ valid = (slice(None),) * axis + (
+ slice(offset, offset + self.obj.shape[axis]),
+ )
+ else:
+ shift = (-self.window[0] // 2) + 1
+ valid = (slice(None),) * axis + (slice(-shift, None),)
+ padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")
-class DatasetRolling(Rolling['Dataset']):
- __slots__ = 'rollings',
+ if is_duck_dask_array(padded.data) and False:
+ raise AssertionError("should not be reachable")
+ else:
+ values = func(
+ padded.data,
+ window=self.window[0],
+ min_count=self.min_periods,
+ axis=axis,
+ )
+
+ if self.center[0]:
+ values = values[valid]
+
+ attrs = self.obj.attrs if keep_attrs else {}
+
+ return self.obj.__class__(
+ values, self.obj.coords, attrs=attrs, name=self.obj.name
+ )
+
+ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
+ # bottleneck doesn't allow min_count to be 0, although it should
+ # work the same as if min_count = 1
+ # Note bottleneck only works with 1d-rolling.
+ if self.min_periods is not None and self.min_periods == 0:
+ min_count = 1
+ else:
+ min_count = self.min_periods
+
+ axis = self.obj.get_axis_num(self.dim[0])
+
+ padded = self.obj.variable
+ if self.center[0]:
+ if is_duck_dask_array(padded.data):
+ # workaround to make the padded chunk size larger than
+ # self.window - 1
+ shift = -(self.window[0] + 1) // 2
+ offset = (self.window[0] - 1) // 2
+ valid = (slice(None),) * axis + (
+ slice(offset, offset + self.obj.shape[axis]),
+ )
+ else:
+ shift = (-self.window[0] // 2) + 1
+ valid = (slice(None),) * axis + (slice(-shift, None),)
+ padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")
+
+ if is_duck_dask_array(padded.data):
+ raise AssertionError("should not be reachable")
+ else:
+ values = func(
+ padded.data, window=self.window[0], min_count=min_count, axis=axis
+ )
+ # index 0 is at the rightmost edge of the window
+ # need to reverse index here
+ # see GH #8541
+ if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
+ values = self.window[0] - 1 - values
+
+ if self.center[0]:
+ values = values[valid]
+
+ attrs = self.obj.attrs if keep_attrs else {}
+
+ return self.obj.__class__(
+ values, self.obj.coords, attrs=attrs, name=self.obj.name
+ )
+
+ def _array_reduce(
+ self,
+ array_agg_func,
+ bottleneck_move_func,
+ numbagg_move_func,
+ rolling_agg_func,
+ keep_attrs,
+ fillna,
+ **kwargs,
+ ):
+ if "dim" in kwargs:
+ warnings.warn(
+ f"Reductions are applied along the rolling dimension(s) "
+ f"'{self.dim}'. Passing the 'dim' kwarg to reduction "
+ f"operations has no effect.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+ del kwargs["dim"]
+
+ if (
+ OPTIONS["use_numbagg"]
+ and module_available("numbagg")
+ and pycompat.mod_version("numbagg") >= Version("0.6.3")
+ and numbagg_move_func is not None
+ # TODO: we could at least allow this for the equivalent of `apply_ufunc`'s
+ # "parallelized". `rolling_exp` does this, as an example (but rolling_exp is
+ # much simpler)
+ and not is_duck_dask_array(self.obj.data)
+ # Numbagg doesn't handle object arrays and generally has dtype consistency,
+ # so doesn't deal well with bool arrays which are expected to change type.
+ and self.obj.data.dtype.kind not in "ObMm"
+ # TODO: we could also allow this, probably as part of a refactoring of this
+ # module, so we can use the machinery in `self.reduce`.
+ and self.ndim == 1
+ ):
+ import numbagg
+
+ # Numbagg has a default ddof of 1. I (@max-sixty) think we should make
+ # this the default in xarray too, but until we do, don't use numbagg for
+ # std and var unless ddof is set to 1.
+ if (
+ numbagg_move_func not in [numbagg.move_std, numbagg.move_var]
+ or kwargs.get("ddof") == 1
+ ):
+ return self._numbagg_reduce(
+ numbagg_move_func, keep_attrs=keep_attrs, **kwargs
+ )
- def __init__(self, obj: Dataset, windows: Mapping[Any, int],
- min_periods: (int | None)=None, center: (bool | Mapping[Any, bool])
- =False) ->None:
+ if (
+ OPTIONS["use_bottleneck"]
+ and bottleneck_move_func is not None
+ and not is_duck_dask_array(self.obj.data)
+ and self.ndim == 1
+ ):
+ # TODO: re-enable bottleneck with dask after the issues
+ # underlying https://github.com/pydata/xarray/issues/2940 are
+ # fixed.
+ return self._bottleneck_reduce(
+ bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
+ )
+
+ if rolling_agg_func:
+ return rolling_agg_func(self, keep_attrs=self._get_keep_attrs(keep_attrs))
+
+ if fillna is not None:
+ if fillna is dtypes.INF:
+ fillna = dtypes.get_pos_infinity(self.obj.dtype, max_for_int=True)
+ elif fillna is dtypes.NINF:
+ fillna = dtypes.get_neg_infinity(self.obj.dtype, min_for_int=True)
+ kwargs.setdefault("skipna", False)
+ kwargs.setdefault("fillna", fillna)
+
+ return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs)
+
+
+class DatasetRolling(Rolling["Dataset"]):
+ __slots__ = ("rollings",)
+
+ def __init__(
+ self,
+ obj: Dataset,
+ windows: Mapping[Any, int],
+ min_periods: int | None = None,
+ center: bool | Mapping[Any, bool] = False,
+ ) -> None:
"""
Moving window object for Dataset.
You should use Dataset.rolling() method to construct this object
@@ -353,20 +732,42 @@ class DatasetRolling(Rolling['Dataset']):
xarray.DataArray.groupby
"""
super().__init__(obj, windows, min_periods, center)
+
+ # Keep each Rolling object as a dictionary
self.rollings = {}
for key, da in self.obj.data_vars.items():
+ # keeps rollings only for the dataset depending on self.dim
dims, center = [], {}
for i, d in enumerate(self.dim):
if d in da.dims:
dims.append(d)
center[d] = self.center[i]
+
if dims:
w = {d: windows[d] for d in dims}
- self.rollings[key] = DataArrayRolling(da, w, min_periods,
- center)
+ self.rollings[key] = DataArrayRolling(da, w, min_periods, center)
+
+ def _dataset_implementation(self, func, keep_attrs, **kwargs):
+ from xarray.core.dataset import Dataset
+
+ keep_attrs = self._get_keep_attrs(keep_attrs)
- def reduce(self, func: Callable, keep_attrs: (bool | None)=None, **
- kwargs: Any) ->DataArray:
+ reduced = {}
+ for key, da in self.obj.data_vars.items():
+ if any(d in da.dims for d in self.dim):
+ reduced[key] = func(self.rollings[key], keep_attrs=keep_attrs, **kwargs)
+ else:
+ reduced[key] = self.obj[key].copy()
+ # we need to delete the attrs of the copied DataArray
+ if not keep_attrs:
+ reduced[key].attrs = {}
+
+ attrs = self.obj.attrs if keep_attrs else {}
+ return Dataset(reduced, coords=self.obj.coords, attrs=attrs)
+
+ def reduce(
+ self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any
+ ) -> DataArray:
"""Reduce the items in this group by applying `func` along some
dimension(s).
@@ -388,12 +789,44 @@ class DatasetRolling(Rolling['Dataset']):
reduced : DataArray
Array with summarized data.
"""
- pass
-
- def construct(self, window_dim: (Hashable | Mapping[Any, Hashable] |
- None)=None, stride: (int | Mapping[Any, int])=1, fill_value: Any=
- dtypes.NA, keep_attrs: (bool | None)=None, **window_dim_kwargs:
- Hashable) ->Dataset:
+ return self._dataset_implementation(
+ functools.partial(DataArrayRolling.reduce, func=func),
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def _counts(self, keep_attrs: bool | None) -> Dataset:
+ return self._dataset_implementation(
+ DataArrayRolling._counts, keep_attrs=keep_attrs
+ )
+
+ def _array_reduce(
+ self,
+ array_agg_func,
+ bottleneck_move_func,
+ rolling_agg_func,
+ keep_attrs,
+ **kwargs,
+ ):
+ return self._dataset_implementation(
+ functools.partial(
+ DataArrayRolling._array_reduce,
+ array_agg_func=array_agg_func,
+ bottleneck_move_func=bottleneck_move_func,
+ rolling_agg_func=rolling_agg_func,
+ ),
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ def construct(
+ self,
+ window_dim: Hashable | Mapping[Any, Hashable] | None = None,
+ stride: int | Mapping[Any, int] = 1,
+ fill_value: Any = dtypes.NA,
+ keep_attrs: bool | None = None,
+ **window_dim_kwargs: Hashable,
+ ) -> Dataset:
"""
Convert this rolling object to xr.Dataset,
where the window dimension is stacked as a new dimension
@@ -414,7 +847,52 @@ class DatasetRolling(Rolling['Dataset']):
-------
Dataset with variables converted from rolling object.
"""
- pass
+
+ from xarray.core.dataset import Dataset
+
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+
+ if window_dim is None:
+ if len(window_dim_kwargs) == 0:
+ raise ValueError(
+ "Either window_dim or window_dim_kwargs need to be specified."
+ )
+ window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim}
+
+ window_dims = self._mapping_to_list(
+ window_dim, allow_default=False, allow_allsame=False
+ )
+ strides = self._mapping_to_list(stride, default=1)
+
+ dataset = {}
+ for key, da in self.obj.data_vars.items():
+ # keeps rollings only for the dataset depending on self.dim
+ dims = [d for d in self.dim if d in da.dims]
+ if dims:
+ wi = {d: window_dims[i] for i, d in enumerate(self.dim) if d in da.dims}
+ st = {d: strides[i] for i, d in enumerate(self.dim) if d in da.dims}
+
+ dataset[key] = self.rollings[key].construct(
+ window_dim=wi,
+ fill_value=fill_value,
+ stride=st,
+ keep_attrs=keep_attrs,
+ )
+ else:
+ dataset[key] = da.copy()
+
+ # as the DataArrays can be copied we need to delete the attrs
+ if not keep_attrs:
+ dataset[key].attrs = {}
+
+ # Need to stride coords as well. TODO: is there a better way?
+ coords = self.obj.isel(
+ {d: slice(None, None, s) for d, s in zip(self.dim, strides)}
+ ).coords
+
+ attrs = self.obj.attrs if keep_attrs else {}
+
+ return Dataset(dataset, coords=coords, attrs=attrs)
class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
@@ -425,19 +903,30 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
Dataset.coarsen
DataArray.coarsen
"""
- __slots__ = ('obj', 'boundary', 'coord_func', 'windows', 'side',
- 'trim_excess')
- _attributes = 'windows', 'side', 'trim_excess'
+
+ __slots__ = (
+ "obj",
+ "boundary",
+ "coord_func",
+ "windows",
+ "side",
+ "trim_excess",
+ )
+ _attributes = ("windows", "side", "trim_excess")
obj: T_Xarray
windows: Mapping[Hashable, int]
side: SideOptions | Mapping[Hashable, SideOptions]
boundary: CoarsenBoundaryOptions
coord_func: Mapping[Hashable, str | Callable]
- def __init__(self, obj: T_Xarray, windows: Mapping[Any, int], boundary:
- CoarsenBoundaryOptions, side: (SideOptions | Mapping[Any,
- SideOptions]), coord_func: (str | Callable | Mapping[Any, str |
- Callable])) ->None:
+ def __init__(
+ self,
+ obj: T_Xarray,
+ windows: Mapping[Any, int],
+ boundary: CoarsenBoundaryOptions,
+ side: SideOptions | Mapping[Any, SideOptions],
+ coord_func: str | Callable | Mapping[Any, str | Callable],
+ ) -> None:
"""
Moving window object.
@@ -464,30 +953,47 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
self.windows = windows
self.side = side
self.boundary = boundary
- missing_dims = tuple(dim for dim in windows.keys() if dim not in
- self.obj.dims)
+
+ missing_dims = tuple(dim for dim in windows.keys() if dim not in self.obj.dims)
if missing_dims:
raise ValueError(
- f'Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} dimensions {tuple(self.obj.dims)}'
- )
+ f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} "
+ f"dimensions {tuple(self.obj.dims)}"
+ )
+
if utils.is_dict_like(coord_func):
coord_func_map = coord_func
else:
coord_func_map = {d: coord_func for d in self.obj.dims}
for c in self.obj.coords:
if c not in coord_func_map:
- coord_func_map[c] = duck_array_ops.mean
+ coord_func_map[c] = duck_array_ops.mean # type: ignore[index]
self.coord_func = coord_func_map
- def __repr__(self) ->str:
+ def _get_keep_attrs(self, keep_attrs):
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ return keep_attrs
+
+ def __repr__(self) -> str:
"""provide a nice str repr of our coarsen object"""
- attrs = [f'{k}->{getattr(self, k)}' for k in self._attributes if
- getattr(self, k, None) is not None]
- return '{klass} [{attrs}]'.format(klass=self.__class__.__name__,
- attrs=','.join(attrs))
- def construct(self, window_dim=None, keep_attrs=None, **window_dim_kwargs
- ) ->T_Xarray:
+ attrs = [
+ f"{k}->{getattr(self, k)}"
+ for k in self._attributes
+ if getattr(self, k, None) is not None
+ ]
+ return "{klass} [{attrs}]".format(
+ klass=self.__class__.__name__, attrs=",".join(attrs)
+ )
+
+ def construct(
+ self,
+ window_dim=None,
+ keep_attrs=None,
+ **window_dim_kwargs,
+ ) -> T_Xarray:
"""
Convert this Coarsen object to a DataArray or Dataset,
where the coarsening dimension is split or reshaped to two
@@ -522,24 +1028,125 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]):
DataArrayRolling.construct
DatasetRolling.construct
"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.core.dataset import Dataset
-class DataArrayCoarsen(Coarsen['DataArray']):
+ window_dim = either_dict_or_kwargs(
+ window_dim, window_dim_kwargs, "Coarsen.construct"
+ )
+ if not window_dim:
+ raise ValueError(
+ "Either window_dim or window_dim_kwargs need to be specified."
+ )
+
+ bad_new_dims = tuple(
+ win
+ for win, dims in window_dim.items()
+ if len(dims) != 2 or isinstance(dims, str)
+ )
+ if bad_new_dims:
+ raise ValueError(
+ f"Please provide exactly two dimension names for the following coarsening dimensions: {bad_new_dims}"
+ )
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ missing_dims = set(window_dim) - set(self.windows)
+ if missing_dims:
+ raise ValueError(
+ f"'window_dim' must contain entries for all dimensions to coarsen. Missing {missing_dims}"
+ )
+ extra_windows = set(self.windows) - set(window_dim)
+ if extra_windows:
+ raise ValueError(
+ f"'window_dim' includes dimensions that will not be coarsened: {extra_windows}"
+ )
+
+ reshaped = Dataset()
+ if isinstance(self.obj, DataArray):
+ obj = self.obj._to_temp_dataset()
+ else:
+ obj = self.obj
+
+ reshaped.attrs = obj.attrs if keep_attrs else {}
+
+ for key, var in obj.variables.items():
+ reshaped_dims = tuple(
+ itertools.chain(*[window_dim.get(dim, [dim]) for dim in list(var.dims)])
+ )
+ if reshaped_dims != var.dims:
+ windows = {w: self.windows[w] for w in window_dim if w in var.dims}
+ reshaped_var, _ = var.coarsen_reshape(windows, self.boundary, self.side)
+ attrs = var.attrs if keep_attrs else {}
+ reshaped[key] = (reshaped_dims, reshaped_var, attrs)
+ else:
+ reshaped[key] = var
+
+ # should handle window_dim being unindexed
+ should_be_coords = (set(window_dim) & set(self.obj.coords)) | set(
+ self.obj.coords
+ )
+ result = reshaped.set_coords(should_be_coords)
+ if isinstance(self.obj, DataArray):
+ return self.obj._from_temp_dataset(result)
+ else:
+ return result
+
+
+class DataArrayCoarsen(Coarsen["DataArray"]):
__slots__ = ()
- _reduce_extra_args_docstring = ''
+
+ _reduce_extra_args_docstring = """"""
@classmethod
- def _reduce_method(cls, func: Callable, include_skipna: bool=False,
- numeric_only: bool=False) ->Callable[..., DataArray]:
+ def _reduce_method(
+ cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False
+ ) -> Callable[..., DataArray]:
"""
Return a wrapped function for injecting reduction methods.
see ops.inject_reduce_methods
"""
- pass
-
- def reduce(self, func: Callable, keep_attrs: (bool | None)=None, **kwargs
- ) ->DataArray:
+ kwargs: dict[str, Any] = {}
+ if include_skipna:
+ kwargs["skipna"] = None
+
+ def wrapped_func(
+ self: DataArrayCoarsen, keep_attrs: bool | None = None, **kwargs
+ ) -> DataArray:
+ from xarray.core.dataarray import DataArray
+
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+
+ reduced = self.obj.variable.coarsen(
+ self.windows, func, self.boundary, self.side, keep_attrs, **kwargs
+ )
+ coords = {}
+ for c, v in self.obj.coords.items():
+ if c == self.obj.name:
+ coords[c] = reduced
+ else:
+ if any(d in self.windows for d in v.dims):
+ coords[c] = v.variable.coarsen(
+ self.windows,
+ self.coord_func[c],
+ self.boundary,
+ self.side,
+ keep_attrs,
+ **kwargs,
+ )
+ else:
+ coords[c] = v
+ return DataArray(
+ reduced, dims=self.obj.dims, coords=coords, name=self.obj.name
+ )
+
+ return wrapped_func
+
+ def reduce(
+ self, func: Callable, keep_attrs: bool | None = None, **kwargs
+ ) -> DataArray:
"""Reduce the items in this group by applying `func` along some
dimension(s).
@@ -572,23 +1179,68 @@ class DataArrayCoarsen(Coarsen['DataArray']):
[ 9, 13]])
Dimensions without coordinates: a, b
"""
- pass
+ wrapped_func = self._reduce_method(func)
+ return wrapped_func(self, keep_attrs=keep_attrs, **kwargs)
-class DatasetCoarsen(Coarsen['Dataset']):
+class DatasetCoarsen(Coarsen["Dataset"]):
__slots__ = ()
- _reduce_extra_args_docstring = ''
+
+ _reduce_extra_args_docstring = """"""
@classmethod
- def _reduce_method(cls, func: Callable, include_skipna: bool=False,
- numeric_only: bool=False) ->Callable[..., Dataset]:
+ def _reduce_method(
+ cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False
+ ) -> Callable[..., Dataset]:
"""
Return a wrapped function for injecting reduction methods.
see ops.inject_reduce_methods
"""
- pass
+ kwargs: dict[str, Any] = {}
+ if include_skipna:
+ kwargs["skipna"] = None
+
+ def wrapped_func(
+ self: DatasetCoarsen, keep_attrs: bool | None = None, **kwargs
+ ) -> Dataset:
+ from xarray.core.dataset import Dataset
+
+ keep_attrs = self._get_keep_attrs(keep_attrs)
+
+ if keep_attrs:
+ attrs = self.obj.attrs
+ else:
+ attrs = {}
+
+ reduced = {}
+ for key, da in self.obj.data_vars.items():
+ reduced[key] = da.variable.coarsen(
+ self.windows,
+ func,
+ self.boundary,
+ self.side,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ coords = {}
+ for c, v in self.obj.coords.items():
+ # variable.coarsen returns variables not containing the window dims
+ # unchanged (maybe removes attrs)
+ coords[c] = v.variable.coarsen(
+ self.windows,
+ self.coord_func[c],
+ self.boundary,
+ self.side,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+
+ return Dataset(reduced, coords=coords, attrs=attrs)
+
+ return wrapped_func
- def reduce(self, func: Callable, keep_attrs=None, **kwargs) ->Dataset:
+ def reduce(self, func: Callable, keep_attrs=None, **kwargs) -> Dataset:
"""Reduce the items in this group by applying `func` along some
dimension(s).
@@ -611,4 +1263,5 @@ class DatasetCoarsen(Coarsen['Dataset']):
reduced : Dataset
Arrays with summarized data.
"""
- pass
+ wrapped_func = self._reduce_method(func)
+ return wrapped_func(self, keep_attrs=keep_attrs, **kwargs)
diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py
index a86a17a1..4e085a0a 100644
--- a/xarray/core/rolling_exp.py
+++ b/xarray/core/rolling_exp.py
@@ -1,8 +1,11 @@
from __future__ import annotations
+
from collections.abc import Mapping
from typing import Any, Generic
+
import numpy as np
from packaging.version import Version
+
from xarray.core.computation import apply_ufunc
from xarray.core.options import _get_keep_attrs
from xarray.core.pdcompat import count_not_none
@@ -11,12 +14,38 @@ from xarray.core.utils import module_available
from xarray.namedarray import pycompat
-def _get_alpha(com: (float | None)=None, span: (float | None)=None,
- halflife: (float | None)=None, alpha: (float | None)=None) ->float:
+def _get_alpha(
+ com: float | None = None,
+ span: float | None = None,
+ halflife: float | None = None,
+ alpha: float | None = None,
+) -> float:
"""
Convert com, span, halflife to alpha.
"""
- pass
+ valid_count = count_not_none(com, span, halflife, alpha)
+ if valid_count > 1:
+ raise ValueError("com, span, halflife, and alpha are mutually exclusive")
+
+ # Convert to alpha
+ if com is not None:
+ if com < 0:
+ raise ValueError("commust satisfy: com>= 0")
+ return 1 / (com + 1)
+ elif span is not None:
+ if span < 1:
+ raise ValueError("span must satisfy: span >= 1")
+ return 2 / (span + 1)
+ elif halflife is not None:
+ if halflife <= 0:
+ raise ValueError("halflife must satisfy: halflife > 0")
+ return 1 - np.exp(np.log(0.5) / halflife)
+ elif alpha is not None:
+ if not 0 < alpha <= 1:
+ raise ValueError("alpha must satisfy: 0 < alpha <= 1")
+ return alpha
+ else:
+ raise ValueError("Must pass one of comass, span, halflife, or alpha")
class RollingExp(Generic[T_DataWithCoords]):
@@ -41,32 +70,38 @@ class RollingExp(Generic[T_DataWithCoords]):
RollingExp : type of input argument
"""
- def __init__(self, obj: T_DataWithCoords, windows: Mapping[Any, int |
- float], window_type: str='span', min_weight: float=0.0):
- if not module_available('numbagg'):
+ def __init__(
+ self,
+ obj: T_DataWithCoords,
+ windows: Mapping[Any, int | float],
+ window_type: str = "span",
+ min_weight: float = 0.0,
+ ):
+ if not module_available("numbagg"):
raise ImportError(
- 'numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed'
- )
- elif pycompat.mod_version('numbagg') < Version('0.2.1'):
+ "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed"
+ )
+ elif pycompat.mod_version("numbagg") < Version("0.2.1"):
raise ImportError(
f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed"
- )
- elif pycompat.mod_version('numbagg') < Version('0.3.1'
- ) and min_weight > 0:
+ )
+ elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0:
raise ImportError(
f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed"
- )
+ )
+
self.obj: T_DataWithCoords = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})
self.min_weight = min_weight
+ # Don't pass min_weight=0 so we can support older versions of numbagg
kwargs = dict(alpha=self.alpha, axis=-1)
if min_weight > 0:
- kwargs['min_weight'] = min_weight
+ kwargs["min_weight"] = min_weight
self.kwargs = kwargs
- def mean(self, keep_attrs: (bool | None)=None) ->T_DataWithCoords:
+ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
"""
Exponentially weighted moving average.
@@ -85,9 +120,26 @@ class RollingExp(Generic[T_DataWithCoords]):
array([1. , 1. , 1.69230769, 1.9 , 1.96694215])
Dimensions without coordinates: x
"""
- pass
- def sum(self, keep_attrs: (bool | None)=None) ->T_DataWithCoords:
+ import numbagg
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ dim_order = self.obj.dims
+
+ return apply_ufunc(
+ numbagg.move_exp_nanmean,
+ self.obj,
+ input_core_dims=[[self.dim]],
+ kwargs=self.kwargs,
+ output_core_dims=[[self.dim]],
+ keep_attrs=keep_attrs,
+ on_missing_core_dim="copy",
+ dask="parallelized",
+ ).transpose(*dim_order)
+
+ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords:
"""
Exponentially weighted moving sum.
@@ -106,9 +158,26 @@ class RollingExp(Generic[T_DataWithCoords]):
array([1. , 1.33333333, 2.44444444, 2.81481481, 2.9382716 ])
Dimensions without coordinates: x
"""
- pass
- def std(self) ->T_DataWithCoords:
+ import numbagg
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ dim_order = self.obj.dims
+
+ return apply_ufunc(
+ numbagg.move_exp_nansum,
+ self.obj,
+ input_core_dims=[[self.dim]],
+ kwargs=self.kwargs,
+ output_core_dims=[[self.dim]],
+ keep_attrs=keep_attrs,
+ on_missing_core_dim="copy",
+ dask="parallelized",
+ ).transpose(*dim_order)
+
+ def std(self) -> T_DataWithCoords:
"""
Exponentially weighted moving standard deviation.
@@ -122,9 +191,27 @@ class RollingExp(Generic[T_DataWithCoords]):
array([ nan, 0. , 0.67936622, 0.42966892, 0.25389527])
Dimensions without coordinates: x
"""
- pass
- def var(self) ->T_DataWithCoords:
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
+ raise ImportError(
+ f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed"
+ )
+ import numbagg
+
+ dim_order = self.obj.dims
+
+ return apply_ufunc(
+ numbagg.move_exp_nanstd,
+ self.obj,
+ input_core_dims=[[self.dim]],
+ kwargs=self.kwargs,
+ output_core_dims=[[self.dim]],
+ keep_attrs=True,
+ on_missing_core_dim="copy",
+ dask="parallelized",
+ ).transpose(*dim_order)
+
+ def var(self) -> T_DataWithCoords:
"""
Exponentially weighted moving variance.
@@ -138,9 +225,25 @@ class RollingExp(Generic[T_DataWithCoords]):
array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281])
Dimensions without coordinates: x
"""
- pass
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
+ raise ImportError(
+ f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {pycompat.mod_version('numbagg')} is installed"
+ )
+ dim_order = self.obj.dims
+ import numbagg
+
+ return apply_ufunc(
+ numbagg.move_exp_nanvar,
+ self.obj,
+ input_core_dims=[[self.dim]],
+ kwargs=self.kwargs,
+ output_core_dims=[[self.dim]],
+ keep_attrs=True,
+ on_missing_core_dim="copy",
+ dask="parallelized",
+ ).transpose(*dim_order)
- def cov(self, other: T_DataWithCoords) ->T_DataWithCoords:
+ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords:
"""
Exponentially weighted moving covariance.
@@ -154,9 +257,27 @@ class RollingExp(Generic[T_DataWithCoords]):
array([ nan, 0. , 1.38461538, 0.55384615, 0.19338843])
Dimensions without coordinates: x
"""
- pass
- def corr(self, other: T_DataWithCoords) ->T_DataWithCoords:
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
+ raise ImportError(
+ f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {pycompat.mod_version('numbagg')} is installed"
+ )
+ dim_order = self.obj.dims
+ import numbagg
+
+ return apply_ufunc(
+ numbagg.move_exp_nancov,
+ self.obj,
+ other,
+ input_core_dims=[[self.dim], [self.dim]],
+ kwargs=self.kwargs,
+ output_core_dims=[[self.dim]],
+ keep_attrs=True,
+ on_missing_core_dim="copy",
+ dask="parallelized",
+ ).transpose(*dim_order)
+
+ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords:
"""
Exponentially weighted moving correlation.
@@ -170,4 +291,22 @@ class RollingExp(Generic[T_DataWithCoords]):
array([ nan, nan, nan, 0.4330127 , 0.48038446])
Dimensions without coordinates: x
"""
- pass
+
+ if pycompat.mod_version("numbagg") < Version("0.4.0"):
+ raise ImportError(
+ f"numbagg >= 0.4.0 is required for rolling_exp().corr(), currently {pycompat.mod_version('numbagg')} is installed"
+ )
+ dim_order = self.obj.dims
+ import numbagg
+
+ return apply_ufunc(
+ numbagg.move_exp_nancorr,
+ self.obj,
+ other,
+ input_core_dims=[[self.dim], [self.dim]],
+ kwargs=self.kwargs,
+ output_core_dims=[[self.dim]],
+ keep_attrs=True,
+ on_missing_core_dim="copy",
+ dask="parallelized",
+ ).transpose(*dim_order)
diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py
index 30d6e6d1..77e7ed23 100644
--- a/xarray/core/treenode.py
+++ b/xarray/core/treenode.py
@@ -1,9 +1,16 @@
from __future__ import annotations
+
import sys
from collections.abc import Iterator, Mapping
from pathlib import PurePosixPath
-from typing import TYPE_CHECKING, Generic, TypeVar
+from typing import (
+ TYPE_CHECKING,
+ Generic,
+ TypeVar,
+)
+
from xarray.core.utils import Frozen, is_dict_like
+
if TYPE_CHECKING:
from xarray.core.types import T_DataArray
@@ -25,14 +32,16 @@ class NodePath(PurePosixPath):
else:
super().__new__(PurePosixPath, *pathsegments)
if self.drive:
- raise ValueError('NodePaths cannot have drives')
- if self.root not in ['/', '']:
+ raise ValueError("NodePaths cannot have drives")
+
+ if self.root not in ["/", ""]:
raise ValueError(
'Root of NodePath can only be either "/" or "", with "" meaning the path is relative.'
- )
+ )
+ # TODO should we also forbid suffixes to avoid node names with dots in them?
-Tree = TypeVar('Tree', bound='TreeNode')
+Tree = TypeVar("Tree", bound="TreeNode")
class TreeNode(Generic[Tree]):
@@ -61,10 +70,11 @@ class TreeNode(Generic[Tree]):
(This class is heavily inspired by the anytree library's NodeMixin class.)
"""
+
_parent: Tree | None
_children: dict[str, Tree]
- def __init__(self, children: (Mapping[str, Tree] | None)=None):
+ def __init__(self, children: Mapping[str, Tree] | None = None):
"""Create a parentless node."""
self._parent = None
self._children = {}
@@ -72,107 +82,248 @@ class TreeNode(Generic[Tree]):
self.children = children
@property
- def parent(self) ->(Tree | None):
+ def parent(self) -> Tree | None:
"""Parent of this node."""
- pass
-
- def _check_loop(self, new_parent: (Tree | None)) ->None:
+ return self._parent
+
+ def _set_parent(
+ self, new_parent: Tree | None, child_name: str | None = None
+ ) -> None:
+ # TODO is it possible to refactor in a way that removes this private method?
+
+ if new_parent is not None and not isinstance(new_parent, TreeNode):
+ raise TypeError(
+ "Parent nodes must be of type DataTree or None, "
+ f"not type {type(new_parent)}"
+ )
+
+ old_parent = self._parent
+ if new_parent is not old_parent:
+ self._check_loop(new_parent)
+ self._detach(old_parent)
+ self._attach(new_parent, child_name)
+
+ def _check_loop(self, new_parent: Tree | None) -> None:
"""Checks that assignment of this new parent will not create a cycle."""
- pass
+ if new_parent is not None:
+ if new_parent is self:
+ raise InvalidTreeError(
+ f"Cannot set parent, as node {self} cannot be a parent of itself."
+ )
+
+ if self._is_descendant_of(new_parent):
+ raise InvalidTreeError(
+ "Cannot set parent, as intended parent is already a descendant of this node."
+ )
+
+ def _is_descendant_of(self, node: Tree) -> bool:
+ return any(n is self for n in node.parents)
+
+ def _detach(self, parent: Tree | None) -> None:
+ if parent is not None:
+ self._pre_detach(parent)
+ parents_children = parent.children
+ parent._children = {
+ name: child
+ for name, child in parents_children.items()
+ if child is not self
+ }
+ self._parent = None
+ self._post_detach(parent)
+
+ def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
+ if parent is not None:
+ if child_name is None:
+ raise ValueError(
+ "To directly set parent, child needs a name, but child is unnamed"
+ )
- def orphan(self) ->None:
+ self._pre_attach(parent, child_name)
+ parentchildren = parent._children
+ assert not any(
+ child is self for child in parentchildren
+ ), "Tree is corrupt."
+ parentchildren[child_name] = self
+ self._parent = parent
+ self._post_attach(parent, child_name)
+ else:
+ self._parent = None
+
+ def orphan(self) -> None:
"""Detach this node from its parent."""
- pass
+ self._set_parent(new_parent=None)
@property
- def children(self: Tree) ->Mapping[str, Tree]:
+ def children(self: Tree) -> Mapping[str, Tree]:
"""Child nodes of this node, stored under a mapping via their names."""
- pass
+ return Frozen(self._children)
+
+ @children.setter
+ def children(self: Tree, children: Mapping[str, Tree]) -> None:
+ self._check_children(children)
+ children = {**children}
+
+ old_children = self.children
+ del self.children
+ try:
+ self._pre_attach_children(children)
+ for name, child in children.items():
+ child._set_parent(new_parent=self, child_name=name)
+ self._post_attach_children(children)
+ assert len(self.children) == len(children)
+ except Exception:
+ # if something goes wrong then revert to previous children
+ self.children = old_children
+ raise
+
+ @children.deleter
+ def children(self) -> None:
+ # TODO this just detaches all the children, it doesn't actually delete them...
+ children = self.children
+ self._pre_detach_children(children)
+ for child in self.children.values():
+ child.orphan()
+ assert len(self.children) == 0
+ self._post_detach_children(children)
@staticmethod
- def _check_children(children: Mapping[str, Tree]) ->None:
+ def _check_children(children: Mapping[str, Tree]) -> None:
"""Check children for correct types and for any duplicates."""
- pass
+ if not is_dict_like(children):
+ raise TypeError(
+ "children must be a dict-like mapping from names to node objects"
+ )
+
+ seen = set()
+ for name, child in children.items():
+ if not isinstance(child, TreeNode):
+ raise TypeError(
+ f"Cannot add object {name}. It is of type {type(child)}, "
+ "but can only add children of type DataTree"
+ )
+
+ childid = id(child)
+ if childid not in seen:
+ seen.add(childid)
+ else:
+ raise InvalidTreeError(
+ f"Cannot add same node {name} multiple times as different children."
+ )
- def __repr__(self) ->str:
- return f'TreeNode(children={dict(self._children)})'
+ def __repr__(self) -> str:
+ return f"TreeNode(children={dict(self._children)})"
- def _pre_detach_children(self: Tree, children: Mapping[str, Tree]) ->None:
+ def _pre_detach_children(self: Tree, children: Mapping[str, Tree]) -> None:
"""Method call before detaching `children`."""
pass
- def _post_detach_children(self: Tree, children: Mapping[str, Tree]) ->None:
+ def _post_detach_children(self: Tree, children: Mapping[str, Tree]) -> None:
"""Method call after detaching `children`."""
pass
- def _pre_attach_children(self: Tree, children: Mapping[str, Tree]) ->None:
+ def _pre_attach_children(self: Tree, children: Mapping[str, Tree]) -> None:
"""Method call before attaching `children`."""
pass
- def _post_attach_children(self: Tree, children: Mapping[str, Tree]) ->None:
+ def _post_attach_children(self: Tree, children: Mapping[str, Tree]) -> None:
"""Method call after attaching `children`."""
pass
- def _iter_parents(self: Tree) ->Iterator[Tree]:
+ def _iter_parents(self: Tree) -> Iterator[Tree]:
"""Iterate up the tree, starting from the current node's parent."""
- pass
+ node: Tree | None = self.parent
+ while node is not None:
+ yield node
+ node = node.parent
- def iter_lineage(self: Tree) ->tuple[Tree, ...]:
+ def iter_lineage(self: Tree) -> tuple[Tree, ...]:
"""Iterate up the tree, starting from the current node."""
- pass
+ from warnings import warn
+
+ warn(
+ "`iter_lineage` has been deprecated, and in the future will raise an error."
+ "Please use `parents` from now on.",
+ DeprecationWarning,
+ )
+ return tuple((self, *self.parents))
@property
- def lineage(self: Tree) ->tuple[Tree, ...]:
+ def lineage(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the closest."""
- pass
+ from warnings import warn
+
+ warn(
+ "`lineage` has been deprecated, and in the future will raise an error."
+ "Please use `parents` from now on.",
+ DeprecationWarning,
+ )
+ return self.iter_lineage()
@property
- def parents(self: Tree) ->tuple[Tree, ...]:
+ def parents(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the closest."""
- pass
+ return tuple(self._iter_parents())
@property
- def ancestors(self: Tree) ->tuple[Tree, ...]:
+ def ancestors(self: Tree) -> tuple[Tree, ...]:
"""All parent nodes and their parent nodes, starting with the most distant."""
- pass
+
+ from warnings import warn
+
+ warn(
+ "`ancestors` has been deprecated, and in the future will raise an error."
+ "Please use `parents`. Example: `tuple(reversed(node.parents))`",
+ DeprecationWarning,
+ )
+ return tuple((*reversed(self.parents), self))
@property
- def root(self: Tree) ->Tree:
+ def root(self: Tree) -> Tree:
"""Root node of the tree"""
- pass
+ node = self
+ while node.parent is not None:
+ node = node.parent
+ return node
@property
- def is_root(self) ->bool:
+ def is_root(self) -> bool:
"""Whether this node is the tree root."""
- pass
+ return self.parent is None
@property
- def is_leaf(self) ->bool:
+ def is_leaf(self) -> bool:
"""
Whether this node is a leaf node.
Leaf nodes are defined as nodes which have no children.
"""
- pass
+ return self.children == {}
@property
- def leaves(self: Tree) ->tuple[Tree, ...]:
+ def leaves(self: Tree) -> tuple[Tree, ...]:
"""
All leaf nodes.
Leaf nodes are defined as nodes which have no children.
"""
- pass
+ return tuple([node for node in self.subtree if node.is_leaf])
@property
- def siblings(self: Tree) ->dict[str, Tree]:
+ def siblings(self: Tree) -> dict[str, Tree]:
"""
Nodes with the same parent as this node.
"""
- pass
+ if self.parent:
+ return {
+ name: child
+ for name, child in self.parent.children.items()
+ if child is not self
+ }
+ else:
+ return {}
@property
- def subtree(self: Tree) ->Iterator[Tree]:
+ def subtree(self: Tree) -> Iterator[Tree]:
"""
An iterator over all nodes in this tree, including both self and all descendants.
@@ -182,10 +333,12 @@ class TreeNode(Generic[Tree]):
--------
DataTree.descendants
"""
- pass
+ from xarray.core.iterators import LevelOrderIter
+
+ return LevelOrderIter(self)
@property
- def descendants(self: Tree) ->tuple[Tree, ...]:
+ def descendants(self: Tree) -> tuple[Tree, ...]:
"""
Child nodes and all their child nodes.
@@ -195,10 +348,12 @@ class TreeNode(Generic[Tree]):
--------
DataTree.subtree
"""
- pass
+ all_nodes = tuple(self.subtree)
+ this_node, *descendants = all_nodes
+ return tuple(descendants)
@property
- def level(self: Tree) ->int:
+ def level(self: Tree) -> int:
"""
Level of this node.
@@ -214,10 +369,10 @@ class TreeNode(Generic[Tree]):
depth
width
"""
- pass
+ return len(self.parents)
@property
- def depth(self: Tree) ->int:
+ def depth(self: Tree) -> int:
"""
Maximum level of this tree.
@@ -232,10 +387,10 @@ class TreeNode(Generic[Tree]):
level
width
"""
- pass
+ return max(node.level for node in self.root.subtree)
@property
- def width(self: Tree) ->int:
+ def width(self: Tree) -> int:
"""
Number of nodes at this level in the tree.
@@ -250,52 +405,85 @@ class TreeNode(Generic[Tree]):
level
depth
"""
- pass
+ return len([node for node in self.root.subtree if node.level == self.level])
- def _pre_detach(self: Tree, parent: Tree) ->None:
+ def _pre_detach(self: Tree, parent: Tree) -> None:
"""Method call before detaching from `parent`."""
pass
- def _post_detach(self: Tree, parent: Tree) ->None:
+ def _post_detach(self: Tree, parent: Tree) -> None:
"""Method call after detaching from `parent`."""
pass
- def _pre_attach(self: Tree, parent: Tree, name: str) ->None:
+ def _pre_attach(self: Tree, parent: Tree, name: str) -> None:
"""Method call before attaching to `parent`."""
pass
- def _post_attach(self: Tree, parent: Tree, name: str) ->None:
+ def _post_attach(self: Tree, parent: Tree, name: str) -> None:
"""Method call after attaching to `parent`."""
pass
- def get(self: Tree, key: str, default: (Tree | None)=None) ->(Tree | None):
+ def get(self: Tree, key: str, default: Tree | None = None) -> Tree | None:
"""
Return the child node with the specified key.
Only looks for the node within the immediate children of this node,
not in other nodes of the tree.
"""
- pass
+ if key in self.children:
+ return self.children[key]
+ else:
+ return default
+
+ # TODO `._walk` method to be called by both `_get_item` and `_set_item`
- def _get_item(self: Tree, path: (str | NodePath)) ->(Tree | T_DataArray):
+ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
"""
Returns the object lying at the given path.
Raises a KeyError if there is no object at the given path.
"""
- pass
+ if isinstance(path, str):
+ path = NodePath(path)
- def _set(self: Tree, key: str, val: Tree) ->None:
+ if path.root:
+ current_node = self.root
+ root, *parts = list(path.parts)
+ else:
+ current_node = self
+ parts = list(path.parts)
+
+ for part in parts:
+ if part == "..":
+ if current_node.parent is None:
+ raise KeyError(f"Could not find node at {path}")
+ else:
+ current_node = current_node.parent
+ elif part in ("", "."):
+ pass
+ else:
+ if current_node.get(part) is None:
+ raise KeyError(f"Could not find node at {path}")
+ else:
+ current_node = current_node.get(part)
+ return current_node
+
+ def _set(self: Tree, key: str, val: Tree) -> None:
"""
Set the child node with the specified key to value.
Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree.
"""
- pass
-
- def _set_item(self: Tree, path: (str | NodePath), item: (Tree |
- T_DataArray), new_nodes_along_path: bool=False, allow_overwrite:
- bool=True) ->None:
+ new_children = {**self.children, key: val}
+ self.children = new_children
+
+ def _set_item(
+ self: Tree,
+ path: str | NodePath,
+ item: Tree | T_DataArray,
+ new_nodes_along_path: bool = False,
+ allow_overwrite: bool = True,
+ ) -> None:
"""
Set a new item in the tree, overwriting anything already present at that path.
@@ -319,7 +507,51 @@ class TreeNode(Generic[Tree]):
If node cannot be reached, and new_nodes_along_path=False.
Or if a node already exists at the specified path, and allow_overwrite=False.
"""
- pass
+ if isinstance(path, str):
+ path = NodePath(path)
+
+ if not path.name:
+ raise ValueError("Can't set an item under a path which has no name")
+
+ if path.root:
+ # absolute path
+ current_node = self.root
+ root, *parts, name = path.parts
+ else:
+ # relative path
+ current_node = self
+ *parts, name = path.parts
+
+ if parts:
+ # Walk to location of new node, creating intermediate node objects as we go if necessary
+ for part in parts:
+ if part == "..":
+ if current_node.parent is None:
+ # We can't create a parent if `new_nodes_along_path=True` as we wouldn't know what to name it
+ raise KeyError(f"Could not reach node at path {path}")
+ else:
+ current_node = current_node.parent
+ elif part in ("", "."):
+ pass
+ else:
+ if part in current_node.children:
+ current_node = current_node.children[part]
+ elif new_nodes_along_path:
+ # Want child classes (i.e. DataTree) to populate tree with their own types
+ new_node = type(self)()
+ current_node._set(part, new_node)
+ current_node = current_node.children[part]
+ else:
+ raise KeyError(f"Could not reach node at path {path}")
+
+ if name in current_node.children:
+ # Deal with anything already existing at this location
+ if allow_overwrite:
+ current_node._set(name, item)
+ else:
+ raise KeyError(f"Already a node object at path {path}")
+ else:
+ current_node._set(name, item)
def __delitem__(self: Tree, key: str):
"""Remove a child node from this tree object."""
@@ -328,14 +560,14 @@ class TreeNode(Generic[Tree]):
del self._children[key]
child.orphan()
else:
- raise KeyError('Cannot delete')
+ raise KeyError("Cannot delete")
- def same_tree(self, other: Tree) ->bool:
+ def same_tree(self, other: Tree) -> bool:
"""True if other node is in the same tree as this node."""
- pass
+ return self.root is other.root
-AnyNamedNode = TypeVar('AnyNamedNode', bound='NamedNode')
+AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode")
class NamedNode(TreeNode, Generic[Tree]):
@@ -344,6 +576,7 @@ class NamedNode(TreeNode, Generic[Tree]):
Implements path-like relationships to other nodes in its tree.
"""
+
_name: str | None
_parent: Tree | None
_children: dict[str, Tree]
@@ -354,45 +587,95 @@ class NamedNode(TreeNode, Generic[Tree]):
self.name = name
@property
- def name(self) ->(str | None):
+ def name(self) -> str | None:
"""The name of this node."""
- pass
+ return self._name
+
+ @name.setter
+ def name(self, name: str | None) -> None:
+ if name is not None:
+ if not isinstance(name, str):
+ raise TypeError("node name must be a string or None")
+ if "/" in name:
+ raise ValueError("node names cannot contain forward slashes")
+ self._name = name
def __repr__(self, level=0):
- repr_value = '\t' * level + self.__str__() + '\n'
+ repr_value = "\t" * level + self.__str__() + "\n"
for child in self.children:
repr_value += self.get(child).__repr__(level + 1)
return repr_value
- def __str__(self) ->str:
- return f"NamedNode('{self.name}')" if self.name else 'NamedNode()'
+ def __str__(self) -> str:
+ return f"NamedNode('{self.name}')" if self.name else "NamedNode()"
- def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str
- ) ->None:
+ def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
"""Ensures child has name attribute corresponding to key under which it has been stored."""
- pass
+ self.name = name
@property
- def path(self) ->str:
+ def path(self) -> str:
"""Return the file-like path from the root to this node."""
- pass
+ if self.is_root:
+ return "/"
+ else:
+ root, *ancestors = tuple(reversed(self.parents))
+ # don't include name of root because (a) root might not have a name & (b) we want path relative to root.
+ names = [*(node.name for node in ancestors), self.name]
+ return "/" + "/".join(names)
- def relative_to(self: NamedNode, other: NamedNode) ->str:
+ def relative_to(self: NamedNode, other: NamedNode) -> str:
"""
Compute the relative path from this node to node `other`.
If other is not in this tree, or it's otherwise impossible, raise a ValueError.
"""
- pass
+ if not self.same_tree(other):
+ raise NotFoundInTreeError(
+ "Cannot find relative path because nodes do not lie within the same tree"
+ )
+
+ this_path = NodePath(self.path)
+ if other.path in list(parent.path for parent in (self, *self.parents)):
+ return str(this_path.relative_to(other.path))
+ else:
+ common_ancestor = self.find_common_ancestor(other)
+ path_to_common_ancestor = other._path_to_ancestor(common_ancestor)
+ return str(
+ path_to_common_ancestor / this_path.relative_to(common_ancestor.path)
+ )
- def find_common_ancestor(self, other: NamedNode) ->NamedNode:
+ def find_common_ancestor(self, other: NamedNode) -> NamedNode:
"""
Find the first common ancestor of two nodes in the same tree.
Raise ValueError if they are not in the same tree.
"""
- pass
+ if self is other:
+ return self
- def _path_to_ancestor(self, ancestor: NamedNode) ->NodePath:
+ other_paths = [op.path for op in other.parents]
+ for parent in (self, *self.parents):
+ if parent.path in other_paths:
+ return parent
+
+ raise NotFoundInTreeError(
+ "Cannot find common ancestor because nodes do not lie within the same tree"
+ )
+
+ def _path_to_ancestor(self, ancestor: NamedNode) -> NodePath:
"""Return the relative path from this node to the given ancestor node"""
- pass
+
+ if not self.same_tree(ancestor):
+ raise NotFoundInTreeError(
+ "Cannot find relative path to ancestor because nodes do not lie within the same tree"
+ )
+ if ancestor.path not in list(a.path for a in (self, *self.parents)):
+ raise NotFoundInTreeError(
+ "Cannot find relative path to ancestor because given node is not an ancestor of this node"
+ )
+
+ parents_paths = list(parent.path for parent in (self, *self.parents))
+ generation_gap = list(parents_paths).index(ancestor.path)
+ path_upwards = "../" * generation_gap if generation_gap > 0 else "."
+ return NodePath(path_upwards)
diff --git a/xarray/core/types.py b/xarray/core/types.py
index 64217fbc..591320d2 100644
--- a/xarray/core/types.py
+++ b/xarray/core/types.py
@@ -1,10 +1,22 @@
from __future__ import annotations
+
import datetime
import sys
from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, SupportsIndex, TypeVar, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Literal,
+ Protocol,
+ SupportsIndex,
+ TypeVar,
+ Union,
+)
+
import numpy as np
import pandas as pd
+
try:
if sys.version_info >= (3, 11):
from typing import Self, TypeAlias
@@ -15,8 +27,11 @@ except ImportError:
raise
else:
Self: Any = None
+
+
from numpy._typing import _SupportsDType
from numpy.typing import ArrayLike
+
if TYPE_CHECKING:
from xarray.backends.common import BackendEntrypoint
from xarray.core.alignment import Aligner
@@ -28,32 +43,62 @@ if TYPE_CHECKING:
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
from xarray.groupers import TimeResampler
+
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
+
try:
from cubed import Array as CubedArray
except ImportError:
CubedArray = np.ndarray
+
try:
from zarr.core import Array as ZarrArray
except ImportError:
ZarrArray = np.ndarray
+
+ # Anything that can be coerced to a shape tuple
_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]]
- _DTypeLikeNested = Any
- DTypeLikeSave = Union[np.dtype[Any], None, type[Any], str, tuple[
- _DTypeLikeNested, int], tuple[_DTypeLikeNested, _ShapeLike], tuple[
- _DTypeLikeNested, _DTypeLikeNested], list[Any], _SupportsDType[np.
- dtype[Any]]]
+ _DTypeLikeNested = Any # TODO: wait for support for recursive types
+
+ # Xarray requires a Mapping[Hashable, dtype] in many places which
+ # conflicts with numpys own DTypeLike (with dtypes for fields).
+ # https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike
+ # This is a copy of this DTypeLike that allows only non-Mapping dtypes.
+ DTypeLikeSave = Union[
+ np.dtype[Any],
+ # default data type (float64)
+ None,
+ # array-scalar types and generic types
+ type[Any],
+ # character codes, type strings or comma-separated fields, e.g., 'float64'
+ str,
+ # (flexible_dtype, itemsize)
+ tuple[_DTypeLikeNested, int],
+ # (fixed_dtype, shape)
+ tuple[_DTypeLikeNested, _ShapeLike],
+ # (base_dtype, new_dtype)
+ tuple[_DTypeLikeNested, _DTypeLikeNested],
+ # because numpy does the same?
+ list[Any],
+ # anything with a dtype attribute
+ _SupportsDType[np.dtype[Any]],
+ ]
+
else:
DTypeLikeSave: Any = None
+
+# https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
try:
from cftime import datetime as CFTimeDatetime
except ImportError:
CFTimeDatetime = np.datetime64
-DatetimeLike: TypeAlias = Union[pd.Timestamp, datetime.datetime, np.
- datetime64, CFTimeDatetime]
+
+DatetimeLike: TypeAlias = Union[
+ pd.Timestamp, datetime.datetime, np.datetime64, CFTimeDatetime
+]
class Alignable(Protocol):
@@ -64,83 +109,196 @@ class Alignable(Protocol):
"""
- def __len__(self) ->int:
- ...
-
- def __iter__(self) ->Iterator[Hashable]:
- ...
-
-
-T_Alignable = TypeVar('T_Alignable', bound='Alignable')
-T_Backend = TypeVar('T_Backend', bound='BackendEntrypoint')
-T_Dataset = TypeVar('T_Dataset', bound='Dataset')
-T_DataArray = TypeVar('T_DataArray', bound='DataArray')
-T_Variable = TypeVar('T_Variable', bound='Variable')
-T_Coordinates = TypeVar('T_Coordinates', bound='Coordinates')
-T_Array = TypeVar('T_Array', bound='AbstractArray')
-T_Index = TypeVar('T_Index', bound='Index')
-T_Xarray = TypeVar('T_Xarray', 'DataArray', 'Dataset')
-T_DataArrayOrSet = TypeVar('T_DataArrayOrSet', bound=Union['Dataset',
- 'DataArray'])
-T_DataWithCoords = TypeVar('T_DataWithCoords', bound='DataWithCoords')
-T_DuckArray = TypeVar('T_DuckArray', bound=Any, covariant=True)
-T_ExtensionArray = TypeVar('T_ExtensionArray', bound=pd.api.extensions.
- ExtensionArray)
-ScalarOrArray = Union['ArrayLike', np.generic]
-VarCompatible = Union['Variable', 'ScalarOrArray']
-DaCompatible = Union['DataArray', 'VarCompatible']
-DsCompatible = Union['Dataset', 'DaCompatible']
-GroupByCompatible = Union['Dataset', 'DataArray']
-Dims = Union[str, Collection[Hashable], 'ellipsis', None]
-T_ChunkDim: TypeAlias = Union[str, int, Literal['auto'], None, tuple[int, ...]]
-T_ChunkDimFreq: TypeAlias = Union['TimeResampler', T_ChunkDim]
+ @property
+ def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: ...
+
+ @property
+ def sizes(self) -> Mapping[Hashable, int]: ...
+
+ @property
+ def xindexes(self) -> Indexes[Index]: ...
+
+ def _reindex_callback(
+ self,
+ aligner: Aligner,
+ dim_pos_indexers: dict[Hashable, Any],
+ variables: dict[Hashable, Variable],
+ indexes: dict[Hashable, Index],
+ fill_value: Any,
+ exclude_dims: frozenset[Hashable],
+ exclude_vars: frozenset[Hashable],
+ ) -> Self: ...
+
+ def _overwrite_indexes(
+ self,
+ indexes: Mapping[Any, Index],
+ variables: Mapping[Any, Variable] | None = None,
+ ) -> Self: ...
+
+ def __len__(self) -> int: ...
+
+ def __iter__(self) -> Iterator[Hashable]: ...
+
+ def copy(
+ self,
+ deep: bool = False,
+ ) -> Self: ...
+
+
+T_Alignable = TypeVar("T_Alignable", bound="Alignable")
+
+T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint")
+T_Dataset = TypeVar("T_Dataset", bound="Dataset")
+T_DataArray = TypeVar("T_DataArray", bound="DataArray")
+T_Variable = TypeVar("T_Variable", bound="Variable")
+T_Coordinates = TypeVar("T_Coordinates", bound="Coordinates")
+T_Array = TypeVar("T_Array", bound="AbstractArray")
+T_Index = TypeVar("T_Index", bound="Index")
+
+# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used
+# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of
+# the same concrete type, either "DataArray" or "Dataset". This is generally preferred
+# over `T_DataArrayOrSet`, given the type system can determine the exact type.
+T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")
+
+# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or
+# "Dataset". Use it for functions that might return either type, but where the exact
+# type cannot be determined statically using the type system.
+T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"])
+
+# For working directly with `DataWithCoords`. It will only allow using methods defined
+# on `DataWithCoords`.
+T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
+
+
+# Temporary placeholder for indicating an array api compliant type.
+# hopefully in the future we can narrow this down more:
+T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True)
+
+# For typing pandas extension arrays.
+T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray)
+
+
+ScalarOrArray = Union["ArrayLike", np.generic]
+VarCompatible = Union["Variable", "ScalarOrArray"]
+DaCompatible = Union["DataArray", "VarCompatible"]
+DsCompatible = Union["Dataset", "DaCompatible"]
+GroupByCompatible = Union["Dataset", "DataArray"]
+
+# Don't change to Hashable | Collection[Hashable]
+# Read: https://github.com/pydata/xarray/issues/6142
+Dims = Union[str, Collection[Hashable], "ellipsis", None]
+
+# FYI in some cases we don't allow `None`, which this doesn't take account of.
+# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
+T_ChunkDim: TypeAlias = Union[str, int, Literal["auto"], None, tuple[int, ...]]
+T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
T_ChunksFreq: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDimFreq]]
+# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]]
T_NormalizedChunks = tuple[tuple[int, ...], ...]
+
DataVars = Mapping[Any, Any]
-ErrorOptions = Literal['raise', 'ignore']
-ErrorOptionsWithWarn = Literal['raise', 'warn', 'ignore']
-CompatOptions = Literal['identical', 'equals', 'broadcast_equals',
- 'no_conflicts', 'override', 'minimal']
-ConcatOptions = Literal['all', 'minimal', 'different']
-CombineAttrsOptions = Union[Literal['drop', 'identical', 'no_conflicts',
- 'drop_conflicts', 'override'], Callable[..., Any]]
-JoinOptions = Literal['outer', 'inner', 'left', 'right', 'exact', 'override']
-Interp1dOptions = Literal['linear', 'nearest', 'zero', 'slinear',
- 'quadratic', 'cubic', 'polynomial']
-InterpolantOptions = Literal['barycentric', 'krogh', 'pchip', 'spline', 'akima'
- ]
+
+
+ErrorOptions = Literal["raise", "ignore"]
+ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"]
+
+CompatOptions = Literal[
+ "identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal"
+]
+ConcatOptions = Literal["all", "minimal", "different"]
+CombineAttrsOptions = Union[
+ Literal["drop", "identical", "no_conflicts", "drop_conflicts", "override"],
+ Callable[..., Any],
+]
+JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override"]
+
+Interp1dOptions = Literal[
+ "linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
+]
+InterpolantOptions = Literal["barycentric", "krogh", "pchip", "spline", "akima"]
InterpOptions = Union[Interp1dOptions, InterpolantOptions]
-DatetimeUnitOptions = Literal['Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us',
- 'μs', 'ns', 'ps', 'fs', 'as', None]
-NPDatetimeUnitOptions = Literal['D', 'h', 'm', 's', 'ms', 'us', 'ns']
-QueryEngineOptions = Literal['python', 'numexpr', None]
-QueryParserOptions = Literal['pandas', 'python']
-ReindexMethodOptions = Literal['nearest', 'pad', 'ffill', 'backfill',
- 'bfill', None]
-PadModeOptions = Literal['constant', 'edge', 'linear_ramp', 'maximum',
- 'mean', 'median', 'minimum', 'reflect', 'symmetric', 'wrap']
-PadReflectOptions = Literal['even', 'odd', None]
-CFCalendar = Literal['standard', 'gregorian', 'proleptic_gregorian',
- 'noleap', '365_day', '360_day', 'julian', 'all_leap', '366_day']
-CoarsenBoundaryOptions = Literal['exact', 'trim', 'pad']
-SideOptions = Literal['left', 'right']
-InclusiveOptions = Literal['both', 'neither', 'left', 'right']
-ScaleOptions = Literal['linear', 'symlog', 'log', 'logit', None]
-HueStyleOptions = Literal['continuous', 'discrete', None]
-AspectOptions = Union[Literal['auto', 'equal'], float, None]
-ExtendOptions = Literal['neither', 'both', 'min', 'max', None]
-_T = TypeVar('_T')
-NestedSequence = Union[_T, Sequence[_T], Sequence[Sequence[_T]], Sequence[
- Sequence[Sequence[_T]]], Sequence[Sequence[Sequence[Sequence[_T]]]]]
-QuantileMethods = Literal['inverted_cdf', 'averaged_inverted_cdf',
- 'closest_observation', 'interpolated_inverted_cdf', 'hazen', 'weibull',
- 'linear', 'median_unbiased', 'normal_unbiased', 'lower', 'higher',
- 'midpoint', 'nearest']
-NetcdfWriteModes = Literal['w', 'a']
-ZarrWriteModes = Literal['w', 'w-', 'a', 'a-', 'r+', 'r']
+
+DatetimeUnitOptions = Literal[
+ "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", None
+]
+NPDatetimeUnitOptions = Literal["D", "h", "m", "s", "ms", "us", "ns"]
+
+QueryEngineOptions = Literal["python", "numexpr", None]
+QueryParserOptions = Literal["pandas", "python"]
+
+ReindexMethodOptions = Literal["nearest", "pad", "ffill", "backfill", "bfill", None]
+
+PadModeOptions = Literal[
+ "constant",
+ "edge",
+ "linear_ramp",
+ "maximum",
+ "mean",
+ "median",
+ "minimum",
+ "reflect",
+ "symmetric",
+ "wrap",
+]
+PadReflectOptions = Literal["even", "odd", None]
+
+CFCalendar = Literal[
+ "standard",
+ "gregorian",
+ "proleptic_gregorian",
+ "noleap",
+ "365_day",
+ "360_day",
+ "julian",
+ "all_leap",
+ "366_day",
+]
+
+CoarsenBoundaryOptions = Literal["exact", "trim", "pad"]
+SideOptions = Literal["left", "right"]
+InclusiveOptions = Literal["both", "neither", "left", "right"]
+
+ScaleOptions = Literal["linear", "symlog", "log", "logit", None]
+HueStyleOptions = Literal["continuous", "discrete", None]
+AspectOptions = Union[Literal["auto", "equal"], float, None]
+ExtendOptions = Literal["neither", "both", "min", "max", None]
+
+# TODO: Wait until mypy supports recursive objects in combination with typevars
+_T = TypeVar("_T")
+NestedSequence = Union[
+ _T,
+ Sequence[_T],
+ Sequence[Sequence[_T]],
+ Sequence[Sequence[Sequence[_T]]],
+ Sequence[Sequence[Sequence[Sequence[_T]]]],
+]
+
+
+QuantileMethods = Literal[
+ "inverted_cdf",
+ "averaged_inverted_cdf",
+ "closest_observation",
+ "interpolated_inverted_cdf",
+ "hazen",
+ "weibull",
+ "linear",
+ "median_unbiased",
+ "normal_unbiased",
+ "lower",
+ "higher",
+ "midpoint",
+ "nearest",
+]
+
+
+NetcdfWriteModes = Literal["w", "a"]
+ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]
+
GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
GroupIndices = tuple[GroupIndex, ...]
-Bins = Union[int, Sequence[int], Sequence[float], Sequence[pd.Timestamp],
- np.ndarray, pd.Index]
+Bins = Union[
+ int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index
+]
diff --git a/xarray/core/utils.py b/xarray/core/utils.py
index fb70a0c1..c2859632 100644
--- a/xarray/core/utils.py
+++ b/xarray/core/utils.py
@@ -1,5 +1,41 @@
"""Internal utilities; not for external use"""
+
+# Some functions in this module are derived from functions in pandas. For
+# reference, here is a copy of the pandas copyright notice:
+
+# BSD 3-Clause License
+
+# Copyright (c) 2008-2011, AQR Capital Management, LLC, Lambda Foundry, Inc. and PyData Development Team
+# All rights reserved.
+
+# Copyright (c) 2011-2022, Open source contributors.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
+
import contextlib
import functools
import inspect
@@ -10,28 +46,91 @@ import os
import re
import sys
import warnings
-from collections.abc import Collection, Container, Hashable, ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, MutableSet, Sequence, ValuesView
+from collections.abc import (
+ Collection,
+ Container,
+ Hashable,
+ ItemsView,
+ Iterable,
+ Iterator,
+ KeysView,
+ Mapping,
+ MutableMapping,
+ MutableSet,
+ Sequence,
+ ValuesView,
+)
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, overload
+
import numpy as np
import pandas as pd
-from xarray.namedarray.utils import ReprObject, drop_missing_dims, either_dict_or_kwargs, infix_dims, is_dask_collection, is_dict_like, is_duck_array, is_duck_dask_array, module_available, to_0d_object_array
+
+from xarray.namedarray.utils import ( # noqa: F401
+ ReprObject,
+ drop_missing_dims,
+ either_dict_or_kwargs,
+ infix_dims,
+ is_dask_collection,
+ is_dict_like,
+ is_duck_array,
+ is_duck_dask_array,
+ module_available,
+ to_0d_object_array,
+)
+
if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn
-K = TypeVar('K')
-V = TypeVar('V')
-T = TypeVar('T')
+K = TypeVar("K")
+V = TypeVar("V")
+T = TypeVar("T")
+
+
+def alias_message(old_name: str, new_name: str) -> str:
+ return f"{old_name} has been deprecated. Use {new_name} instead."
-def get_valid_numpy_dtype(array: (np.ndarray | pd.Index)) ->np.dtype:
+
+def alias_warning(old_name: str, new_name: str, stacklevel: int = 3) -> None:
+ warnings.warn(
+ alias_message(old_name, new_name), FutureWarning, stacklevel=stacklevel
+ )
+
+
+def alias(obj: Callable[..., T], old_name: str) -> Callable[..., T]:
+ assert isinstance(old_name, str)
+
+ @functools.wraps(obj)
+ def wrapper(*args, **kwargs):
+ alias_warning(old_name, obj.__name__)
+ return obj(*args, **kwargs)
+
+ wrapper.__doc__ = alias_message(old_name, obj.__name__)
+ return wrapper
+
+
+def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype:
"""Return a numpy compatible dtype from either
a numpy array or a pandas.Index.
Used for wrapping a pandas.Index as an xarray.Variable.
"""
- pass
+ if isinstance(array, pd.PeriodIndex):
+ return np.dtype("O")
+
+ if hasattr(array, "categories"):
+ # category isn't a real numpy dtype
+ dtype = array.categories.dtype
+ if not is_valid_numpy_dtype(dtype):
+ dtype = np.dtype("O")
+ return dtype
+
+ if not is_valid_numpy_dtype(array.dtype):
+ return np.dtype("O")
+
+ return array.dtype # type: ignore[return-value]
def maybe_coerce_to_str(index, original_coords):
@@ -39,7 +138,17 @@ def maybe_coerce_to_str(index, original_coords):
pd.Index uses object-dtype to store str - try to avoid this for coords
"""
- pass
+ from xarray.core import dtypes
+
+ try:
+ result_type = dtypes.result_type(*original_coords)
+ except TypeError:
+ pass
+ else:
+ if result_type.kind in "SU":
+ index = np.asarray(index, dtype=result_type.type)
+
+ return index
def maybe_wrap_array(original, new_array):
@@ -48,26 +157,53 @@ def maybe_wrap_array(original, new_array):
This lets us treat arbitrary functions that take and return ndarray objects
like ufuncs, as long as they return an array with the same shape.
"""
- pass
+ # in case func lost array's metadata
+ if isinstance(new_array, np.ndarray) and new_array.shape == original.shape:
+ return original.__array_wrap__(new_array)
+ else:
+ return new_array
-def equivalent(first: T, second: T) ->bool:
+def equivalent(first: T, second: T) -> bool:
"""Compare two objects for equivalence (identity or equality), using
array_equiv if either object is an ndarray. If both objects are lists,
equivalent is sequentially called on all the elements.
"""
- pass
+ # TODO: refactor to avoid circular import
+ from xarray.core import duck_array_ops
+ if first is second:
+ return True
+ if isinstance(first, np.ndarray) or isinstance(second, np.ndarray):
+ return duck_array_ops.array_equiv(first, second)
+ if isinstance(first, list) or isinstance(second, list):
+ return list_equiv(first, second) # type: ignore[arg-type]
+ return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload]
-def peek_at(iterable: Iterable[T]) ->tuple[T, Iterator[T]]:
+
+def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool:
+ if len(first) != len(second):
+ return False
+ for f, s in zip(first, second):
+ if not equivalent(f, s):
+ return False
+ return True
+
+
+def peek_at(iterable: Iterable[T]) -> tuple[T, Iterator[T]]:
"""Returns the first value from iterable, as well as a new iterator with
the same content as the original iterable
"""
- pass
+ gen = iter(iterable)
+ peek = next(gen)
+ return peek, itertools.chain([peek], gen)
-def update_safety_check(first_dict: Mapping[K, V], second_dict: Mapping[K,
- V], compat: Callable[[V, V], bool]=equivalent) ->None:
+def update_safety_check(
+ first_dict: Mapping[K, V],
+ second_dict: Mapping[K, V],
+ compat: Callable[[V, V], bool] = equivalent,
+) -> None:
"""Check the safety of updating one dictionary with another.
Raises ValueError if dictionaries have non-compatible values for any key,
@@ -83,11 +219,19 @@ def update_safety_check(first_dict: Mapping[K, V], second_dict: Mapping[K,
Binary operator to determine if two values are compatible. By default,
checks for equivalence.
"""
- pass
-
-
-def remove_incompatible_items(first_dict: MutableMapping[K, V], second_dict:
- Mapping[K, V], compat: Callable[[V, V], bool]=equivalent) ->None:
+ for k, v in second_dict.items():
+ if k in first_dict and not compat(v, first_dict[k]):
+ raise ValueError(
+ "unsafe to merge dictionaries without "
+ f"overriding values; conflicting key {k!r}"
+ )
+
+
+def remove_incompatible_items(
+ first_dict: MutableMapping[K, V],
+ second_dict: Mapping[K, V],
+ compat: Callable[[V, V], bool] = equivalent,
+) -> None:
"""Remove incompatible items from the first dictionary in-place.
Items are retained if their keys are found in both dictionaries and the
@@ -101,9 +245,38 @@ def remove_incompatible_items(first_dict: MutableMapping[K, V], second_dict:
Binary operator to determine if two values are compatible. By default,
checks for equivalence.
"""
- pass
+ for k in list(first_dict):
+ if k not in second_dict or not compat(first_dict[k], second_dict[k]):
+ del first_dict[k]
+
+
+def is_full_slice(value: Any) -> bool:
+ return isinstance(value, slice) and value == slice(None)
+
+def is_list_like(value: Any) -> TypeGuard[list | tuple]:
+ return isinstance(value, (list, tuple))
+
+def _is_scalar(value, include_0d):
+ from xarray.core.variable import NON_NUMPY_SUPPORTED_ARRAY_TYPES
+
+ if include_0d:
+ include_0d = getattr(value, "ndim", None) == 0
+ return (
+ include_0d
+ or isinstance(value, (str, bytes))
+ or not (
+ isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES)
+ or hasattr(value, "__array_function__")
+ or hasattr(value, "__array_namespace__")
+ )
+ )
+
+
+# See GH5624, this is a convoluted way to allow type-checking to use `TypeGuard` without
+# requiring typing_extensions as a required dependency to _run_ the code (it is required
+# to type-check).
try:
if sys.version_info >= (3, 10):
from typing import TypeGuard
@@ -114,29 +287,45 @@ except ImportError:
raise
else:
- def is_scalar(value: Any, include_0d: bool=True) ->bool:
+ def is_scalar(value: Any, include_0d: bool = True) -> bool:
"""Whether to treat a value as a scalar.
Any non-iterable, string, or 0-D array
"""
- pass
+ return _is_scalar(value, include_0d)
+
else:
- def is_scalar(value: Any, include_0d: bool=True) ->TypeGuard[Hashable]:
+ def is_scalar(value: Any, include_0d: bool = True) -> TypeGuard[Hashable]:
"""Whether to treat a value as a scalar.
Any non-iterable, string, or 0-D array
"""
- pass
+ return _is_scalar(value, include_0d)
-def to_0d_array(value: Any) ->np.ndarray:
+def is_valid_numpy_dtype(dtype: Any) -> bool:
+ try:
+ np.dtype(dtype)
+ except (TypeError, ValueError):
+ return False
+ else:
+ return True
+
+
+def to_0d_array(value: Any) -> np.ndarray:
"""Given a value, wrap it in a 0-D numpy.ndarray."""
- pass
+ if np.isscalar(value) or (isinstance(value, np.ndarray) and value.ndim == 0):
+ return np.array(value)
+ else:
+ return to_0d_object_array(value)
-def dict_equiv(first: Mapping[K, V], second: Mapping[K, V], compat:
- Callable[[V, V], bool]=equivalent) ->bool:
+def dict_equiv(
+ first: Mapping[K, V],
+ second: Mapping[K, V],
+ compat: Callable[[V, V], bool] = equivalent,
+) -> bool:
"""Test equivalence of two dict-like objects. If any of the values are
numpy arrays, compare them correctly.
@@ -153,12 +342,17 @@ def dict_equiv(first: Mapping[K, V], second: Mapping[K, V], compat:
equals : bool
True if the dictionaries are equal
"""
- pass
+ for k in first:
+ if k not in second or not compat(first[k], second[k]):
+ return False
+ return all(k in first for k in second)
-def compat_dict_intersection(first_dict: Mapping[K, V], second_dict:
- Mapping[K, V], compat: Callable[[V, V], bool]=equivalent) ->MutableMapping[
- K, V]:
+def compat_dict_intersection(
+ first_dict: Mapping[K, V],
+ second_dict: Mapping[K, V],
+ compat: Callable[[V, V], bool] = equivalent,
+) -> MutableMapping[K, V]:
"""Return the intersection of two dictionaries as a new dictionary.
Items are retained if their keys are found in both dictionaries and the
@@ -177,11 +371,16 @@ def compat_dict_intersection(first_dict: Mapping[K, V], second_dict:
intersection : dict
Intersection of the contents.
"""
- pass
+ new_dict = dict(first_dict)
+ remove_incompatible_items(new_dict, second_dict, compat)
+ return new_dict
-def compat_dict_union(first_dict: Mapping[K, V], second_dict: Mapping[K, V],
- compat: Callable[[V, V], bool]=equivalent) ->MutableMapping[K, V]:
+def compat_dict_union(
+ first_dict: Mapping[K, V],
+ second_dict: Mapping[K, V],
+ compat: Callable[[V, V], bool] = equivalent,
+) -> MutableMapping[K, V]:
"""Return the union of two dictionaries as a new dictionary.
An exception is raised if any keys are found in both dictionaries and the
@@ -200,7 +399,10 @@ def compat_dict_union(first_dict: Mapping[K, V], second_dict: Mapping[K, V],
union : dict
union of the contents.
"""
- pass
+ new_dict = dict(first_dict)
+ update_safety_check(first_dict, second_dict, compat)
+ new_dict.update(second_dict)
+ return new_dict
class Frozen(Mapping[K, V]):
@@ -208,25 +410,30 @@ class Frozen(Mapping[K, V]):
immutable. If you really want to modify the mapping, the mutable version is
saved under the `mapping` attribute.
"""
- __slots__ = 'mapping',
+
+ __slots__ = ("mapping",)
def __init__(self, mapping: Mapping[K, V]):
self.mapping = mapping
- def __getitem__(self, key: K) ->V:
+ def __getitem__(self, key: K) -> V:
return self.mapping[key]
- def __iter__(self) ->Iterator[K]:
+ def __iter__(self) -> Iterator[K]:
return iter(self.mapping)
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self.mapping)
- def __contains__(self, key: object) ->bool:
+ def __contains__(self, key: object) -> bool:
return key in self.mapping
- def __repr__(self) ->str:
- return f'{type(self).__name__}({self.mapping!r})'
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}({self.mapping!r})"
+
+
+def FrozenDict(*args, **kwargs) -> Frozen:
+ return Frozen(dict(*args, **kwargs))
class FrozenMappingWarningOnValuesAccess(Frozen[K, V]):
@@ -240,12 +447,43 @@ class FrozenMappingWarningOnValuesAccess(Frozen[K, V]):
of ds.dims is used like a dictionary (i.e. it doesn't raise a warning if used in a way that
would also be valid for a FrozenSet, e.g. iteration).
"""
- __slots__ = 'mapping',
- def __getitem__(self, key: K) ->V:
+ __slots__ = ("mapping",)
+
+ def _warn(self) -> None:
+ emit_user_level_warning(
+ "The return type of `Dataset.dims` will be changed to return a set of dimension names in future, "
+ "in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, "
+ "please use `Dataset.sizes`.",
+ FutureWarning,
+ )
+
+ def __getitem__(self, key: K) -> V:
self._warn()
return super().__getitem__(key)
+ @overload
+ def get(self, key: K, /) -> V | None: ...
+
+ @overload
+ def get(self, key: K, /, default: V | T) -> V | T: ...
+
+ def get(self, key: K, default: T | None = None) -> V | T | None:
+ self._warn()
+ return super().get(key, default)
+
+ def keys(self) -> KeysView[K]:
+ self._warn()
+ return super().keys()
+
+ def items(self) -> ItemsView[K, V]:
+ self._warn()
+ return super().items()
+
+ def values(self) -> ValuesView[V]:
+ self._warn()
+ return super().values()
+
class HybridMappingProxy(Mapping[K, V]):
"""Implements the Mapping interface. Uses the wrapped mapping for item lookup
@@ -259,19 +497,20 @@ class HybridMappingProxy(Mapping[K, V]):
and `mapping`. It is the caller's responsibility to ensure that they are
suitable for the task at hand.
"""
- __slots__ = '_keys', 'mapping'
+
+ __slots__ = ("_keys", "mapping")
def __init__(self, keys: Collection[K], mapping: Mapping[K, V]):
self._keys = keys
self.mapping = mapping
- def __getitem__(self, key: K) ->V:
+ def __getitem__(self, key: K) -> V:
return self.mapping[key]
- def __iter__(self) ->Iterator[K]:
+ def __iter__(self) -> Iterator[K]:
return iter(self._keys)
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self._keys)
@@ -281,35 +520,51 @@ class OrderedSet(MutableSet[T]):
The API matches the builtin set, but it preserves insertion order of elements, like
a dict. Note that, unlike in an OrderedDict, equality tests are not order-sensitive.
"""
+
_d: dict[T, None]
- __slots__ = '_d',
- def __init__(self, values: (Iterable[T] | None)=None):
+ __slots__ = ("_d",)
+
+ def __init__(self, values: Iterable[T] | None = None):
self._d = {}
if values is not None:
self.update(values)
- def __contains__(self, value: Hashable) ->bool:
+ # Required methods for MutableSet
+
+ def __contains__(self, value: Hashable) -> bool:
return value in self._d
- def __iter__(self) ->Iterator[T]:
+ def __iter__(self) -> Iterator[T]:
return iter(self._d)
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self._d)
- def __repr__(self) ->str:
- return f'{type(self).__name__}({list(self)!r})'
+ def add(self, value: T) -> None:
+ self._d[value] = None
+
+ def discard(self, value: T) -> None:
+ del self._d[value]
+
+ # Additional methods
+
+ def update(self, values: Iterable[T]) -> None:
+ self._d.update(dict.fromkeys(values))
+
+ def __repr__(self) -> str:
+ return f"{type(self).__name__}({list(self)!r})"
class NdimSizeLenMixin:
"""Mixin class that extends a class that defines a ``shape`` property to
one that also defines ``ndim``, ``size`` and ``__len__``.
"""
+
__slots__ = ()
@property
- def ndim(self: Any) ->int:
+ def ndim(self: Any) -> int:
"""
Number of array dimensions.
@@ -317,10 +572,10 @@ class NdimSizeLenMixin:
--------
numpy.ndarray.ndim
"""
- pass
+ return len(self.shape)
@property
- def size(self: Any) ->int:
+ def size(self: Any) -> int:
"""
Number of elements in the array.
@@ -330,13 +585,13 @@ class NdimSizeLenMixin:
--------
numpy.ndarray.size
"""
- pass
+ return math.prod(self.shape)
- def __len__(self: Any) ->int:
+ def __len__(self: Any) -> int:
try:
return self.shape[0]
except IndexError:
- raise TypeError('len() of unsized object')
+ raise TypeError("len() of unsized object")
class NDArrayMixin(NdimSizeLenMixin):
@@ -346,13 +601,22 @@ class NDArrayMixin(NdimSizeLenMixin):
A subclass should set the `array` property and override one or more of
`dtype`, `shape` and `__getitem__`.
"""
+
__slots__ = ()
+ @property
+ def dtype(self: Any) -> np.dtype:
+ return self.array.dtype
+
+ @property
+ def shape(self: Any) -> tuple[int, ...]:
+ return self.array.shape
+
def __getitem__(self: Any, key):
return self.array[key]
- def __repr__(self: Any) ->str:
- return f'{type(self).__name__}(array={self.array!r})'
+ def __repr__(self: Any) -> str:
+ return f"{type(self).__name__}(array={self.array!r})"
@contextlib.contextmanager
@@ -360,19 +624,58 @@ def close_on_error(f):
"""Context manager to ensure that a file opened by xarray is closed if an
exception is raised before the user sees the file object.
"""
- pass
+ try:
+ yield
+ except Exception:
+ f.close()
+ raise
-def is_remote_uri(path: str) ->bool:
+def is_remote_uri(path: str) -> bool:
"""Finds URLs of the form protocol:// or protocol::
This also matches for http[s]://, which were the only remote URLs
supported in <=v0.16.2.
"""
- pass
+ return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path))
+
+
+def read_magic_number_from_file(filename_or_obj, count=8) -> bytes:
+ # check byte header to determine file type
+ if isinstance(filename_or_obj, bytes):
+ magic_number = filename_or_obj[:count]
+ elif isinstance(filename_or_obj, io.IOBase):
+ if filename_or_obj.tell() != 0:
+ filename_or_obj.seek(0)
+ magic_number = filename_or_obj.read(count)
+ filename_or_obj.seek(0)
+ else:
+ raise TypeError(f"cannot read the magic number from {type(filename_or_obj)}")
+ return magic_number
+
+
+def try_read_magic_number_from_path(pathlike, count=8) -> bytes | None:
+ if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"):
+ path = os.fspath(pathlike)
+ try:
+ with open(path, "rb") as f:
+ return read_magic_number_from_file(f, count)
+ except (FileNotFoundError, TypeError):
+ pass
+ return None
-def is_uniform_spaced(arr, **kwargs) ->bool:
+def try_read_magic_number_from_file_or_path(filename_or_obj, count=8) -> bytes | None:
+ magic_number = try_read_magic_number_from_path(filename_or_obj, count)
+ if magic_number is None:
+ try:
+ magic_number = read_magic_number_from_file(filename_or_obj, count)
+ except TypeError:
+ pass
+ return magic_number
+
+
+def is_uniform_spaced(arr, **kwargs) -> bool:
"""Return True if values of an array are uniformly spaced and sorted.
>>> is_uniform_spaced(range(5))
@@ -382,68 +685,100 @@ def is_uniform_spaced(arr, **kwargs) ->bool:
kwargs are additional arguments to ``np.isclose``
"""
- pass
+ arr = np.array(arr, dtype=float)
+ diffs = np.diff(arr)
+ return bool(np.isclose(diffs.min(), diffs.max(), **kwargs))
-def hashable(v: Any) ->TypeGuard[Hashable]:
+def hashable(v: Any) -> TypeGuard[Hashable]:
"""Determine whether `v` can be hashed."""
- pass
+ try:
+ hash(v)
+ except TypeError:
+ return False
+ return True
-def iterable(v: Any) ->TypeGuard[Iterable[Any]]:
+def iterable(v: Any) -> TypeGuard[Iterable[Any]]:
"""Determine whether `v` is iterable."""
- pass
+ try:
+ iter(v)
+ except TypeError:
+ return False
+ return True
-def iterable_of_hashable(v: Any) ->TypeGuard[Iterable[Hashable]]:
+def iterable_of_hashable(v: Any) -> TypeGuard[Iterable[Hashable]]:
"""Determine whether `v` is an Iterable of Hashables."""
- pass
+ try:
+ it = iter(v)
+ except TypeError:
+ return False
+ return all(hashable(elm) for elm in it)
-def decode_numpy_dict_values(attrs: Mapping[K, V]) ->dict[K, V]:
+def decode_numpy_dict_values(attrs: Mapping[K, V]) -> dict[K, V]:
"""Convert attribute values from numpy objects to native Python objects,
for use in to_dict
"""
- pass
+ attrs = dict(attrs)
+ for k, v in attrs.items():
+ if isinstance(v, np.ndarray):
+ attrs[k] = v.tolist()
+ elif isinstance(v, np.generic):
+ attrs[k] = v.item()
+ return attrs
def ensure_us_time_resolution(val):
"""Convert val out of numpy time, for use in to_dict.
Needed because of numpy bug GH#7619"""
- pass
+ if np.issubdtype(val.dtype, np.datetime64):
+ val = val.astype("datetime64[us]")
+ elif np.issubdtype(val.dtype, np.timedelta64):
+ val = val.astype("timedelta64[us]")
+ return val
class HiddenKeyDict(MutableMapping[K, V]):
"""Acts like a normal dictionary, but hides certain keys."""
- __slots__ = '_data', '_hidden_keys'
+
+ __slots__ = ("_data", "_hidden_keys")
+
+ # ``__init__`` method required to create instance from class.
def __init__(self, data: MutableMapping[K, V], hidden_keys: Iterable[K]):
self._data = data
self._hidden_keys = frozenset(hidden_keys)
- def __setitem__(self, key: K, value: V) ->None:
+ def _raise_if_hidden(self, key: K) -> None:
+ if key in self._hidden_keys:
+ raise KeyError(f"Key `{key!r}` is hidden.")
+
+ # The next five methods are requirements of the ABC.
+ def __setitem__(self, key: K, value: V) -> None:
self._raise_if_hidden(key)
self._data[key] = value
- def __getitem__(self, key: K) ->V:
+ def __getitem__(self, key: K) -> V:
self._raise_if_hidden(key)
return self._data[key]
- def __delitem__(self, key: K) ->None:
+ def __delitem__(self, key: K) -> None:
self._raise_if_hidden(key)
del self._data[key]
- def __iter__(self) ->Iterator[K]:
+ def __iter__(self) -> Iterator[K]:
for k in self._data:
if k not in self._hidden_keys:
yield k
- def __len__(self) ->int:
+ def __len__(self) -> int:
num_hidden = len(self._hidden_keys & self._data.keys())
return len(self._data) - num_hidden
-def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) ->Hashable:
+def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
"""Get an new dimension name based on new_dim, that is not used in dims.
If the same name exists, we add an underscore(s) in the head.
@@ -456,12 +791,16 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) ->Hashable:
new_dim: ['_rolling']
-> ['__rolling']
"""
- pass
+ while new_dim in dims:
+ new_dim = "_" + str(new_dim)
+ return new_dim
-def drop_dims_from_indexers(indexers: Mapping[Any, Any], dims: (Iterable[
- Hashable] | Mapping[Any, int]), missing_dims: ErrorOptionsWithWarn
- ) ->Mapping[Hashable, Any]:
+def drop_dims_from_indexers(
+ indexers: Mapping[Any, Any],
+ dims: Iterable[Hashable] | Mapping[Any, int],
+ missing_dims: ErrorOptionsWithWarn,
+) -> Mapping[Hashable, Any]:
"""Depending on the setting of missing_dims, drop any dimensions from indexers that
are not present in dims.
@@ -471,12 +810,66 @@ def drop_dims_from_indexers(indexers: Mapping[Any, Any], dims: (Iterable[
dims : sequence
missing_dims : {"raise", "warn", "ignore"}
"""
- pass
+ if missing_dims == "raise":
+ invalid = indexers.keys() - set(dims)
+ if invalid:
+ raise ValueError(
+ f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
+ )
+
+ return indexers
-def parse_dims(dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists:
- bool=True, replace_none: bool=True) ->(tuple[Hashable, ...] | None |
- ellipsis):
+ elif missing_dims == "warn":
+ # don't modify input
+ indexers = dict(indexers)
+
+ invalid = indexers.keys() - set(dims)
+ if invalid:
+ warnings.warn(
+ f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
+ )
+ for key in invalid:
+ indexers.pop(key)
+
+ return indexers
+
+ elif missing_dims == "ignore":
+ return {key: val for key, val in indexers.items() if key in dims}
+
+ else:
+ raise ValueError(
+ f"Unrecognised option {missing_dims} for missing_dims argument"
+ )
+
+
+@overload
+def parse_dims(
+ dim: Dims,
+ all_dims: tuple[Hashable, ...],
+ *,
+ check_exists: bool = True,
+ replace_none: Literal[True] = True,
+) -> tuple[Hashable, ...]: ...
+
+
+@overload
+def parse_dims(
+ dim: Dims,
+ all_dims: tuple[Hashable, ...],
+ *,
+ check_exists: bool = True,
+ replace_none: Literal[False],
+) -> tuple[Hashable, ...] | None | ellipsis: ...
+
+
+def parse_dims(
+ dim: Dims,
+ all_dims: tuple[Hashable, ...],
+ *,
+ check_exists: bool = True,
+ replace_none: bool = True,
+) -> tuple[Hashable, ...] | None | ellipsis:
"""Parse one or more dimensions.
A single dimension must be always a str, multiple dimensions
@@ -500,12 +893,44 @@ def parse_dims(dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists:
parsed_dims : tuple of Hashable
Input dimensions as a tuple.
"""
- pass
-
-
-def parse_ordered_dims(dim: Dims, all_dims: tuple[Hashable, ...], *,
- check_exists: bool=True, replace_none: bool=True) ->(tuple[Hashable,
- ...] | None | ellipsis):
+ if dim is None or dim is ...:
+ if replace_none:
+ return all_dims
+ return dim
+ if isinstance(dim, str):
+ dim = (dim,)
+ if check_exists:
+ _check_dims(set(dim), set(all_dims))
+ return tuple(dim)
+
+
+@overload
+def parse_ordered_dims(
+ dim: Dims,
+ all_dims: tuple[Hashable, ...],
+ *,
+ check_exists: bool = True,
+ replace_none: Literal[True] = True,
+) -> tuple[Hashable, ...]: ...
+
+
+@overload
+def parse_ordered_dims(
+ dim: Dims,
+ all_dims: tuple[Hashable, ...],
+ *,
+ check_exists: bool = True,
+ replace_none: Literal[False],
+) -> tuple[Hashable, ...] | None | ellipsis: ...
+
+
+def parse_ordered_dims(
+ dim: Dims,
+ all_dims: tuple[Hashable, ...],
+ *,
+ check_exists: bool = True,
+ replace_none: bool = True,
+) -> tuple[Hashable, ...] | None | ellipsis:
"""Parse one or more dimensions.
A single dimension must be always a str, multiple dimensions
@@ -531,10 +956,39 @@ def parse_ordered_dims(dim: Dims, all_dims: tuple[Hashable, ...], *,
parsed_dims : tuple of Hashable
Input dimensions as a tuple.
"""
- pass
+ if dim is not None and dim is not ... and not isinstance(dim, str) and ... in dim:
+ dims_set: set[Hashable | ellipsis] = set(dim)
+ all_dims_set = set(all_dims)
+ if check_exists:
+ _check_dims(dims_set, all_dims_set)
+ if len(all_dims_set) != len(all_dims):
+ raise ValueError("Cannot use ellipsis with repeated dims")
+ dims = tuple(dim)
+ if dims.count(...) > 1:
+ raise ValueError("More than one ellipsis supplied")
+ other_dims = tuple(d for d in all_dims if d not in dims_set)
+ idx = dims.index(...)
+ return dims[:idx] + other_dims + dims[idx + 1 :]
+ else:
+ # mypy cannot resolve that the sequence cannot contain "..."
+ return parse_dims( # type: ignore[call-overload]
+ dim=dim,
+ all_dims=all_dims,
+ check_exists=check_exists,
+ replace_none=replace_none,
+ )
+
+def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
+ wrong_dims = (dim - all_dims) - {...}
+ if wrong_dims:
+ wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)
+ raise ValueError(
+ f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}"
+ )
-_Accessor = TypeVar('_Accessor')
+
+_Accessor = TypeVar("_Accessor")
class UncachedAccessor(Generic[_Accessor]):
@@ -545,24 +999,23 @@ class UncachedAccessor(Generic[_Accessor]):
accessor.
"""
- def __init__(self, accessor: type[_Accessor]) ->None:
+ def __init__(self, accessor: type[_Accessor]) -> None:
self._accessor = accessor
@overload
- def __get__(self, obj: None, cls) ->type[_Accessor]:
- ...
+ def __get__(self, obj: None, cls) -> type[_Accessor]: ...
@overload
- def __get__(self, obj: object, cls) ->_Accessor:
- ...
+ def __get__(self, obj: object, cls) -> _Accessor: ...
- def __get__(self, obj: (None | object), cls) ->(type[_Accessor] | _Accessor
- ):
+ def __get__(self, obj: None | object, cls) -> type[_Accessor] | _Accessor:
if obj is None:
return self._accessor
- return self._accessor(obj)
+ return self._accessor(obj) # type: ignore # assume it is a valid accessor!
+
+# Singleton type, as per https://github.com/python/typing/pull/240
class Default(Enum):
token = 0
@@ -570,14 +1023,33 @@ class Default(Enum):
_default = Default.token
-def contains_only_chunked_or_numpy(obj) ->bool:
+def iterate_nested(nested_list):
+ for item in nested_list:
+ if isinstance(item, list):
+ yield from iterate_nested(item)
+ else:
+ yield item
+
+
+def contains_only_chunked_or_numpy(obj) -> bool:
"""Returns True if xarray object contains only numpy arrays or chunked arrays (i.e. pure dask or cubed).
Expects obj to be Dataset or DataArray"""
- pass
+ from xarray.core.dataarray import DataArray
+ from xarray.namedarray.pycompat import is_chunked_array
+ if isinstance(obj, DataArray):
+ obj = obj._to_temp_dataset()
-def find_stack_level(test_mode=False) ->int:
+ return all(
+ [
+ isinstance(var.data, np.ndarray) or is_chunked_array(var.data)
+ for var in obj.variables.values()
+ ]
+ )
+
+
+def find_stack_level(test_mode=False) -> int:
"""Find the first place in the stack that is not inside xarray or the Python standard library.
This is unless the code emanates from a test, in which case we would prefer
@@ -596,20 +1068,101 @@ def find_stack_level(test_mode=False) ->int:
stacklevel : int
First level in the stack that is not part of xarray or the Python standard library.
"""
- pass
+ import xarray as xr
+
+ pkg_dir = Path(xr.__file__).parent
+ test_dir = pkg_dir / "tests"
+
+ std_lib_init = sys.modules["os"].__file__
+ # Mostly to appease mypy; I don't think this can happen...
+ if std_lib_init is None:
+ return 0
+
+ std_lib_dir = Path(std_lib_init).parent
+
+ frame = inspect.currentframe()
+ n = 0
+ while frame:
+ fname = inspect.getfile(frame)
+ if (
+ fname.startswith(str(pkg_dir))
+ and (not fname.startswith(str(test_dir)) or test_mode)
+ ) or (
+ fname.startswith(str(std_lib_dir))
+ and "site-packages" not in fname
+ and "dist-packages" not in fname
+ ):
+ frame = frame.f_back
+ n += 1
+ else:
+ break
+ return n
-def emit_user_level_warning(message, category=None) ->None:
+def emit_user_level_warning(message, category=None) -> None:
"""Emit a warning at the user level by inspecting the stack trace."""
- pass
+ stacklevel = find_stack_level()
+ return warnings.warn(message, category=category, stacklevel=stacklevel)
-def consolidate_dask_from_array_kwargs(from_array_kwargs: dict[Any, Any],
- name: (str | None)=None, lock: (bool | None)=None, inline_array: (bool |
- None)=None) ->dict[Any, Any]:
+def consolidate_dask_from_array_kwargs(
+ from_array_kwargs: dict[Any, Any],
+ name: str | None = None,
+ lock: bool | None = None,
+ inline_array: bool | None = None,
+) -> dict[Any, Any]:
"""
Merge dask-specific kwargs with arbitrary from_array_kwargs dict.
Temporary function, to be deleted once explicitly passing dask-specific kwargs to .chunk() is deprecated.
"""
- pass
+
+ from_array_kwargs = _resolve_doubly_passed_kwarg(
+ from_array_kwargs,
+ kwarg_name="name",
+ passed_kwarg_value=name,
+ default=None,
+ err_msg_dict_name="from_array_kwargs",
+ )
+ from_array_kwargs = _resolve_doubly_passed_kwarg(
+ from_array_kwargs,
+ kwarg_name="lock",
+ passed_kwarg_value=lock,
+ default=False,
+ err_msg_dict_name="from_array_kwargs",
+ )
+ from_array_kwargs = _resolve_doubly_passed_kwarg(
+ from_array_kwargs,
+ kwarg_name="inline_array",
+ passed_kwarg_value=inline_array,
+ default=False,
+ err_msg_dict_name="from_array_kwargs",
+ )
+
+ return from_array_kwargs
+
+
+def _resolve_doubly_passed_kwarg(
+ kwargs_dict: dict[Any, Any],
+ kwarg_name: str,
+ passed_kwarg_value: str | bool | None,
+ default: bool | None,
+ err_msg_dict_name: str,
+) -> dict[Any, Any]:
+ # if in kwargs_dict but not passed explicitly then just pass kwargs_dict through unaltered
+ if kwarg_name in kwargs_dict and passed_kwarg_value is None:
+ pass
+ # if passed explicitly but not in kwargs_dict then use that
+ elif kwarg_name not in kwargs_dict and passed_kwarg_value is not None:
+ kwargs_dict[kwarg_name] = passed_kwarg_value
+ # if in neither then use default
+ elif kwarg_name not in kwargs_dict and passed_kwarg_value is None:
+ kwargs_dict[kwarg_name] = default
+ # if in both then raise
+ else:
+ raise ValueError(
+ f"argument {kwarg_name} cannot be passed both as a keyword argument and within "
+ f"the {err_msg_dict_name} dictionary"
+ )
+
+ return kwargs_dict
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index 13cee905..828c53e6 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import copy
import itertools
import math
@@ -8,38 +9,86 @@ from collections.abc import Hashable, Mapping, Sequence
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, NoReturn, cast
+
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from pandas.api.types import is_extension_array_dtype
-import xarray as xr
+
+import xarray as xr # only for Dataset and DataArray
from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils
from xarray.core.arithmetic import VariableArithmetic
from xarray.core.common import AbstractArray
from xarray.core.extension_array import PandasExtensionArray
-from xarray.core.indexing import BasicIndexer, OuterIndexer, PandasIndexingAdapter, VectorizedIndexer, as_indexable
+from xarray.core.indexing import (
+ BasicIndexer,
+ OuterIndexer,
+ PandasIndexingAdapter,
+ VectorizedIndexer,
+ as_indexable,
+)
from xarray.core.options import OPTIONS, _get_keep_attrs
-from xarray.core.utils import OrderedSet, _default, consolidate_dask_from_array_kwargs, decode_numpy_dict_values, drop_dims_from_indexers, either_dict_or_kwargs, emit_user_level_warning, ensure_us_time_resolution, infix_dims, is_dict_like, is_duck_array, is_duck_dask_array, maybe_coerce_to_str
+from xarray.core.utils import (
+ OrderedSet,
+ _default,
+ consolidate_dask_from_array_kwargs,
+ decode_numpy_dict_values,
+ drop_dims_from_indexers,
+ either_dict_or_kwargs,
+ emit_user_level_warning,
+ ensure_us_time_resolution,
+ infix_dims,
+ is_dict_like,
+ is_duck_array,
+ is_duck_dask_array,
+ maybe_coerce_to_str,
+)
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
from xarray.util.deprecation_helpers import deprecate_dims
-NON_NUMPY_SUPPORTED_ARRAY_TYPES = (indexing.ExplicitlyIndexed, pd.Index, pd
- .api.extensions.ExtensionArray)
+
+NON_NUMPY_SUPPORTED_ARRAY_TYPES = (
+ indexing.ExplicitlyIndexed,
+ pd.Index,
+ pd.api.extensions.ExtensionArray,
+)
+# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)
+
if TYPE_CHECKING:
- from xarray.core.types import Dims, ErrorOptionsWithWarn, PadModeOptions, PadReflectOptions, QuantileMethods, Self, T_Chunks, T_DuckArray
+ from xarray.core.types import (
+ Dims,
+ ErrorOptionsWithWarn,
+ PadModeOptions,
+ PadReflectOptions,
+ QuantileMethods,
+ Self,
+ T_Chunks,
+ T_DuckArray,
+ )
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
+
+
NON_NANOSECOND_WARNING = (
- 'Converting non-nanosecond precision {case} values to nanosecond precision. This behavior can eventually be relaxed in xarray, as it is an artifact from pandas which is now beginning to support non-nanosecond precision values. This warning is caused by passing non-nanosecond np.datetime64 or np.timedelta64 values to the DataArray or Variable constructor; it can be silenced by converting the values to nanosecond precision ahead of time.'
- )
+ "Converting non-nanosecond precision {case} values to nanosecond precision. "
+ "This behavior can eventually be relaxed in xarray, as it is an artifact from "
+ "pandas which is now beginning to support non-nanosecond precision values. "
+ "This warning is caused by passing non-nanosecond np.datetime64 or "
+ "np.timedelta64 values to the DataArray or Variable constructor; it can be "
+ "silenced by converting the values to nanosecond precision ahead of time."
+)
class MissingDimensionsError(ValueError):
"""Error class used when we can't safely guess a dimension name."""
+ # inherits from ValueError for backward compatibility
+ # TODO: move this to an xarray.exceptions module?
+
-def as_variable(obj: (T_DuckArray | Any), name=None, auto_convert: bool=True
- ) ->(Variable | IndexVariable):
+def as_variable(
+ obj: T_DuckArray | Any, name=None, auto_convert: bool = True
+) -> Variable | IndexVariable:
"""Convert an object into a Variable.
Parameters
@@ -69,7 +118,65 @@ def as_variable(obj: (T_DuckArray | Any), name=None, auto_convert: bool=True
The newly created variable.
"""
- pass
+ from xarray.core.dataarray import DataArray
+
+ # TODO: consider extending this method to automatically handle Iris and
+ if isinstance(obj, DataArray):
+ # extract the primary Variable from DataArrays
+ obj = obj.variable
+
+ if isinstance(obj, Variable):
+ obj = obj.copy(deep=False)
+ elif isinstance(obj, tuple):
+ try:
+ dims_, data_, *attrs = obj
+ except ValueError:
+ raise ValueError(f"Tuple {obj} is not in the form (dims, data[, attrs])")
+
+ if isinstance(data_, DataArray):
+ raise TypeError(
+ f"Variable {name!r}: Using a DataArray object to construct a variable is"
+ " ambiguous, please extract the data using the .data property."
+ )
+ try:
+ obj = Variable(dims_, data_, *attrs)
+ except (TypeError, ValueError) as error:
+ raise error.__class__(
+ f"Variable {name!r}: Could not convert tuple of form "
+ f"(dims, data[, attrs, encoding]): {obj} to Variable."
+ )
+ elif utils.is_scalar(obj):
+ obj = Variable([], obj)
+ elif isinstance(obj, (pd.Index, IndexVariable)) and obj.name is not None:
+ obj = Variable(obj.name, obj)
+ elif isinstance(obj, (set, dict)):
+ raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}")
+ elif name is not None:
+ data: T_DuckArray = as_compatible_data(obj)
+ if data.ndim != 1:
+ raise MissingDimensionsError(
+ f"cannot set variable {name!r} with {data.ndim!r}-dimensional data "
+ "without explicit dimension names. Pass a tuple of "
+ "(dims, data) instead."
+ )
+ obj = Variable(name, data, fastpath=True)
+ else:
+ raise TypeError(
+ f"Variable {name!r}: unable to convert object into a variable without an "
+ f"explicit list of dimensions: {obj!r}"
+ )
+
+ if auto_convert:
+ if name is not None and name in obj.dims and obj.ndim == 1:
+ # automatically convert the Variable into an Index
+ emit_user_level_warning(
+ f"variable {name!r} with name matching its dimension will not be "
+ "automatically converted into an `IndexVariable` object in the future.",
+ FutureWarning,
+ )
+ obj = obj.to_index_variable()
+
+ return obj
def _maybe_wrap_data(data):
@@ -80,7 +187,35 @@ def _maybe_wrap_data(data):
NumpyArrayAdapter, PandasIndexingAdapter and LazilyIndexedArray should
all pass through unmodified.
"""
- pass
+ if isinstance(data, pd.Index):
+ return PandasIndexingAdapter(data)
+ if isinstance(data, pd.api.extensions.ExtensionArray):
+ return PandasExtensionArray[type(data)](data)
+ return data
+
+
+def _as_nanosecond_precision(data):
+ dtype = data.dtype
+ non_ns_datetime64 = (
+ dtype.kind == "M"
+ and isinstance(dtype, np.dtype)
+ and dtype != np.dtype("datetime64[ns]")
+ )
+ non_ns_datetime_tz_dtype = (
+ isinstance(dtype, pd.DatetimeTZDtype) and dtype.unit != "ns"
+ )
+ if non_ns_datetime64 or non_ns_datetime_tz_dtype:
+ utils.emit_user_level_warning(NON_NANOSECOND_WARNING.format(case="datetime"))
+ if isinstance(dtype, pd.DatetimeTZDtype):
+ nanosecond_precision_dtype = pd.DatetimeTZDtype("ns", dtype.tz)
+ else:
+ nanosecond_precision_dtype = "datetime64[ns]"
+ return duck_array_ops.astype(data, nanosecond_precision_dtype)
+ elif dtype.kind == "m" and dtype != np.dtype("timedelta64[ns]"):
+ utils.emit_user_level_warning(NON_NANOSECOND_WARNING.format(case="timedelta"))
+ return duck_array_ops.astype(data, "timedelta64[ns]")
+ else:
+ return data
def _possibly_convert_objects(values):
@@ -95,7 +230,17 @@ def _possibly_convert_objects(values):
within the valid date range for ns precision, as pandas will raise an error
if they are not.
"""
- pass
+ as_series = pd.Series(values.ravel(), copy=False)
+ if as_series.dtype.kind in "mM":
+ as_series = _as_nanosecond_precision(as_series)
+ result = np.asarray(as_series).reshape(values.shape)
+ if not result.flags.writeable:
+ # GH8843, pandas copy-on-write mode creates read-only arrays by default
+ try:
+ result.flags.writeable = True
+ except ValueError:
+ result = result.copy()
+ return result
def _possibly_convert_datetime_or_timedelta_index(data):
@@ -104,11 +249,17 @@ def _possibly_convert_datetime_or_timedelta_index(data):
this in version 2.0.0, in xarray we will need to make sure we are ready to
handle non-nanosecond precision datetimes or timedeltas in our code
before allowing such values to pass through unchanged."""
- pass
+ if isinstance(data, PandasIndexingAdapter):
+ if isinstance(data.array, (pd.DatetimeIndex, pd.TimedeltaIndex)):
+ data = PandasIndexingAdapter(_as_nanosecond_precision(data.array))
+ elif isinstance(data, (pd.DatetimeIndex, pd.TimedeltaIndex)):
+ data = _as_nanosecond_precision(data)
+ return data
-def as_compatible_data(data: (T_DuckArray | ArrayLike), fastpath: bool=False
- ) ->T_DuckArray:
+def as_compatible_data(
+ data: T_DuckArray | ArrayLike, fastpath: bool = False
+) -> T_DuckArray:
"""Prepare and wrap data to put in a Variable.
- If data does not have the necessary attributes, convert it to ndarray.
@@ -119,7 +270,56 @@ def as_compatible_data(data: (T_DuckArray | ArrayLike), fastpath: bool=False
Finally, wrap it up with an adapter if necessary.
"""
- pass
+ if fastpath and getattr(data, "ndim", None) is not None:
+ return cast("T_DuckArray", data)
+
+ from xarray.core.dataarray import DataArray
+
+ # TODO: do this uwrapping in the Variable/NamedArray constructor instead.
+ if isinstance(data, Variable):
+ return cast("T_DuckArray", data._data)
+
+ # TODO: do this uwrapping in the DataArray constructor instead.
+ if isinstance(data, DataArray):
+ return cast("T_DuckArray", data._variable._data)
+
+ if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
+ data = _possibly_convert_datetime_or_timedelta_index(data)
+ return cast("T_DuckArray", _maybe_wrap_data(data))
+
+ if isinstance(data, tuple):
+ data = utils.to_0d_object_array(data)
+
+ if isinstance(data, pd.Timestamp):
+ # TODO: convert, handle datetime objects, too
+ data = np.datetime64(data.value, "ns")
+
+ if isinstance(data, timedelta):
+ data = np.timedelta64(getattr(data, "value", data), "ns")
+
+ # we don't want nested self-described arrays
+ if isinstance(data, (pd.Series, pd.DataFrame)):
+ data = data.values # type: ignore[assignment]
+
+ if isinstance(data, np.ma.MaskedArray):
+ mask = np.ma.getmaskarray(data)
+ if mask.any():
+ dtype, fill_value = dtypes.maybe_promote(data.dtype)
+ data = duck_array_ops.where_method(data, ~mask, fill_value)
+ else:
+ data = np.asarray(data)
+
+ if not isinstance(data, np.ndarray) and (
+ hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
+ ):
+ return cast("T_DuckArray", data)
+
+ # validate whether the data is valid data types.
+ data = np.asarray(data)
+
+ if isinstance(data, np.ndarray) and data.dtype.kind in "OMm":
+ data = _possibly_convert_objects(data)
+ return _maybe_wrap_data(data)
def _as_array_or_item(data):
@@ -136,7 +336,13 @@ def _as_array_or_item(data):
TODO: remove this (replace with np.asarray) once these issues are fixed
"""
- pass
+ data = np.asarray(data)
+ if data.ndim == 0:
+ if data.dtype.kind == "M":
+ data = np.datetime64(data, "ns")
+ elif data.dtype.kind == "m":
+ data = np.timedelta64(data, "ns")
+ return data
class Variable(NamedArray, AbstractArray, VariableArithmetic):
@@ -159,10 +365,17 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
form of a Dataset or DataArray should almost always be preferred, because
they can use more complete metadata in context of coordinate labels.
"""
- __slots__ = '_dims', '_data', '_attrs', '_encoding'
- def __init__(self, dims, data: (T_DuckArray | ArrayLike), attrs=None,
- encoding=None, fastpath=False):
+ __slots__ = ("_dims", "_data", "_attrs", "_encoding")
+
+ def __init__(
+ self,
+ dims,
+ data: T_DuckArray | ArrayLike,
+ attrs=None,
+ encoding=None,
+ fastpath=False,
+ ):
"""
Parameters
----------
@@ -182,12 +395,42 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
Well-behaved code to serialize a Variable should ignore
unrecognized encoding items.
"""
- super().__init__(dims=dims, data=as_compatible_data(data, fastpath=
- fastpath), attrs=attrs)
+ super().__init__(
+ dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs
+ )
+
self._encoding = None
if encoding is not None:
self.encoding = encoding
+ def _new(
+ self,
+ dims=_default,
+ data=_default,
+ attrs=_default,
+ ):
+ dims_ = copy.copy(self._dims) if dims is _default else dims
+
+ if attrs is _default:
+ attrs_ = None if self._attrs is None else self._attrs.copy()
+ else:
+ attrs_ = attrs
+
+ if data is _default:
+ return type(self)(dims_, copy.copy(self._data), attrs_)
+ else:
+ cls_ = type(self)
+ return cls_(dims_, data, attrs_)
+
+ @property
+ def _in_memory(self):
+ return isinstance(
+ self._data, (np.ndarray, np.number, PandasIndexingAdapter)
+ ) or (
+ isinstance(self._data, indexing.MemoryCachedArray)
+ and isinstance(self._data.array, indexing.NumpyIndexingAdapter)
+ )
+
@property
def data(self):
"""
@@ -200,10 +443,29 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
Variable.as_numpy
Variable.values
"""
- pass
-
- def astype(self, dtype, *, order=None, casting=None, subok=None, copy=
- None, keep_attrs=True) ->Self:
+ if is_duck_array(self._data):
+ return self._data
+ elif isinstance(self._data, indexing.ExplicitlyIndexed):
+ return self._data.get_duck_array()
+ else:
+ return self.values
+
+ @data.setter
+ def data(self, data: T_DuckArray | ArrayLike) -> None:
+ data = as_compatible_data(data)
+ self._check_shape(data)
+ self._data = data
+
+ def astype(
+ self,
+ dtype,
+ *,
+ order=None,
+ casting=None,
+ subok=None,
+ copy=None,
+ keep_attrs=True,
+ ) -> Self:
"""
Copy of the Variable object, with data cast to a specified type.
@@ -255,31 +517,86 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
dask.array.Array.astype
sparse.COO.astype
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ kwargs = dict(order=order, casting=casting, subok=subok, copy=copy)
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
+ return apply_ufunc(
+ duck_array_ops.astype,
+ self,
+ dtype,
+ kwargs=kwargs,
+ keep_attrs=keep_attrs,
+ dask="allowed",
+ )
+
+ def _dask_finalize(self, results, array_func, *args, **kwargs):
+ data = array_func(results, *args, **kwargs)
+ return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding)
@property
def values(self):
"""The variable's data as a numpy.ndarray"""
- pass
+ return _as_array_or_item(self._data)
+
+ @values.setter
+ def values(self, values):
+ self.data = values
- def to_base_variable(self) ->Variable:
+ def to_base_variable(self) -> Variable:
"""Return this variable as a base xarray.Variable"""
- pass
- to_variable = utils.alias(to_base_variable, 'to_variable')
+ return Variable(
+ self._dims, self._data, self._attrs, encoding=self._encoding, fastpath=True
+ )
- def to_index_variable(self) ->IndexVariable:
+ to_variable = utils.alias(to_base_variable, "to_variable")
+
+ def to_index_variable(self) -> IndexVariable:
"""Return this variable as an xarray.IndexVariable"""
- pass
- to_coord = utils.alias(to_index_variable, 'to_coord')
+ return IndexVariable(
+ self._dims, self._data, self._attrs, encoding=self._encoding, fastpath=True
+ )
+
+ to_coord = utils.alias(to_index_variable, "to_coord")
+
+ def _to_index(self) -> pd.Index:
+ return self.to_index_variable()._to_index()
- def to_index(self) ->pd.Index:
+ def to_index(self) -> pd.Index:
"""Convert this variable to a pandas.Index"""
- pass
+ return self.to_index_variable().to_index()
- def to_dict(self, data: (bool | str)='list', encoding: bool=False) ->dict[
- str, Any]:
+ def to_dict(
+ self, data: bool | str = "list", encoding: bool = False
+ ) -> dict[str, Any]:
"""Dictionary representation of variable."""
- pass
+ item: dict[str, Any] = {
+ "dims": self.dims,
+ "attrs": decode_numpy_dict_values(self.attrs),
+ }
+ if data is not False:
+ if data in [True, "list"]:
+ item["data"] = ensure_us_time_resolution(self.to_numpy()).tolist()
+ elif data == "array":
+ item["data"] = ensure_us_time_resolution(self.data)
+ else:
+ msg = 'data argument must be bool, "list", or "array"'
+ raise ValueError(msg)
+
+ else:
+ item.update({"dtype": str(self.dtype), "shape": self.shape})
+
+ if encoding:
+ item["encoding"] = dict(self.encoding)
+
+ return item
+
+ def _item_key_to_tuple(self, key):
+ if is_dict_like(key):
+ return tuple(key.get(dim, slice(None)) for dim in self.dims)
+ else:
+ return key
def _broadcast_indexes(self, key):
"""Prepare an indexing key for an indexing operation.
@@ -302,13 +619,170 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
the first len(new_order) indexing should be moved to these
positions.
"""
- pass
+ key = self._item_key_to_tuple(key) # key is a tuple
+ # key is a tuple of full size
+ key = indexing.expanded_indexer(key, self.ndim)
+ # Convert a scalar Variable to a 0d-array
+ key = tuple(
+ k.data if isinstance(k, Variable) and k.ndim == 0 else k for k in key
+ )
+ # Convert a 0d numpy arrays to an integer
+ # dask 0d arrays are passed through
+ key = tuple(
+ k.item() if isinstance(k, np.ndarray) and k.ndim == 0 else k for k in key
+ )
+
+ if all(isinstance(k, BASIC_INDEXING_TYPES) for k in key):
+ return self._broadcast_indexes_basic(key)
+
+ self._validate_indexers(key)
+ # Detect it can be mapped as an outer indexer
+ # If all key is unlabeled, or
+ # key can be mapped as an OuterIndexer.
+ if all(not isinstance(k, Variable) for k in key):
+ return self._broadcast_indexes_outer(key)
+
+ # If all key is 1-dimensional and there are no duplicate labels,
+ # key can be mapped as an OuterIndexer.
+ dims = []
+ for k, d in zip(key, self.dims):
+ if isinstance(k, Variable):
+ if len(k.dims) > 1:
+ return self._broadcast_indexes_vectorized(key)
+ dims.append(k.dims[0])
+ elif not isinstance(k, integer_types):
+ dims.append(d)
+ if len(set(dims)) == len(dims):
+ return self._broadcast_indexes_outer(key)
+
+ return self._broadcast_indexes_vectorized(key)
+
+ def _broadcast_indexes_basic(self, key):
+ dims = tuple(
+ dim for k, dim in zip(key, self.dims) if not isinstance(k, integer_types)
+ )
+ return dims, BasicIndexer(key), None
def _validate_indexers(self, key):
"""Make sanity checks"""
- pass
-
- def __getitem__(self, key) ->Self:
+ for dim, k in zip(self.dims, key):
+ if not isinstance(k, BASIC_INDEXING_TYPES):
+ if not isinstance(k, Variable):
+ if not is_duck_array(k):
+ k = np.asarray(k)
+ if k.ndim > 1:
+ raise IndexError(
+ "Unlabeled multi-dimensional array cannot be "
+ f"used for indexing: {k}"
+ )
+ if k.dtype.kind == "b":
+ if self.shape[self.get_axis_num(dim)] != len(k):
+ raise IndexError(
+ f"Boolean array size {len(k):d} is used to index array "
+ f"with shape {str(self.shape):s}."
+ )
+ if k.ndim > 1:
+ raise IndexError(
+ f"{k.ndim}-dimensional boolean indexing is "
+ "not supported. "
+ )
+ if is_duck_dask_array(k.data):
+ raise KeyError(
+ "Indexing with a boolean dask array is not allowed. "
+ "This will result in a dask array of unknown shape. "
+ "Such arrays are unsupported by Xarray."
+ "Please compute the indexer first using .compute()"
+ )
+ if getattr(k, "dims", (dim,)) != (dim,):
+ raise IndexError(
+ "Boolean indexer should be unlabeled or on the "
+ "same dimension to the indexed array. Indexer is "
+ f"on {str(k.dims):s} but the target dimension is {dim:s}."
+ )
+
+ def _broadcast_indexes_outer(self, key):
+ # drop dim if k is integer or if k is a 0d dask array
+ dims = tuple(
+ k.dims[0] if isinstance(k, Variable) else dim
+ for k, dim in zip(key, self.dims)
+ if (not isinstance(k, integer_types) and not is_0d_dask_array(k))
+ )
+
+ new_key = []
+ for k in key:
+ if isinstance(k, Variable):
+ k = k.data
+ if not isinstance(k, BASIC_INDEXING_TYPES):
+ if not is_duck_array(k):
+ k = np.asarray(k)
+ if k.size == 0:
+ # Slice by empty list; numpy could not infer the dtype
+ k = k.astype(int)
+ elif k.dtype.kind == "b":
+ (k,) = np.nonzero(k)
+ new_key.append(k)
+
+ return dims, OuterIndexer(tuple(new_key)), None
+
+ def _broadcast_indexes_vectorized(self, key):
+ variables = []
+ out_dims_set = OrderedSet()
+ for dim, value in zip(self.dims, key):
+ if isinstance(value, slice):
+ out_dims_set.add(dim)
+ else:
+ variable = (
+ value
+ if isinstance(value, Variable)
+ else as_variable(value, name=dim, auto_convert=False)
+ )
+ if variable.dims == (dim,):
+ variable = variable.to_index_variable()
+ if variable.dtype.kind == "b": # boolean indexing case
+ (variable,) = variable._nonzero()
+
+ variables.append(variable)
+ out_dims_set.update(variable.dims)
+
+ variable_dims = set()
+ for variable in variables:
+ variable_dims.update(variable.dims)
+
+ slices = []
+ for i, (dim, value) in enumerate(zip(self.dims, key)):
+ if isinstance(value, slice):
+ if dim in variable_dims:
+ # We only convert slice objects to variables if they share
+ # a dimension with at least one other variable. Otherwise,
+ # we can equivalently leave them as slices aknd transpose
+ # the result. This is significantly faster/more efficient
+ # for most array backends.
+ values = np.arange(*value.indices(self.sizes[dim]))
+ variables.insert(i - len(slices), Variable((dim,), values))
+ else:
+ slices.append((i, value))
+
+ try:
+ variables = _broadcast_compat_variables(*variables)
+ except ValueError:
+ raise IndexError(f"Dimensions of indexers mismatch: {key}")
+
+ out_key = [variable.data for variable in variables]
+ out_dims = tuple(out_dims_set)
+ slice_positions = set()
+ for i, value in slices:
+ out_key.insert(i, value)
+ new_position = out_dims.index(self.dims[i])
+ slice_positions.add(new_position)
+
+ if slice_positions:
+ new_order = [i for i in range(len(out_dims)) if i not in slice_positions]
+ else:
+ new_order = None
+
+ return out_dims, VectorizedIndexer(tuple(out_key)), new_order
+
+ def __getitem__(self, key) -> Self:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.
@@ -323,18 +797,59 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
"""
dims, indexer, new_order = self._broadcast_indexes(key)
indexable = as_indexable(self._data)
+
data = indexing.apply_indexer(indexable, indexer)
+
if new_order:
data = np.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)
- def _finalize_indexing_result(self, dims, data) ->Self:
+ def _finalize_indexing_result(self, dims, data) -> Self:
"""Used by IndexVariable to return IndexVariable objects when possible."""
- pass
+ return self._replace(dims=dims, data=data)
def _getitem_with_mask(self, key, fill_value=dtypes.NA):
"""Index this Variable with -1 remapped to fill_value."""
- pass
+ # TODO(shoyer): expose this method in public API somewhere (isel?) and
+ # use it for reindex.
+ # TODO(shoyer): add a sanity check that all other integers are
+ # non-negative
+ # TODO(shoyer): add an optimization, remapping -1 to an adjacent value
+ # that is actually indexed rather than mapping it to the last value
+ # along each axis.
+
+ if fill_value is dtypes.NA:
+ fill_value = dtypes.get_fill_value(self.dtype)
+
+ dims, indexer, new_order = self._broadcast_indexes(key)
+
+ if self.size:
+
+ if is_duck_dask_array(self._data):
+ # dask's indexing is faster this way; also vindex does not
+ # support negative indices yet:
+ # https://github.com/dask/dask/pull/2967
+ actual_indexer = indexing.posify_mask_indexer(indexer)
+ else:
+ actual_indexer = indexer
+
+ indexable = as_indexable(self._data)
+ data = indexing.apply_indexer(indexable, actual_indexer)
+
+ mask = indexing.create_mask(indexer, self.shape, data)
+ # we need to invert the mask in order to pass data first. This helps
+ # pint to choose the correct unit
+ # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed
+ data = duck_array_ops.where(np.logical_not(mask), data, fill_value)
+ else:
+ # array cannot be indexed along dimensions of size 0, so just
+ # build the mask directly instead.
+ mask = indexing.create_mask(indexer, self.shape)
+ data = np.broadcast_to(fill_value, getattr(mask, "shape", ()))
+
+ if new_order:
+ data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order)
+ return self._finalize_indexing_result(dims, data)
def __setitem__(self, key, value):
"""__setitem__ is overloaded to access the underlying numpy values with
@@ -343,33 +858,104 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
See __getitem__ for more details.
"""
dims, index_tuple, new_order = self._broadcast_indexes(key)
+
if not isinstance(value, Variable):
value = as_compatible_data(value)
if value.ndim > len(dims):
raise ValueError(
- f'shape mismatch: value array of shape {value.shape} could not be broadcast to indexing result with {len(dims)} dimensions'
- )
+ f"shape mismatch: value array of shape {value.shape} could not be "
+ f"broadcast to indexing result with {len(dims)} dimensions"
+ )
if value.ndim == 0:
value = Variable((), value)
else:
- value = Variable(dims[-value.ndim:], value)
+ value = Variable(dims[-value.ndim :], value)
+ # broadcast to become assignable
value = value.set_dims(dims).data
+
if new_order:
value = duck_array_ops.asarray(value)
- value = value[(len(dims) - value.ndim) * (np.newaxis,) + (
- Ellipsis,)]
+ value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)]
value = np.moveaxis(value, new_order, range(len(new_order)))
+
indexable = as_indexable(self._data)
indexing.set_with_indexer(indexable, index_tuple, value)
@property
- def encoding(self) ->dict[Any, Any]:
+ def encoding(self) -> dict[Any, Any]:
"""Dictionary of encodings on this variable."""
- pass
-
- def drop_encoding(self) ->Self:
+ if self._encoding is None:
+ self._encoding = {}
+ return self._encoding
+
+ @encoding.setter
+ def encoding(self, value):
+ try:
+ self._encoding = dict(value)
+ except ValueError:
+ raise ValueError("encoding must be castable to a dictionary")
+
+ def reset_encoding(self) -> Self:
+ warnings.warn(
+ "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead"
+ )
+ return self.drop_encoding()
+
+ def drop_encoding(self) -> Self:
"""Return a new Variable without encoding."""
- pass
+ return self._replace(encoding={})
+
+ def _copy(
+ self,
+ deep: bool = True,
+ data: T_DuckArray | ArrayLike | None = None,
+ memo: dict[int, Any] | None = None,
+ ) -> Self:
+ if data is None:
+ data_old = self._data
+
+ if not isinstance(data_old, indexing.MemoryCachedArray):
+ ndata = data_old
+ else:
+ # don't share caching between copies
+ # TODO: MemoryCachedArray doesn't match the array api:
+ ndata = indexing.MemoryCachedArray(data_old.array) # type: ignore[assignment]
+
+ if deep:
+ ndata = copy.deepcopy(ndata, memo)
+
+ else:
+ ndata = as_compatible_data(data)
+ if self.shape != ndata.shape: # type: ignore[attr-defined]
+ raise ValueError(
+ f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
+ )
+
+ attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
+ encoding = (
+ copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding)
+ )
+
+ # note: dims is already an immutable tuple
+ return self._replace(data=ndata, attrs=attrs, encoding=encoding)
+
+ def _replace(
+ self,
+ dims=_default,
+ data=_default,
+ attrs=_default,
+ encoding=_default,
+ ) -> Self:
+ if dims is _default:
+ dims = copy.copy(self._dims)
+ if data is _default:
+ data = copy.copy(self.data)
+ if attrs is _default:
+ attrs = copy.copy(self._attrs)
+
+ if encoding is _default:
+ encoding = copy.copy(self._encoding)
+ return type(self)(dims, data, attrs, encoding, fastpath=True)
def load(self, **kwargs):
"""Manually trigger loading of this variable's data from disk or a
@@ -388,7 +974,8 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
dask.array.compute
"""
- pass
+ self._data = to_duck_array(self._data, **kwargs)
+ return self
def compute(self, **kwargs):
"""Manually trigger loading of this variable's data from disk or a
@@ -408,10 +995,15 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
dask.array.compute
"""
- pass
-
- def isel(self, indexers: (Mapping[Any, Any] | None)=None, missing_dims:
- ErrorOptionsWithWarn='raise', **indexers_kwargs: Any) ->Self:
+ new = self.copy(deep=False)
+ return new.load(**kwargs)
+
+ def isel(
+ self,
+ indexers: Mapping[Any, Any] | None = None,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ **indexers_kwargs: Any,
+ ) -> Self:
"""Return a new array indexed along the specified dimension(s).
Parameters
@@ -434,7 +1026,12 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
unless numpy fancy indexing was triggered by using an array
indexer, in which case the data will be a copy.
"""
- pass
+ indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
+
+ indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)
+
+ key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
+ return self[key]
def squeeze(self, dim=None):
"""Return a new object with squeezed data.
@@ -456,7 +1053,44 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
numpy.squeeze
"""
- pass
+ dims = common.get_squeeze_dims(self, dim)
+ return self.isel({d: 0 for d in dims})
+
+ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA):
+ axis = self.get_axis_num(dim)
+
+ if count > 0:
+ keep = slice(None, -count)
+ elif count < 0:
+ keep = slice(-count, None)
+ else:
+ keep = slice(None)
+
+ trimmed_data = self[(slice(None),) * axis + (keep,)].data
+
+ if fill_value is dtypes.NA:
+ dtype, fill_value = dtypes.maybe_promote(self.dtype)
+ else:
+ dtype = self.dtype
+
+ width = min(abs(count), self.shape[axis])
+ dim_pad = (width, 0) if count >= 0 else (0, width)
+ pads = [(0, 0) if d != dim else dim_pad for d in self.dims]
+
+ data = np.pad(
+ duck_array_ops.astype(trimmed_data, dtype),
+ pads,
+ mode="constant",
+ constant_values=fill_value,
+ )
+
+ if is_duck_dask_array(data):
+ # chunked data should come out with the same chunks; this makes
+ # it feasible to combine shifted and unshifted data
+ # TODO: remove this once dask.array automatically aligns chunks
+ data = data.rechunk(self.data.chunks)
+
+ return self._replace(data=data)
def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
"""
@@ -479,16 +1113,39 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
shifted : Variable
Variable with the same dimensions and attributes but shifted data.
"""
- pass
-
- def pad(self, pad_width: (Mapping[Any, int | tuple[int, int]] | None)=
- None, mode: PadModeOptions='constant', stat_length: (int | tuple[
- int, int] | Mapping[Any, tuple[int, int]] | None)=None,
- constant_values: (float | tuple[float, float] | Mapping[Any, tuple[
- float, float]] | None)=None, end_values: (int | tuple[int, int] |
- Mapping[Any, tuple[int, int]] | None)=None, reflect_type:
- PadReflectOptions=None, keep_attrs: (bool | None)=None, **
- pad_width_kwargs: Any):
+ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "shift")
+ result = self
+ for dim, count in shifts.items():
+ result = result._shift_one_dim(dim, count, fill_value=fill_value)
+ return result
+
+ def _pad_options_dim_to_index(
+ self,
+ pad_option: Mapping[Any, int | tuple[int, int]],
+ fill_with_shape=False,
+ ):
+ if fill_with_shape:
+ return [
+ (n, n) if d not in pad_option else pad_option[d]
+ for d, n in zip(self.dims, self.data.shape)
+ ]
+ return [(0, 0) if d not in pad_option else pad_option[d] for d in self.dims]
+
+ def pad(
+ self,
+ pad_width: Mapping[Any, int | tuple[int, int]] | None = None,
+ mode: PadModeOptions = "constant",
+ stat_length: (
+ int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
+ ) = None,
+ constant_values: (
+ float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
+ ) = None,
+ end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
+ reflect_type: PadReflectOptions = None,
+ keep_attrs: bool | None = None,
+ **pad_width_kwargs: Any,
+ ):
"""
Return a new Variable with padded data.
@@ -526,7 +1183,80 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
padded : Variable
Variable with the same dimensions and attributes but padded data.
"""
- pass
+ pad_width = either_dict_or_kwargs(pad_width, pad_width_kwargs, "pad")
+
+ # change default behaviour of pad with mode constant
+ if mode == "constant" and (
+ constant_values is None or constant_values is dtypes.NA
+ ):
+ dtype, constant_values = dtypes.maybe_promote(self.dtype)
+ else:
+ dtype = self.dtype
+
+ # create pad_options_kwargs, numpy requires only relevant kwargs to be nonempty
+ if isinstance(stat_length, dict):
+ stat_length = self._pad_options_dim_to_index(
+ stat_length, fill_with_shape=True
+ )
+ if isinstance(constant_values, dict):
+ constant_values = self._pad_options_dim_to_index(constant_values)
+ if isinstance(end_values, dict):
+ end_values = self._pad_options_dim_to_index(end_values)
+
+ # workaround for bug in Dask's default value of stat_length https://github.com/dask/dask/issues/5303
+ if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]:
+ stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment]
+
+ # change integer values to a tuple of two of those values and change pad_width to index
+ for k, v in pad_width.items():
+ if isinstance(v, numbers.Number):
+ pad_width[k] = (v, v)
+ pad_width_by_index = self._pad_options_dim_to_index(pad_width)
+
+ # create pad_options_kwargs, numpy/dask requires only relevant kwargs to be nonempty
+ pad_option_kwargs: dict[str, Any] = {}
+ if stat_length is not None:
+ pad_option_kwargs["stat_length"] = stat_length
+ if constant_values is not None:
+ pad_option_kwargs["constant_values"] = constant_values
+ if end_values is not None:
+ pad_option_kwargs["end_values"] = end_values
+ if reflect_type is not None:
+ pad_option_kwargs["reflect_type"] = reflect_type
+
+ array = np.pad(
+ duck_array_ops.astype(self.data, dtype, copy=False),
+ pad_width_by_index,
+ mode=mode,
+ **pad_option_kwargs,
+ )
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+ attrs = self._attrs if keep_attrs else None
+
+ return type(self)(self.dims, array, attrs=attrs)
+
+ def _roll_one_dim(self, dim, count):
+ axis = self.get_axis_num(dim)
+
+ count %= self.shape[axis]
+ if count != 0:
+ indices = [slice(-count, None), slice(None, -count)]
+ else:
+ indices = [slice(None)]
+
+ arrays = [self[(slice(None),) * axis + (idx,)].data for idx in indices]
+
+ data = duck_array_ops.concatenate(arrays, axis)
+
+ if is_duck_dask_array(data):
+ # chunked data should come out with the same chunks; this makes
+ # it feasible to combine shifted and unshifted data
+ # TODO: remove this once dask.array automatically aligns chunks
+ data = data.rechunk(self.data.chunks)
+
+ return self._replace(data=data)
def roll(self, shifts=None, **shifts_kwargs):
"""
@@ -547,11 +1277,19 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
shifted : Variable
Variable with the same dimensions and attributes but rolled data.
"""
- pass
+ shifts = either_dict_or_kwargs(shifts, shifts_kwargs, "roll")
+
+ result = self
+ for dim, count in shifts.items():
+ result = result._roll_one_dim(dim, count)
+ return result
@deprecate_dims
- def transpose(self, *dim: (Hashable | ellipsis), missing_dims:
- ErrorOptionsWithWarn='raise') ->Self:
+ def transpose(
+ self,
+ *dim: Hashable | ellipsis,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ ) -> Self:
"""Return a new Variable object with transposed dimensions.
Parameters
@@ -581,7 +1319,23 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
numpy.transpose
"""
- pass
+ if len(dim) == 0:
+ dim = self.dims[::-1]
+ else:
+ dim = tuple(infix_dims(dim, self.dims, missing_dims))
+
+ if len(dim) < 2 or dim == self.dims:
+ # no need to transpose if only one dimension
+ # or dims are in same order
+ return self.copy(deep=False)
+
+ axes = self.get_axis_num(dim)
+ data = as_indexable(self._data).transpose(axes)
+ return self._replace(dims=dim, data=data)
+
+ @property
+ def T(self) -> Self:
+ return self.transpose()
@deprecate_dims
def set_dims(self, dim, shape=None):
@@ -601,9 +1355,66 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
-------
Variable
"""
- pass
+ if isinstance(dim, str):
+ dim = [dim]
+
+ if shape is None and is_dict_like(dim):
+ shape = dim.values()
+
+ missing_dims = set(self.dims) - set(dim)
+ if missing_dims:
+ raise ValueError(
+ f"new dimensions {dim!r} must be a superset of "
+ f"existing dimensions {self.dims!r}"
+ )
+
+ self_dims = set(self.dims)
+ expanded_dims = tuple(d for d in dim if d not in self_dims) + self.dims
+
+ if self.dims == expanded_dims:
+ # don't use broadcast_to unless necessary so the result remains
+ # writeable if possible
+ expanded_data = self.data
+ elif shape is not None:
+ dims_map = dict(zip(dim, shape))
+ tmp_shape = tuple(dims_map[d] for d in expanded_dims)
+ expanded_data = duck_array_ops.broadcast_to(self.data, tmp_shape)
+ else:
+ indexer = (None,) * (len(expanded_dims) - self.ndim) + (...,)
+ expanded_data = self.data[indexer]
+
+ expanded_var = Variable(
+ expanded_dims, expanded_data, self._attrs, self._encoding, fastpath=True
+ )
+ return expanded_var.transpose(*dim)
+
+ def _stack_once(self, dim: list[Hashable], new_dim: Hashable):
+ if not set(dim) <= set(self.dims):
+ raise ValueError(f"invalid existing dimensions: {dim}")
+
+ if new_dim in self.dims:
+ raise ValueError(
+ "cannot create a new dimension with the same "
+ "name as an existing dimension"
+ )
+
+ if len(dim) == 0:
+ # don't stack
+ return self.copy(deep=False)
+
+ other_dims = [d for d in self.dims if d not in dim]
+ dim_order = other_dims + list(dim)
+ reordered = self.transpose(*dim_order)
- @partial(deprecate_dims, old_name='dimensions')
+ new_shape = reordered.shape[: len(other_dims)] + (-1,)
+ new_data = duck_array_ops.reshape(reordered.data, new_shape)
+ new_dims = reordered.dims[: len(other_dims)] + (new_dim,)
+
+ return type(self)(
+ new_dims, new_data, self._attrs, self._encoding, fastpath=True
+ )
+
+ @partial(deprecate_dims, old_name="dimensions")
def stack(self, dim=None, **dim_kwargs):
"""
Stack any number of existing dim into a single new dimension.
@@ -630,27 +1441,118 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
Variable.unstack
"""
- pass
+ dim = either_dict_or_kwargs(dim, dim_kwargs, "stack")
+ result = self
+ for new_dim, dims in dim.items():
+ result = result._stack_once(dims, new_dim)
+ return result
- def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable
- ) ->Self:
+ def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self:
"""
Unstacks the variable without needing an index.
Unlike `_unstack_once`, this function requires the existing dimension to
contain the full product of the new dimensions.
"""
- pass
+ new_dim_names = tuple(dim.keys())
+ new_dim_sizes = tuple(dim.values())
+
+ if old_dim not in self.dims:
+ raise ValueError(f"invalid existing dimension: {old_dim}")
- def _unstack_once(self, index: pd.MultiIndex, dim: Hashable, fill_value
- =dtypes.NA, sparse: bool=False) ->Self:
+ if set(new_dim_names).intersection(self.dims):
+ raise ValueError(
+ "cannot create a new dimension with the same "
+ "name as an existing dimension"
+ )
+
+ if math.prod(new_dim_sizes) != self.sizes[old_dim]:
+ raise ValueError(
+ "the product of the new dimension sizes must "
+ "equal the size of the old dimension"
+ )
+
+ other_dims = [d for d in self.dims if d != old_dim]
+ dim_order = other_dims + [old_dim]
+ reordered = self.transpose(*dim_order)
+
+ new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes
+ new_data = duck_array_ops.reshape(reordered.data, new_shape)
+ new_dims = reordered.dims[: len(other_dims)] + new_dim_names
+
+ return type(self)(
+ new_dims, new_data, self._attrs, self._encoding, fastpath=True
+ )
+
+ def _unstack_once(
+ self,
+ index: pd.MultiIndex,
+ dim: Hashable,
+ fill_value=dtypes.NA,
+ sparse: bool = False,
+ ) -> Self:
"""
Unstacks this variable given an index to unstack and the name of the
dimension to which the index refers.
"""
- pass
- @partial(deprecate_dims, old_name='dimensions')
+ reordered = self.transpose(..., dim)
+
+ new_dim_sizes = [lev.size for lev in index.levels]
+ new_dim_names = index.names
+ indexer = index.codes
+
+ # Potentially we could replace `len(other_dims)` with just `-1`
+ other_dims = [d for d in self.dims if d != dim]
+ new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes)
+ new_dims = reordered.dims[: len(other_dims)] + tuple(new_dim_names)
+
+ create_template: Callable
+ if fill_value is dtypes.NA:
+ is_missing_values = math.prod(new_shape) > math.prod(self.shape)
+ if is_missing_values:
+ dtype, fill_value = dtypes.maybe_promote(self.dtype)
+
+ create_template = partial(np.full_like, fill_value=fill_value)
+ else:
+ dtype = self.dtype
+ fill_value = dtypes.get_fill_value(dtype)
+ create_template = np.empty_like
+ else:
+ dtype = self.dtype
+ create_template = partial(np.full_like, fill_value=fill_value)
+
+ if sparse:
+ # unstacking a dense multitindexed array to a sparse array
+ from sparse import COO
+
+ codes = zip(*index.codes)
+ if reordered.ndim == 1:
+ indexes = codes
+ else:
+ sizes = itertools.product(*[range(s) for s in reordered.shape[:-1]])
+ tuple_indexes = itertools.product(sizes, codes)
+ indexes = map(lambda x: list(itertools.chain(*x)), tuple_indexes) # type: ignore
+
+ data = COO(
+ coords=np.array(list(indexes)).T,
+ data=self.data.astype(dtype).ravel(),
+ fill_value=fill_value,
+ shape=new_shape,
+ sorted=index.is_monotonic_increasing,
+ )
+
+ else:
+ data = create_template(self.data, shape=new_shape, dtype=dtype)
+
+ # Indexer is a list of lists of locations. Each list is the locations
+ # on the new dimension. This is robust to the data being sparse; in that
+ # case the destinations will be NaN / zero.
+ data[(..., *indexer)] = reordered
+
+ return self._replace(dims=new_dims, data=data)
+
+ @partial(deprecate_dims, old_name="dimensions")
def unstack(self, dim=None, **dim_kwargs):
"""
Unstack an existing dimension into multiple new dimensions.
@@ -683,7 +1585,17 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
DataArray.unstack
Dataset.unstack
"""
- pass
+ dim = either_dict_or_kwargs(dim, dim_kwargs, "unstack")
+ result = self
+ for old_dim, dims in dim.items():
+ result = result._unstack_once_full(dims, old_dim)
+ return result
+
+ def fillna(self, value):
+ return ops.fillna(self, value)
+
+ def where(self, cond, other=dtypes.NA):
+ return ops.where_method(self, cond, other)
def clip(self, min=None, max=None):
"""
@@ -696,11 +1608,19 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
numpy.clip : equivalent function
"""
- pass
-
- def reduce(self, func: Callable[..., Any], dim: Dims=None, axis: (int |
- Sequence[int] | None)=None, keep_attrs: (bool | None)=None,
- keepdims: bool=False, **kwargs) ->Variable:
+ from xarray.core.computation import apply_ufunc
+
+ return apply_ufunc(np.clip, self, min, max, dask="allowed")
+
+ def reduce( # type: ignore[override]
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ axis: int | Sequence[int] | None = None,
+ keep_attrs: bool | None = None,
+ keepdims: bool = False,
+ **kwargs,
+ ) -> Variable:
"""Reduce this array by applying `func` along some dimension(s).
Parameters
@@ -733,11 +1653,31 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
Array with summarized data and the indicated dimension(s)
removed.
"""
- pass
+ keep_attrs_ = (
+ _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
+ )
+
+ # Noe that the call order for Variable.mean is
+ # Variable.mean -> NamedArray.mean -> Variable.reduce
+ # -> NamedArray.reduce
+ result = super().reduce(
+ func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs
+ )
+
+ # return Variable always to support IndexVariable
+ return Variable(
+ result.dims, result._data, attrs=result._attrs if keep_attrs_ else None
+ )
@classmethod
- def concat(cls, variables, dim='concat_dim', positions=None, shortcut=
- False, combine_attrs='override'):
+ def concat(
+ cls,
+ variables,
+ dim="concat_dim",
+ positions=None,
+ shortcut=False,
+ combine_attrs="override",
+ ):
"""Concatenate variables along a new or existing dimension.
Parameters
@@ -760,7 +1700,8 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
This option is used internally to speed-up groupby operations.
If `shortcut` is True, some checks of internal consistency between
arrays to concatenate are skipped.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"}, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"}, default: "override"
String indicating how to combine attrs of the objects being merged:
- "drop": empty attrs on returned Dataset.
@@ -778,7 +1719,45 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
Concatenated Variable formed by stacking all the supplied variables
along the given dimension.
"""
- pass
+ from xarray.core.merge import merge_attrs
+
+ if not isinstance(dim, str):
+ (dim,) = dim.dims
+
+ # can't do this lazily: we need to loop through variables at least
+ # twice
+ variables = list(variables)
+ first_var = variables[0]
+ first_var_dims = first_var.dims
+
+ arrays = [v._data for v in variables]
+
+ if dim in first_var_dims:
+ axis = first_var.get_axis_num(dim)
+ dims = first_var_dims
+ data = duck_array_ops.concatenate(arrays, axis=axis)
+ if positions is not None:
+ # TODO: deprecate this option -- we don't need it for groupby
+ # any more.
+ indices = nputils.inverse_permutation(np.concatenate(positions))
+ data = duck_array_ops.take(data, indices, axis=axis)
+ else:
+ axis = 0
+ dims = (dim,) + first_var_dims
+ data = duck_array_ops.stack(arrays, axis=axis)
+
+ attrs = merge_attrs(
+ [var.attrs for var in variables], combine_attrs=combine_attrs
+ )
+ encoding = dict(first_var.encoding)
+ if not shortcut:
+ for var in variables:
+ if var.dims != first_var_dims:
+ raise ValueError(
+ f"Variable has dimensions {tuple(var.dims)} but first Variable has dimensions {tuple(first_var_dims)}"
+ )
+
+ return cls(dims, data, attrs, encoding, fastpath=True)
def equals(self, other, equiv=duck_array_ops.array_equiv):
"""True if two Variables have the same dimensions and values;
@@ -790,7 +1769,13 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
This method is necessary because `v1 == v2` for Variables
does element-wise comparisons (like numpy.ndarrays).
"""
- pass
+ other = getattr(other, "variable", other)
+ try:
+ return self.dims == other.dims and (
+ self._data is other._data or equiv(self.data, other.data)
+ )
+ except (TypeError, AttributeError):
+ return False
def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv):
"""True if two Variables have the values after being broadcast against
@@ -799,11 +1784,20 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
Variables can still be equal (like pandas objects) if they have NaN
values in the same locations.
"""
- pass
+ try:
+ self, other = broadcast_variables(self, other)
+ except (ValueError, AttributeError):
+ return False
+ return self.equals(other, equiv=equiv)
def identical(self, other, equiv=duck_array_ops.array_equiv):
"""Like equals, but also checks attributes."""
- pass
+ try:
+ return utils.dict_equiv(self.attrs, other.attrs) and self.equals(
+ other, equiv=equiv
+ )
+ except (TypeError, AttributeError):
+ return False
def no_conflicts(self, other, equiv=duck_array_ops.array_notnull_equiv):
"""True if the intersection of two Variable's non-null data is
@@ -812,12 +1806,17 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
Variables can thus still be equal if there are locations where either,
or both, contain NaN values.
"""
- pass
-
- def quantile(self, q: ArrayLike, dim: (str | Sequence[Hashable] | None)
- =None, method: QuantileMethods='linear', keep_attrs: (bool | None)=
- None, skipna: (bool | None)=None, interpolation: (QuantileMethods |
- None)=None) ->Self:
+ return self.broadcast_equals(other, equiv=equiv)
+
+ def quantile(
+ self,
+ q: ArrayLike,
+ dim: str | Sequence[Hashable] | None = None,
+ method: QuantileMethods = "linear",
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ interpolation: QuantileMethods | None = None,
+ ) -> Self:
"""Compute the qth quantile of the data along the specified dimension.
Returns the qth quantiles(s) of the array elements.
@@ -886,7 +1885,64 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
"Sample quantiles in statistical packages,"
The American Statistician, 50(4), pp. 361-365, 1996
"""
- pass
+
+ from xarray.core.computation import apply_ufunc
+
+ if interpolation is not None:
+ warnings.warn(
+ "The `interpolation` argument to quantile was renamed to `method`.",
+ FutureWarning,
+ )
+
+ if method != "linear":
+ raise TypeError("Cannot pass interpolation and method keywords!")
+
+ method = interpolation
+
+ if skipna or (skipna is None and self.dtype.kind in "cfO"):
+ _quantile_func = nputils.nanquantile
+ else:
+ _quantile_func = np.quantile
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ scalar = utils.is_scalar(q)
+ q = np.atleast_1d(np.asarray(q, dtype=np.float64))
+
+ if dim is None:
+ dim = self.dims
+
+ if utils.is_scalar(dim):
+ dim = [dim]
+
+ def _wrapper(npa, **kwargs):
+ # move quantile axis to end. required for apply_ufunc
+ return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1)
+
+ axis = np.arange(-1, -1 * len(dim) - 1, -1)
+
+ kwargs = {"q": q, "axis": axis, "method": method}
+
+ result = apply_ufunc(
+ _wrapper,
+ self,
+ input_core_dims=[dim],
+ exclude_dims=set(dim),
+ output_core_dims=[["quantile"]],
+ output_dtypes=[np.float64],
+ dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
+ dask="parallelized",
+ kwargs=kwargs,
+ )
+
+ # for backward compatibility
+ result = result.transpose("quantile", ...)
+ if scalar:
+ result = result.squeeze("quantile")
+ if keep_attrs:
+ result.attrs = self._attrs
+ return result
def rank(self, dim, pct=False):
"""Ranks the data.
@@ -914,10 +1970,33 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
Dataset.rank, DataArray.rank
"""
- pass
-
- def rolling_window(self, dim, window, window_dim, center=False,
- fill_value=dtypes.NA):
+ # This could / should arguably be implemented at the DataArray & Dataset level
+ if not OPTIONS["use_bottleneck"]:
+ raise RuntimeError(
+ "rank requires bottleneck to be enabled."
+ " Call `xr.set_options(use_bottleneck=True)` to enable it."
+ )
+
+ import bottleneck as bn
+
+ func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata
+ ranked = xr.apply_ufunc(
+ func,
+ self,
+ input_core_dims=[[dim]],
+ output_core_dims=[[dim]],
+ dask="parallelized",
+ kwargs=dict(axis=-1),
+ ).transpose(*self.dims)
+
+ if pct:
+ count = self.notnull().sum(dim)
+ ranked /= count
+ return ranked
+
+ def rolling_window(
+ self, dim, window, window_dim, center=False, fill_value=dtypes.NA
+ ):
"""
Make a rolling_window along dim and add a new_dim to the last place.
@@ -972,22 +2051,157 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
[ 5., 6., 7.],
[ 6., 7., nan]]])
"""
- pass
-
- def coarsen(self, windows, func, boundary='exact', side='left',
- keep_attrs=None, **kwargs):
+ if fill_value is dtypes.NA: # np.nan is passed
+ dtype, fill_value = dtypes.maybe_promote(self.dtype)
+ var = duck_array_ops.astype(self, dtype, copy=False)
+ else:
+ dtype = self.dtype
+ var = self
+
+ if utils.is_scalar(dim):
+ for name, arg in zip(
+ ["window", "window_dim", "center"], [window, window_dim, center]
+ ):
+ if not utils.is_scalar(arg):
+ raise ValueError(
+ f"Expected {name}={arg!r} to be a scalar like 'dim'."
+ )
+ dim = (dim,)
+
+ # dim is now a list
+ nroll = len(dim)
+ if utils.is_scalar(window):
+ window = [window] * nroll
+ if utils.is_scalar(window_dim):
+ window_dim = [window_dim] * nroll
+ if utils.is_scalar(center):
+ center = [center] * nroll
+ if (
+ len(dim) != len(window)
+ or len(dim) != len(window_dim)
+ or len(dim) != len(center)
+ ):
+ raise ValueError(
+ "'dim', 'window', 'window_dim', and 'center' must be the same length. "
+ f"Received dim={dim!r}, window={window!r}, window_dim={window_dim!r},"
+ f" and center={center!r}."
+ )
+
+ pads = {}
+ for d, win, cent in zip(dim, window, center):
+ if cent:
+ start = win // 2 # 10 -> 5, 9 -> 4
+ end = win - 1 - start
+ pads[d] = (start, end)
+ else:
+ pads[d] = (win - 1, 0)
+
+ padded = var.pad(pads, mode="constant", constant_values=fill_value)
+ axis = self.get_axis_num(dim)
+ new_dims = self.dims + tuple(window_dim)
+ return Variable(
+ new_dims,
+ duck_array_ops.sliding_window_view(
+ padded.data, window_shape=window, axis=axis
+ ),
+ )
+
+ def coarsen(
+ self, windows, func, boundary="exact", side="left", keep_attrs=None, **kwargs
+ ):
"""
Apply reduction function.
"""
- pass
+ windows = {k: v for k, v in windows.items() if k in self.dims}
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+
+ if keep_attrs:
+ _attrs = self.attrs
+ else:
+ _attrs = None
+
+ if not windows:
+ return self._replace(attrs=_attrs)
+
+ reshaped, axes = self.coarsen_reshape(windows, boundary, side)
+ if isinstance(func, str):
+ name = func
+ func = getattr(duck_array_ops, name, None)
+ if func is None:
+ raise NameError(f"{name} is not a valid method.")
+
+ return self._replace(data=func(reshaped, axis=axes, **kwargs), attrs=_attrs)
def coarsen_reshape(self, windows, boundary, side):
"""
Construct a reshaped-array for coarsen
"""
- pass
+ if not is_dict_like(boundary):
+ boundary = {d: boundary for d in windows.keys()}
+
+ if not is_dict_like(side):
+ side = {d: side for d in windows.keys()}
+
+ # remove unrelated dimensions
+ boundary = {k: v for k, v in boundary.items() if k in windows}
+ side = {k: v for k, v in side.items() if k in windows}
+
+ for d, window in windows.items():
+ if window <= 0:
+ raise ValueError(
+ f"window must be > 0. Given {window} for dimension {d}"
+ )
+
+ variable = self
+ for d, window in windows.items():
+ # trim or pad the object
+ size = variable.shape[self._get_axis_num(d)]
+ n = int(size / window)
+ if boundary[d] == "exact":
+ if n * window != size:
+ raise ValueError(
+ f"Could not coarsen a dimension of size {size} with "
+ f"window {window} and boundary='exact'. Try a different 'boundary' option."
+ )
+ elif boundary[d] == "trim":
+ if side[d] == "left":
+ variable = variable.isel({d: slice(0, window * n)})
+ else:
+ excess = size - window * n
+ variable = variable.isel({d: slice(excess, None)})
+ elif boundary[d] == "pad": # pad
+ pad = window * n - size
+ if pad < 0:
+ pad += window
+ if side[d] == "left":
+ pad_width = {d: (0, pad)}
+ else:
+ pad_width = {d: (pad, 0)}
+ variable = variable.pad(pad_width, mode="constant")
+ else:
+ raise TypeError(
+ f"{boundary[d]} is invalid for boundary. Valid option is 'exact', "
+ "'trim' and 'pad'"
+ )
+
+ shape = []
+ axes = []
+ axis_count = 0
+ for i, d in enumerate(variable.dims):
+ if d in windows:
+ size = variable.shape[i]
+ shape.append(int(size / windows[d]))
+ shape.append(windows[d])
+ axis_count += 1
+ axes.append(i + axis_count)
+ else:
+ shape.append(variable.shape[i])
- def isnull(self, keep_attrs: (bool | None)=None):
+ return duck_array_ops.reshape(variable.data, shape), tuple(axes)
+
+ def isnull(self, keep_attrs: bool | None = None):
"""Test each value in the array for whether it is a missing value.
Returns
@@ -1009,9 +2223,19 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
<xarray.Variable (x: 3)> Size: 3B
array([False, True, False])
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ return apply_ufunc(
+ duck_array_ops.isnull,
+ self,
+ dask="allowed",
+ keep_attrs=keep_attrs,
+ )
- def notnull(self, keep_attrs: (bool | None)=None):
+ def notnull(self, keep_attrs: bool | None = None):
"""Test each value in the array for whether it is not a missing value.
Returns
@@ -1033,10 +2257,20 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
<xarray.Variable (x: 3)> Size: 3B
array([ True, False, True])
"""
- pass
+ from xarray.core.computation import apply_ufunc
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+
+ return apply_ufunc(
+ duck_array_ops.notnull,
+ self,
+ dask="allowed",
+ keep_attrs=keep_attrs,
+ )
@property
- def imag(self) ->Variable:
+ def imag(self) -> Variable:
"""
The imaginary part of the variable.
@@ -1044,10 +2278,10 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
numpy.ndarray.imag
"""
- pass
+ return self._new(data=self.data.imag)
@property
- def real(self) ->Variable:
+ def real(self) -> Variable:
"""
The real part of the variable.
@@ -1055,28 +2289,133 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
numpy.ndarray.real
"""
- pass
+ return self._new(data=self.data.real)
def __array_wrap__(self, obj, context=None):
return Variable(self.dims, obj)
+ def _unary_op(self, f, *args, **kwargs):
+ keep_attrs = kwargs.pop("keep_attrs", None)
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=True)
+ with np.errstate(all="ignore"):
+ result = self.__array_wrap__(f(self.data, *args, **kwargs))
+ if keep_attrs:
+ result.attrs = self.attrs
+ return result
+
+ def _binary_op(self, other, f, reflexive=False):
+ if isinstance(other, (xr.DataArray, xr.Dataset)):
+ return NotImplemented
+ if reflexive and issubclass(type(self), type(other)):
+ other_data, self_data, dims = _broadcast_compat_data(other, self)
+ else:
+ self_data, other_data, dims = _broadcast_compat_data(self, other)
+ keep_attrs = _get_keep_attrs(default=False)
+ attrs = self._attrs if keep_attrs else None
+ with np.errstate(all="ignore"):
+ new_data = (
+ f(self_data, other_data) if not reflexive else f(other_data, self_data)
+ )
+ result = Variable(dims, new_data, attrs=attrs)
+ return result
+
+ def _inplace_binary_op(self, other, f):
+ if isinstance(other, xr.Dataset):
+ raise TypeError("cannot add a Dataset to a Variable in-place")
+ self_data, other_data, dims = _broadcast_compat_data(self, other)
+ if dims != self.dims:
+ raise ValueError("dimensions cannot change for in-place operations")
+ with np.errstate(all="ignore"):
+ self.values = f(self_data, other_data)
+ return self
+
def _to_numeric(self, offset=None, datetime_unit=None, dtype=float):
"""A (private) method to convert datetime array to numeric dtype
See duck_array_ops.datetime_to_numeric
"""
- pass
-
- def _unravel_argminmax(self, argminmax: str, dim: Dims, axis: (int |
- None), keep_attrs: (bool | None), skipna: (bool | None)) ->(Variable |
- dict[Hashable, Variable]):
+ numeric_array = duck_array_ops.datetime_to_numeric(
+ self.data, offset, datetime_unit, dtype
+ )
+ return type(self)(self.dims, numeric_array, self._attrs)
+
+ def _unravel_argminmax(
+ self,
+ argminmax: str,
+ dim: Dims,
+ axis: int | None,
+ keep_attrs: bool | None,
+ skipna: bool | None,
+ ) -> Variable | dict[Hashable, Variable]:
"""Apply argmin or argmax over one or more dimensions, returning the result as a
dict of DataArray that can be passed directly to isel.
"""
- pass
-
- def argmin(self, dim: Dims=None, axis: (int | None)=None, keep_attrs: (
- bool | None)=None, skipna: (bool | None)=None) ->(Variable | dict[
- Hashable, Variable]):
+ if dim is None and axis is None:
+ warnings.warn(
+ "Behaviour of argmin/argmax with neither dim nor axis argument will "
+ "change to return a dict of indices of each dimension. To get a "
+ "single, flat index, please use np.argmin(da.data) or "
+ "np.argmax(da.data) instead of da.argmin() or da.argmax().",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+
+ argminmax_func = getattr(duck_array_ops, argminmax)
+
+ if dim is ...:
+ # In future, should do this also when (dim is None and axis is None)
+ dim = self.dims
+ if (
+ dim is None
+ or axis is not None
+ or not isinstance(dim, Sequence)
+ or isinstance(dim, str)
+ ):
+ # Return int index if single dimension is passed, and is not part of a
+ # sequence
+ return self.reduce(
+ argminmax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna
+ )
+
+ # Get a name for the new dimension that does not conflict with any existing
+ # dimension
+ newdimname = "_unravel_argminmax_dim_0"
+ count = 1
+ while newdimname in self.dims:
+ newdimname = f"_unravel_argminmax_dim_{count}"
+ count += 1
+
+ stacked = self.stack({newdimname: dim})
+
+ result_dims = stacked.dims[:-1]
+ reduce_shape = tuple(self.sizes[d] for d in dim)
+
+ result_flat_indices = stacked.reduce(argminmax_func, axis=-1, skipna=skipna)
+
+ result_unravelled_indices = duck_array_ops.unravel_index(
+ result_flat_indices.data, reduce_shape
+ )
+
+ result = {
+ d: Variable(dims=result_dims, data=i)
+ for d, i in zip(dim, result_unravelled_indices)
+ }
+
+ if keep_attrs is None:
+ keep_attrs = _get_keep_attrs(default=False)
+ if keep_attrs:
+ for v in result.values():
+ v.attrs = self.attrs
+
+ return result
+
+ def argmin(
+ self,
+ dim: Dims = None,
+ axis: int | None = None,
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ ) -> Variable | dict[Hashable, Variable]:
"""Index or indices of the minimum of the Variable over one or more dimensions.
If a sequence is passed to 'dim', then result returned as dict of Variables,
which can be passed directly to isel(). If a single str is passed to 'dim' then
@@ -1113,11 +2452,15 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
DataArray.argmin, DataArray.idxmin
"""
- pass
-
- def argmax(self, dim: Dims=None, axis: (int | None)=None, keep_attrs: (
- bool | None)=None, skipna: (bool | None)=None) ->(Variable | dict[
- Hashable, Variable]):
+ return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna)
+
+ def argmax(
+ self,
+ dim: Dims = None,
+ axis: int | None = None,
+ keep_attrs: bool | None = None,
+ skipna: bool | None = None,
+ ) -> Variable | dict[Hashable, Variable]:
"""Index or indices of the maximum of the Variable over one or more dimensions.
If a sequence is passed to 'dim', then result returned as dict of Variables,
which can be passed directly to isel(). If a single str is passed to 'dim' then
@@ -1154,25 +2497,40 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
--------
DataArray.argmax, DataArray.idxmax
"""
- pass
+ return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna)
- def _as_sparse(self, sparse_format=_default, fill_value=_default
- ) ->Variable:
+ def _as_sparse(self, sparse_format=_default, fill_value=_default) -> Variable:
"""
Use sparse-array as backend.
"""
- pass
+ from xarray.namedarray._typing import _default as _default_named
+
+ if sparse_format is _default:
+ sparse_format = _default_named
+
+ if fill_value is _default:
+ fill_value = _default_named
- def _to_dense(self) ->Variable:
+ out = super()._as_sparse(sparse_format, fill_value)
+ return cast("Variable", out)
+
+ def _to_dense(self) -> Variable:
"""
Change backend from sparse to np.array.
"""
- pass
-
- def chunk(self, chunks: T_Chunks={}, name: (str | None)=None, lock: (
- bool | None)=None, inline_array: (bool | None)=None,
- chunked_array_type: (str | ChunkManagerEntrypoint[Any] | None)=None,
- from_array_kwargs: Any=None, **chunks_kwargs: Any) ->Self:
+ out = super()._to_dense()
+ return cast("Variable", out)
+
+ def chunk( # type: ignore[override]
+ self,
+ chunks: T_Chunks = {},
+ name: str | None = None,
+ lock: bool | None = None,
+ inline_array: bool | None = None,
+ chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None,
+ from_array_kwargs: Any = None,
+ **chunks_kwargs: Any,
+ ) -> Self:
"""Coerce this array's data into a dask array with the given chunks.
If this variable is a non-dask array, it will be converted to dask
@@ -1221,7 +2579,29 @@ class Variable(NamedArray, AbstractArray, VariableArithmetic):
xarray.unify_chunks
dask.array.from_array
"""
- pass
+
+ if is_extension_array_dtype(self):
+ raise ValueError(
+ f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first."
+ )
+
+ if from_array_kwargs is None:
+ from_array_kwargs = {}
+
+ # TODO deprecate passing these dask-specific arguments explicitly. In future just pass everything via from_array_kwargs
+ _from_array_kwargs = consolidate_dask_from_array_kwargs(
+ from_array_kwargs,
+ name=name,
+ lock=lock,
+ inline_array=inline_array,
+ )
+
+ return super().chunk(
+ chunks=chunks,
+ chunked_array_type=chunked_array_type,
+ from_array_kwargs=_from_array_kwargs,
+ **chunks_kwargs,
+ )
class IndexVariable(Variable):
@@ -1234,37 +2614,131 @@ class IndexVariable(Variable):
They also have a name property, which is the name of their sole dimension
unless another name is given.
"""
+
__slots__ = ()
- _data: PandasIndexingAdapter
+
+ # TODO: PandasIndexingAdapter doesn't match the array api:
+ _data: PandasIndexingAdapter # type: ignore[assignment]
def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
super().__init__(dims, data, attrs, encoding, fastpath)
if self.ndim != 1:
- raise ValueError(
- f'{type(self).__name__} objects must be 1-dimensional')
+ raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")
+
+ # Unlike in Variable, always eagerly load values into memory
if not isinstance(self._data, PandasIndexingAdapter):
self._data = PandasIndexingAdapter(self._data)
- def __dask_tokenize__(self) ->object:
+ def __dask_tokenize__(self) -> object:
from dask.base import normalize_token
- return normalize_token((type(self), self._dims, self._data.array,
- self._attrs or None))
+
+ # Don't waste time converting pd.Index to np.ndarray
+ return normalize_token(
+ (type(self), self._dims, self._data.array, self._attrs or None)
+ )
+
+ def load(self):
+ # data is already loaded into memory for IndexVariable
+ return self
+
+ # https://github.com/python/mypy/issues/1465
+ @Variable.data.setter # type: ignore[attr-defined]
+ def data(self, data):
+ raise ValueError(
+ f"Cannot assign to the .data attribute of dimension coordinate a.k.a IndexVariable {self.name!r}. "
+ f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate."
+ )
+
+ @Variable.values.setter # type: ignore[attr-defined]
+ def values(self, values):
+ raise ValueError(
+ f"Cannot assign to the .values attribute of dimension coordinate a.k.a IndexVariable {self.name!r}. "
+ f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate."
+ )
+
+ def chunk(
+ self,
+ chunks={},
+ name=None,
+ lock=False,
+ inline_array=False,
+ chunked_array_type=None,
+ from_array_kwargs=None,
+ ):
+ # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
+ return self.copy(deep=False)
+
+ def _as_sparse(self, sparse_format=_default, fill_value=_default):
+ # Dummy
+ return self.copy(deep=False)
+
+ def _to_dense(self):
+ # Dummy
+ return self.copy(deep=False)
+
+ def _finalize_indexing_result(self, dims, data):
+ if getattr(data, "ndim", 0) != 1:
+ # returns Variable rather than IndexVariable if multi-dimensional
+ return Variable(dims, data, self._attrs, self._encoding)
+ else:
+ return self._replace(dims=dims, data=data)
def __setitem__(self, key, value):
- raise TypeError(f'{type(self).__name__} values cannot be modified')
+ raise TypeError(f"{type(self).__name__} values cannot be modified")
@classmethod
- def concat(cls, variables, dim='concat_dim', positions=None, shortcut=
- False, combine_attrs='override'):
+ def concat(
+ cls,
+ variables,
+ dim="concat_dim",
+ positions=None,
+ shortcut=False,
+ combine_attrs="override",
+ ):
"""Specialized version of Variable.concat for IndexVariable objects.
This exists because we want to avoid converting Index objects to NumPy
arrays, if possible.
"""
- pass
+ from xarray.core.merge import merge_attrs
- def copy(self, deep: bool=True, data: (T_DuckArray | ArrayLike | None)=None
- ):
+ if not isinstance(dim, str):
+ (dim,) = dim.dims
+
+ variables = list(variables)
+ first_var = variables[0]
+
+ if any(not isinstance(v, cls) for v in variables):
+ raise TypeError(
+ "IndexVariable.concat requires that all input "
+ "variables be IndexVariable objects"
+ )
+
+ indexes = [v._data.array for v in variables]
+
+ if not indexes:
+ data = []
+ else:
+ data = indexes[0].append(indexes[1:])
+
+ if positions is not None:
+ indices = nputils.inverse_permutation(np.concatenate(positions))
+ data = data.take(indices)
+
+ # keep as str if possible as pandas.Index uses object (converts to numpy array)
+ data = maybe_coerce_to_str(data, variables)
+
+ attrs = merge_attrs(
+ [var.attrs for var in variables], combine_attrs=combine_attrs
+ )
+ if not shortcut:
+ for var in variables:
+ if var.dims != first_var.dims:
+ raise ValueError("inconsistent dimensions")
+
+ return cls(first_var.dims, data, attrs)
+
+ def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None):
"""Returns a copy of this object.
`deep` is ignored since data is stored in the form of
@@ -1288,27 +2762,122 @@ class IndexVariable(Variable):
New object with dimensions, attributes, encodings, and optionally
data copied from original.
"""
- pass
+ if data is None:
+ ndata = self._data
- def to_index_variable(self) ->IndexVariable:
- """Return this variable as an xarray.IndexVariable"""
- pass
- to_coord = utils.alias(to_index_variable, 'to_coord')
+ if deep:
+ ndata = copy.deepcopy(ndata, None)
+
+ else:
+ ndata = as_compatible_data(data)
+ if self.shape != ndata.shape: # type: ignore[attr-defined]
+ raise ValueError(
+ f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
+ )
+
+ attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)
+ encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding)
+
+ return self._replace(data=ndata, attrs=attrs, encoding=encoding)
- def to_index(self) ->pd.Index:
+ def equals(self, other, equiv=None):
+ # if equiv is specified, super up
+ if equiv is not None:
+ return super().equals(other, equiv)
+
+ # otherwise use the native index equals, rather than looking at _data
+ other = getattr(other, "variable", other)
+ try:
+ return self.dims == other.dims and self._data_equals(other)
+ except (TypeError, AttributeError):
+ return False
+
+ def _data_equals(self, other):
+ return self._to_index().equals(other._to_index())
+
+ def to_index_variable(self) -> IndexVariable:
+ """Return this variable as an xarray.IndexVariable"""
+ return self.copy(deep=False)
+
+ to_coord = utils.alias(to_index_variable, "to_coord")
+
+ def _to_index(self) -> pd.Index:
+ # n.b. creating a new pandas.Index from an old pandas.Index is
+ # basically free as pandas.Index objects are immutable.
+ # n.b.2. this method returns the multi-index instance for
+ # a pandas multi-index level variable.
+ assert self.ndim == 1
+ index = self._data.array
+ if isinstance(index, pd.MultiIndex):
+ # set default names for multi-index unnamed levels so that
+ # we can safely rename dimension / coordinate later
+ valid_level_names = [
+ name or f"{self.dims[0]}_level_{i}"
+ for i, name in enumerate(index.names)
+ ]
+ index = index.set_names(valid_level_names)
+ else:
+ index = index.set_names(self.name)
+ return index
+
+ def to_index(self) -> pd.Index:
"""Convert this variable to a pandas.Index"""
- pass
+ index = self._to_index()
+ level = getattr(self._data, "level", None)
+ if level is not None:
+ # return multi-index level converted to a single index
+ return index.get_level_values(level)
+ else:
+ return index
@property
- def level_names(self) ->(list[str] | None):
+ def level_names(self) -> list[str] | None:
"""Return MultiIndex level names or None if this IndexVariable has no
MultiIndex.
"""
- pass
+ index = self.to_index()
+ if isinstance(index, pd.MultiIndex):
+ return index.names
+ else:
+ return None
def get_level_variable(self, level):
"""Return a new IndexVariable from a given MultiIndex level."""
- pass
+ if self.level_names is None:
+ raise ValueError(f"IndexVariable {self.name!r} has no MultiIndex")
+ index = self.to_index()
+ return type(self)(self.dims, index.get_level_values(level))
+
+ @property
+ def name(self) -> Hashable:
+ return self.dims[0]
+
+ @name.setter
+ def name(self, value) -> NoReturn:
+ raise AttributeError("cannot modify name of IndexVariable in-place")
+
+ def _inplace_binary_op(self, other, f):
+ raise TypeError(
+ "Values of an IndexVariable are immutable and can not be modified inplace"
+ )
+
+
+def _unified_dims(variables):
+ # validate dimensions
+ all_dims = {}
+ for var in variables:
+ var_dims = var.dims
+ _raise_if_any_duplicate_dimensions(var_dims, err_context="Broadcasting")
+
+ for d, s in zip(var_dims, var.shape):
+ if d not in all_dims:
+ all_dims[d] = s
+ elif all_dims[d] != s:
+ raise ValueError(
+ "operands cannot be broadcast together "
+ f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}"
+ )
+ return all_dims
def _broadcast_compat_variables(*variables):
@@ -1317,10 +2886,11 @@ def _broadcast_compat_variables(*variables):
Unlike the result of broadcast_variables(), some variables may have
dimensions of size 1 instead of the size of the broadcast dimension.
"""
- pass
+ dims = tuple(_unified_dims(variables))
+ return tuple(var.set_dims(dims) if var.dims != dims else var for var in variables)
-def broadcast_variables(*variables: Variable) ->tuple[Variable, ...]:
+def broadcast_variables(*variables: Variable) -> tuple[Variable, ...]:
"""Given any number of variables, return variables with matching dimensions
and broadcast data.
@@ -1330,11 +2900,45 @@ def broadcast_variables(*variables: Variable) ->tuple[Variable, ...]:
dimensions are sorted in order of appearance in the first variable's
dimensions followed by the second variable's dimensions.
"""
- pass
+ dims_map = _unified_dims(variables)
+ dims_tuple = tuple(dims_map)
+ return tuple(
+ var.set_dims(dims_map) if var.dims != dims_tuple else var for var in variables
+ )
-def concat(variables, dim='concat_dim', positions=None, shortcut=False,
- combine_attrs='override'):
+def _broadcast_compat_data(self, other):
+ if not OPTIONS["arithmetic_broadcast"]:
+ if (isinstance(other, Variable) and self.dims != other.dims) or (
+ is_duck_array(other) and self.ndim != other.ndim
+ ):
+ raise ValueError(
+ "Broadcasting is necessary but automatic broadcasting is disabled via "
+ "global option `'arithmetic_broadcast'`. "
+ "Use `xr.set_options(arithmetic_broadcast=True)` to enable automatic broadcasting."
+ )
+
+ if all(hasattr(other, attr) for attr in ["dims", "data", "shape", "encoding"]):
+ # `other` satisfies the necessary Variable API for broadcast_variables
+ new_self, new_other = _broadcast_compat_variables(self, other)
+ self_data = new_self.data
+ other_data = new_other.data
+ dims = new_self.dims
+ else:
+ # rely on numpy broadcasting rules
+ self_data = self.data
+ other_data = other
+ dims = self.dims
+ return self_data, other_data, dims
+
+
+def concat(
+ variables,
+ dim="concat_dim",
+ positions=None,
+ shortcut=False,
+ combine_attrs="override",
+):
"""Concatenate variables along a new or existing dimension.
Parameters
@@ -1357,7 +2961,8 @@ def concat(variables, dim='concat_dim', positions=None, shortcut=False,
This option is used internally to speed-up groupby operations.
If `shortcut` is True, some checks of internal consistency between
arrays to concatenate are skipped.
- combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", "override"}, default: "override"
+ combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
+ "override"}, default: "override"
String indicating how to combine attrs of the objects being merged:
- "drop": empty attrs on returned Dataset.
@@ -1375,14 +2980,34 @@ def concat(variables, dim='concat_dim', positions=None, shortcut=False,
Concatenated Variable formed by stacking all the supplied variables
along the given dimension.
"""
- pass
+ variables = list(variables)
+ if all(isinstance(v, IndexVariable) for v in variables):
+ return IndexVariable.concat(variables, dim, positions, shortcut, combine_attrs)
+ else:
+ return Variable.concat(variables, dim, positions, shortcut, combine_attrs)
-def calculate_dimensions(variables: Mapping[Any, Variable]) ->dict[Hashable,
- int]:
+def calculate_dimensions(variables: Mapping[Any, Variable]) -> dict[Hashable, int]:
"""Calculate the dimensions corresponding to a set of variables.
Returns dictionary mapping from dimension names to sizes. Raises ValueError
if any of the dimension sizes conflict.
"""
- pass
+ dims: dict[Hashable, int] = {}
+ last_used = {}
+ scalar_vars = {k for k, v in variables.items() if not v.dims}
+ for k, var in variables.items():
+ for dim, size in zip(var.dims, var.shape):
+ if dim in scalar_vars:
+ raise ValueError(
+ f"dimension {dim!r} already exists as a scalar variable"
+ )
+ if dim not in dims:
+ dims[dim] = size
+ last_used[dim] = k
+ elif dims[dim] != size:
+ raise ValueError(
+ f"conflicting sizes for dimension {dim!r}: "
+ f"length {size} on {k!r} and length {dims[dim]} on {last_used!r}"
+ )
+ return dims
diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py
index b93a079c..8cb90ac1 100644
--- a/xarray/core/weighted.py
+++ b/xarray/core/weighted.py
@@ -1,16 +1,28 @@
from __future__ import annotations
+
from collections.abc import Hashable, Iterable, Sequence
from typing import TYPE_CHECKING, Generic, Literal, cast
+
import numpy as np
from numpy.typing import ArrayLike
+
from xarray.core import duck_array_ops, utils
from xarray.core.alignment import align, broadcast
from xarray.core.computation import apply_ufunc, dot
from xarray.core.types import Dims, T_DataArray, T_Xarray
from xarray.namedarray.utils import is_duck_dask_array
from xarray.util.deprecation_helpers import _deprecate_positional_args
-QUANTILE_METHODS = Literal['linear', 'interpolated_inverted_cdf', 'hazen',
- 'weibull', 'median_unbiased', 'normal_unbiased']
+
+# Weighted quantile methods are a subset of the numpy supported quantile methods.
+QUANTILE_METHODS = Literal[
+ "linear",
+ "interpolated_inverted_cdf",
+ "hazen",
+ "weibull",
+ "median_unbiased",
+ "normal_unbiased",
+]
+
_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
@@ -39,6 +51,7 @@ _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Returns {on_zero} if the ``weights`` sum to 0.0 along the reduced
dimension(s).
"""
+
_SUM_OF_WEIGHTS_DOCSTRING = """
Calculate the sum of weights, accounting for missing values in the data.
@@ -56,6 +69,7 @@ _SUM_OF_WEIGHTS_DOCSTRING = """
reduced : {cls}
New {cls} object with the sum of the weights over the given dimension.
"""
+
_WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """
Apply a weighted ``quantile`` to this {cls}'s data along some dimension(s).
@@ -111,6 +125,8 @@ _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE = """
.. [3] Akinshin, A. (2023) "Weighted quantile estimators" arXiv:2304.07265 [stat.ME]
https://arxiv.org/abs/2304.07265
"""
+
+
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
@@ -127,9 +143,10 @@ class Weighted(Generic[T_Xarray]):
Dataset.weighted
DataArray.weighted
"""
- __slots__ = 'obj', 'weights'
- def __init__(self, obj: T_Xarray, weights: T_DataArray) ->None:
+ __slots__ = ("obj", "weights")
+
+ def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:
"""
Create a Weighted object
@@ -147,85 +164,426 @@ class Weighted(Generic[T_Xarray]):
``weights`` must be a ``DataArray`` and cannot contain missing values.
Missing values can be replaced by ``weights.fillna(0)``.
"""
+
from xarray.core.dataarray import DataArray
+
if not isinstance(weights, DataArray):
- raise ValueError('`weights` must be a DataArray')
+ raise ValueError("`weights` must be a DataArray")
def _weight_check(w):
+ # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
if duck_array_ops.isnull(w).any():
raise ValueError(
- '`weights` cannot contain missing values. Missing values can be replaced by `weights.fillna(0)`.'
- )
+ "`weights` cannot contain missing values. "
+ "Missing values can be replaced by `weights.fillna(0)`."
+ )
return w
+
if is_duck_dask_array(weights.data):
- weights = weights.copy(data=weights.data.map_blocks(
- _weight_check, dtype=weights.dtype), deep=False)
+ # assign to copy - else the check is not triggered
+ weights = weights.copy(
+ data=weights.data.map_blocks(_weight_check, dtype=weights.dtype),
+ deep=False,
+ )
+
else:
_weight_check(weights.data)
+
self.obj: T_Xarray = obj
self.weights: T_DataArray = weights
def _check_dim(self, dim: Dims):
"""raise an error if any dimension is missing"""
- pass
+
+ dims: list[Hashable]
+ if isinstance(dim, str) or not isinstance(dim, Iterable):
+ dims = [dim] if dim else []
+ else:
+ dims = list(dim)
+ all_dims = set(self.obj.dims).union(set(self.weights.dims))
+ missing_dims = set(dims) - all_dims
+ if missing_dims:
+ raise ValueError(
+ f"Dimensions {tuple(missing_dims)} not found in {self.__class__.__name__} dimensions {tuple(all_dims)}"
+ )
@staticmethod
- def _reduce(da: T_DataArray, weights: T_DataArray, dim: Dims=None,
- skipna: (bool | None)=None) ->T_DataArray:
+ def _reduce(
+ da: T_DataArray,
+ weights: T_DataArray,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
for internal use only
"""
- pass
- def _sum_of_weights(self, da: T_DataArray, dim: Dims=None) ->T_DataArray:
+ # need to infer dims as we use `dot`
+ if dim is None:
+ dim = ...
+
+ # need to mask invalid values in da, as `dot` does not implement skipna
+ if skipna or (skipna is None and da.dtype.kind in "cfO"):
+ da = da.fillna(0.0)
+
+ # `dot` does not broadcast arrays, so this avoids creating a large
+ # DataArray (if `weights` has additional dimensions)
+ return dot(da, weights, dim=dim)
+
+ def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray:
"""Calculate the sum of weights, accounting for missing values"""
- pass
- def _sum_of_squares(self, da: T_DataArray, dim: Dims=None, skipna: (
- bool | None)=None) ->T_DataArray:
+ # we need to mask data values that are nan; else the weights are wrong
+ mask = da.notnull()
+
+ # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True
+ # (and not 2); GH4074
+ if self.weights.dtype == bool:
+ sum_of_weights = self._reduce(
+ mask,
+ duck_array_ops.astype(self.weights, dtype=int),
+ dim=dim,
+ skipna=False,
+ )
+ else:
+ sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False)
+
+ # 0-weights are not valid
+ valid_weights = sum_of_weights != 0.0
+
+ return sum_of_weights.where(valid_weights)
+
+ def _sum_of_squares(
+ self,
+ da: T_DataArray,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""
- pass
- def _weighted_sum(self, da: T_DataArray, dim: Dims=None, skipna: (bool |
- None)=None) ->T_DataArray:
+ demeaned = da - da.weighted(self.weights).mean(dim=dim)
+
+ return self._reduce((demeaned**2), self.weights, dim=dim, skipna=skipna)
+
+ def _weighted_sum(
+ self,
+ da: T_DataArray,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""
- pass
- def _weighted_mean(self, da: T_DataArray, dim: Dims=None, skipna: (bool |
- None)=None) ->T_DataArray:
+ return self._reduce(da, self.weights, dim=dim, skipna=skipna)
+
+ def _weighted_mean(
+ self,
+ da: T_DataArray,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""
- pass
- def _weighted_var(self, da: T_DataArray, dim: Dims=None, skipna: (bool |
- None)=None) ->T_DataArray:
+ weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna)
+
+ sum_of_weights = self._sum_of_weights(da, dim=dim)
+
+ return weighted_sum / sum_of_weights
+
+ def _weighted_var(
+ self,
+ da: T_DataArray,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""
- pass
- def _weighted_std(self, da: T_DataArray, dim: Dims=None, skipna: (bool |
- None)=None) ->T_DataArray:
+ sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)
+
+ sum_of_weights = self._sum_of_weights(da, dim=dim)
+
+ return sum_of_squares / sum_of_weights
+
+ def _weighted_std(
+ self,
+ da: T_DataArray,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""
- pass
- def _weighted_quantile(self, da: T_DataArray, q: ArrayLike, dim: Dims=
- None, skipna: (bool | None)=None) ->T_DataArray:
+ return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
+
+ def _weighted_quantile(
+ self,
+ da: T_DataArray,
+ q: ArrayLike,
+ dim: Dims = None,
+ skipna: bool | None = None,
+ ) -> T_DataArray:
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""
- pass
- def __repr__(self) ->str:
+ def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
+ """Return the interpolation parameter."""
+ # Note that options are not yet exposed in the public API.
+ h: np.ndarray
+ if method == "linear":
+ h = (n - 1) * q + 1
+ elif method == "interpolated_inverted_cdf":
+ h = n * q
+ elif method == "hazen":
+ h = n * q + 0.5
+ elif method == "weibull":
+ h = (n + 1) * q
+ elif method == "median_unbiased":
+ h = (n + 1 / 3) * q + 1 / 3
+ elif method == "normal_unbiased":
+ h = (n + 1 / 4) * q + 3 / 8
+ else:
+ raise ValueError(f"Invalid method: {method}.")
+ return h.clip(1, n)
+
+ def _weighted_quantile_1d(
+ data: np.ndarray,
+ weights: np.ndarray,
+ q: np.ndarray,
+ skipna: bool,
+ method: QUANTILE_METHODS = "linear",
+ ) -> np.ndarray:
+ # This algorithm has been adapted from:
+ # https://aakinshin.net/posts/weighted-quantiles/#reference-implementation
+ is_nan = np.isnan(data)
+ if skipna:
+ # Remove nans from data and weights
+ not_nan = ~is_nan
+ data = data[not_nan]
+ weights = weights[not_nan]
+ elif is_nan.any():
+ # Return nan if data contains any nan
+ return np.full(q.size, np.nan)
+
+ # Filter out data (and weights) associated with zero weights, which also flattens them
+ nonzero_weights = weights != 0
+ data = data[nonzero_weights]
+ weights = weights[nonzero_weights]
+ n = data.size
+
+ if n == 0:
+ # Possibly empty after nan or zero weight filtering above
+ return np.full(q.size, np.nan)
+
+ # Kish's effective sample size
+ nw = weights.sum() ** 2 / (weights**2).sum()
+
+ # Sort data and weights
+ sorter = np.argsort(data)
+ data = data[sorter]
+ weights = weights[sorter]
+
+ # Normalize and sum the weights
+ weights = weights / weights.sum()
+ weights_cum = np.append(0, weights.cumsum())
+
+ # Vectorize the computation by transposing q with respect to weights
+ q = np.atleast_2d(q).T
+
+ # Get the interpolation parameter for each q
+ h = _get_h(nw, q, method)
+
+ # Find the samples contributing to the quantile computation (at *positions* between (h-1)/nw and h/nw)
+ u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum))
+
+ # Compute their relative weight
+ v = u * nw - h + 1
+ w = np.diff(v)
+
+ # Apply the weights
+ return (data * w).sum(axis=1)
+
+ if skipna is None and da.dtype.kind in "cfO":
+ skipna = True
+
+ q = np.atleast_1d(np.asarray(q, dtype=np.float64))
+
+ if q.ndim > 1:
+ raise ValueError("q must be a scalar or 1d")
+
+ if np.any((q < 0) | (q > 1)):
+ raise ValueError("q values must be between 0 and 1")
+
+ if dim is None:
+ dim = da.dims
+
+ if utils.is_scalar(dim):
+ dim = [dim]
+
+ # To satisfy mypy
+ dim = cast(Sequence, dim)
+
+ # need to align *and* broadcast
+ # - `_weighted_quantile_1d` requires arrays with the same shape
+ # - broadcast does an outer join, which can introduce NaN to weights
+ # - therefore we first need to do align(..., join="inner")
+
+ # TODO: use broadcast(..., join="inner") once available
+ # see https://github.com/pydata/xarray/issues/6304
+
+ da, weights = align(da, self.weights, join="inner")
+ da, weights = broadcast(da, weights)
+
+ result = apply_ufunc(
+ _weighted_quantile_1d,
+ da,
+ weights,
+ input_core_dims=[dim, dim],
+ output_core_dims=[["quantile"]],
+ output_dtypes=[np.float64],
+ dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}),
+ dask="parallelized",
+ vectorize=True,
+ kwargs={"q": q, "skipna": skipna},
+ )
+
+ result = result.transpose("quantile", ...)
+ result = result.assign_coords(quantile=q).squeeze()
+
+ return result
+
+ def _implementation(self, func, dim, **kwargs):
+ raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`")
+
+ @_deprecate_positional_args("v2023.10.0")
+ def sum_of_weights(
+ self,
+ dim: Dims = None,
+ *,
+ keep_attrs: bool | None = None,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def sum_of_squares(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ @_deprecate_positional_args("v2023.10.0")
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ keep_attrs: bool | None = None,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._weighted_std, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ def quantile(
+ self,
+ q: ArrayLike,
+ *,
+ dim: Dims = None,
+ keep_attrs: bool | None = None,
+ skipna: bool = True,
+ ) -> T_Xarray:
+ return self._implementation(
+ self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs
+ )
+
+ def __repr__(self) -> str:
"""provide a nice str repr of our Weighted object"""
+
klass = self.__class__.__name__
- weight_dims = ', '.join(map(str, self.weights.dims))
- return f'{klass} with weights along dimensions: {weight_dims}'
+ weight_dims = ", ".join(map(str, self.weights.dims))
+ return f"{klass} with weights along dimensions: {weight_dims}"
+
+
+class DataArrayWeighted(Weighted["DataArray"]):
+ def _implementation(self, func, dim, **kwargs) -> DataArray:
+ self._check_dim(dim)
+
+ dataset = self.obj._to_temp_dataset()
+ dataset = dataset.map(func, dim=dim, **kwargs)
+ return self.obj._from_temp_dataset(dataset)
+
+
+class DatasetWeighted(Weighted["Dataset"]):
+ def _implementation(self, func, dim, **kwargs) -> Dataset:
+ self._check_dim(dim)
+
+ return self.obj.map(func, dim=dim, **kwargs)
+
+
+def _inject_docstring(cls, cls_name):
+ cls.sum_of_weights.__doc__ = _SUM_OF_WEIGHTS_DOCSTRING.format(cls=cls_name)
+
+ cls.sum.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="sum", on_zero="0"
+ )
+
+ cls.mean.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="mean", on_zero="NaN"
+ )
+ cls.sum_of_squares.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="sum_of_squares", on_zero="0"
+ )
-class DataArrayWeighted(Weighted['DataArray']):
- pass
+ cls.var.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="var", on_zero="NaN"
+ )
+ cls.std.__doc__ = _WEIGHTED_REDUCE_DOCSTRING_TEMPLATE.format(
+ cls=cls_name, fcn="std", on_zero="NaN"
+ )
-class DatasetWeighted(Weighted['Dataset']):
- pass
+ cls.quantile.__doc__ = _WEIGHTED_QUANTILE_DOCSTRING_TEMPLATE.format(cls=cls_name)
-_inject_docstring(DataArrayWeighted, 'DataArray')
-_inject_docstring(DatasetWeighted, 'Dataset')
+_inject_docstring(DataArrayWeighted, "DataArray")
+_inject_docstring(DatasetWeighted, "Dataset")
diff --git a/xarray/datatree_/docs/source/conf.py b/xarray/datatree_/docs/source/conf.py
index 2084912d..430dbb5b 100644
--- a/xarray/datatree_/docs/source/conf.py
+++ b/xarray/datatree_/docs/source/conf.py
@@ -1,91 +1,412 @@
+# -*- coding: utf-8 -*-
+# flake8: noqa
+# Ignoring F401: imported but unused
+
+# complexity documentation build configuration file, created by
+# sphinx-quickstart on Tue Jul 9 22:26:36 2013.
+#
+# This file is execfile()d with the current directory set to its containing dir.
+#
+# Note that not all possible configuration values are present in this
+# autogenerated file.
+#
+# All configuration values have a default; values that are commented out
+# serve to show the default.
+
import inspect
import os
import sys
-import sphinx_autosummary_accessors
-import datatree
+
+import sphinx_autosummary_accessors # type: ignore
+
+import datatree # type: ignore
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+# sys.path.insert(0, os.path.abspath('.'))
+
cwd = os.getcwd()
parent = os.path.dirname(cwd)
sys.path.insert(0, parent)
-extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode',
- 'sphinx.ext.linkcode', 'sphinx.ext.autosummary',
- 'sphinx.ext.intersphinx', 'sphinx.ext.extlinks', 'sphinx.ext.napoleon',
- 'sphinx_copybutton', 'sphinxext.opengraph',
- 'sphinx_autosummary_accessors',
- 'IPython.sphinxext.ipython_console_highlighting',
- 'IPython.sphinxext.ipython_directive', 'nbsphinx', 'sphinxcontrib.srclinks'
- ]
-extlinks = {'issue': (
- 'https://github.com/xarray-contrib/datatree/issues/%s', 'GH#%s'),
- 'pull': ('https://github.com/xarray-contrib/datatree/pull/%s', 'GH#%s')}
-templates_path = ['_templates', sphinx_autosummary_accessors.templates_path]
+
+
+# -- General configuration -----------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+# needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be extensions
+# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.viewcode",
+ "sphinx.ext.linkcode",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.extlinks",
+ "sphinx.ext.napoleon",
+ "sphinx_copybutton",
+ "sphinxext.opengraph",
+ "sphinx_autosummary_accessors",
+ "IPython.sphinxext.ipython_console_highlighting",
+ "IPython.sphinxext.ipython_directive",
+ "nbsphinx",
+ "sphinxcontrib.srclinks",
+]
+
+extlinks = {
+ "issue": ("https://github.com/xarray-contrib/datatree/issues/%s", "GH#%s"),
+ "pull": ("https://github.com/xarray-contrib/datatree/pull/%s", "GH#%s"),
+}
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates", sphinx_autosummary_accessors.templates_path]
+
+# Generate the API documentation when building
autosummary_generate = True
+
+
+# Napoleon configurations
+
napoleon_google_docstring = False
napoleon_numpy_docstring = True
napoleon_use_param = False
napoleon_use_rtype = False
napoleon_preprocess_types = True
-napoleon_type_aliases = {'sequence': ':term:`sequence`', 'iterable':
- ':term:`iterable`', 'callable': ':py:func:`callable`', 'dict_like':
- ':term:`dict-like <mapping>`', 'dict-like':
- ':term:`dict-like <mapping>`', 'path-like':
- ':term:`path-like <path-like object>`', 'mapping': ':term:`mapping`',
- 'file-like': ':term:`file-like <file-like object>`', 'MutableMapping':
- '~collections.abc.MutableMapping', 'sys.stdout': ':obj:`sys.stdout`',
- 'timedelta': '~datetime.timedelta', 'string': ':class:`string <str>`',
- 'array_like': ':term:`array_like`', 'array-like':
- ':term:`array-like <array_like>`', 'scalar': ':term:`scalar`', 'array':
- ':term:`array`', 'hashable': ':term:`hashable <name>`', 'color-like':
- ':py:func:`color-like <matplotlib.colors.is_color_like>`',
- 'matplotlib colormap name':
- ':doc:`matplotlib colormap name <matplotlib:gallery/color/colormap_reference>`'
- , 'matplotlib axes object':
- ':py:class:`matplotlib axes object <matplotlib.axes.Axes>`', 'colormap':
- ':py:class:`colormap <matplotlib.colors.Colormap>`', 'DataArray':
- '~xarray.DataArray', 'Dataset': '~xarray.Dataset', 'Variable':
- '~xarray.Variable', 'DatasetGroupBy':
- '~xarray.core.groupby.DatasetGroupBy', 'DataArrayGroupBy':
- '~xarray.core.groupby.DataArrayGroupBy', 'ndarray': '~numpy.ndarray',
- 'MaskedArray': '~numpy.ma.MaskedArray', 'dtype': '~numpy.dtype',
- 'ComplexWarning': '~numpy.ComplexWarning', 'Index': '~pandas.Index',
- 'MultiIndex': '~pandas.MultiIndex', 'CategoricalIndex':
- '~pandas.CategoricalIndex', 'TimedeltaIndex': '~pandas.TimedeltaIndex',
- 'DatetimeIndex': '~pandas.DatetimeIndex', 'Series': '~pandas.Series',
- 'DataFrame': '~pandas.DataFrame', 'Categorical': '~pandas.Categorical',
- 'Path': '~~pathlib.Path', 'pd.Index': '~pandas.Index', 'pd.NaT':
- '~pandas.NaT'}
-source_suffix = '.rst'
-master_doc = 'index'
-project = 'Datatree'
-copyright = '2021 onwards, Tom Nicholas and its Contributors'
-author = 'Tom Nicholas'
+napoleon_type_aliases = {
+ # general terms
+ "sequence": ":term:`sequence`",
+ "iterable": ":term:`iterable`",
+ "callable": ":py:func:`callable`",
+ "dict_like": ":term:`dict-like <mapping>`",
+ "dict-like": ":term:`dict-like <mapping>`",
+ "path-like": ":term:`path-like <path-like object>`",
+ "mapping": ":term:`mapping`",
+ "file-like": ":term:`file-like <file-like object>`",
+ # special terms
+ # "same type as caller": "*same type as caller*", # does not work, yet
+ # "same type as values": "*same type as values*", # does not work, yet
+ # stdlib type aliases
+ "MutableMapping": "~collections.abc.MutableMapping",
+ "sys.stdout": ":obj:`sys.stdout`",
+ "timedelta": "~datetime.timedelta",
+ "string": ":class:`string <str>`",
+ # numpy terms
+ "array_like": ":term:`array_like`",
+ "array-like": ":term:`array-like <array_like>`",
+ "scalar": ":term:`scalar`",
+ "array": ":term:`array`",
+ "hashable": ":term:`hashable <name>`",
+ # matplotlib terms
+ "color-like": ":py:func:`color-like <matplotlib.colors.is_color_like>`",
+ "matplotlib colormap name": ":doc:`matplotlib colormap name <matplotlib:gallery/color/colormap_reference>`",
+ "matplotlib axes object": ":py:class:`matplotlib axes object <matplotlib.axes.Axes>`",
+ "colormap": ":py:class:`colormap <matplotlib.colors.Colormap>`",
+ # objects without namespace: xarray
+ "DataArray": "~xarray.DataArray",
+ "Dataset": "~xarray.Dataset",
+ "Variable": "~xarray.Variable",
+ "DatasetGroupBy": "~xarray.core.groupby.DatasetGroupBy",
+ "DataArrayGroupBy": "~xarray.core.groupby.DataArrayGroupBy",
+ # objects without namespace: numpy
+ "ndarray": "~numpy.ndarray",
+ "MaskedArray": "~numpy.ma.MaskedArray",
+ "dtype": "~numpy.dtype",
+ "ComplexWarning": "~numpy.ComplexWarning",
+ # objects without namespace: pandas
+ "Index": "~pandas.Index",
+ "MultiIndex": "~pandas.MultiIndex",
+ "CategoricalIndex": "~pandas.CategoricalIndex",
+ "TimedeltaIndex": "~pandas.TimedeltaIndex",
+ "DatetimeIndex": "~pandas.DatetimeIndex",
+ "Series": "~pandas.Series",
+ "DataFrame": "~pandas.DataFrame",
+ "Categorical": "~pandas.Categorical",
+ "Path": "~~pathlib.Path",
+ # objects with abbreviated namespace (from pandas)
+ "pd.Index": "~pandas.Index",
+ "pd.NaT": "~pandas.NaT",
+}
+
+# The suffix of source filenames.
+source_suffix = ".rst"
+
+# The encoding of source files.
+# source_encoding = 'utf-8-sig'
+
+# The master toctree document.
+master_doc = "index"
+
+# General information about the project.
+project = "Datatree"
+copyright = "2021 onwards, Tom Nicholas and its Contributors"
+author = "Tom Nicholas"
+
html_show_sourcelink = True
-srclink_project = 'https://github.com/xarray-contrib/datatree'
-srclink_branch = 'main'
-srclink_src_path = 'docs/source'
+srclink_project = "https://github.com/xarray-contrib/datatree"
+srclink_branch = "main"
+srclink_src_path = "docs/source"
+
+# The version info for the project you're documenting, acts as replacement for
+# |version| and |release|, also used in various other places throughout the
+# built documents.
+#
+# The short X.Y version.
version = datatree.__version__
+# The full version, including alpha/beta/rc tags.
release = datatree.__version__
-exclude_patterns = ['_build']
-pygments_style = 'sphinx'
-intersphinx_mapping = {'python': ('https://docs.python.org/3.8/', None),
- 'numpy': ('https://numpy.org/doc/stable', None), 'xarray': (
- 'https://xarray.pydata.org/en/stable/', None)}
-html_theme = 'sphinx_book_theme'
-html_theme_options = {'repository_url':
- 'https://github.com/xarray-contrib/datatree', 'repository_branch':
- 'main', 'path_to_docs': 'docs/source', 'use_repository_button': True,
- 'use_issues_button': True, 'use_edit_page_button': True}
-htmlhelp_basename = 'datatree_doc'
-latex_elements: dict = {}
-latex_documents = [('index', 'datatree.tex', 'Datatree Documentation',
- author, 'manual')]
-man_pages = [('index', 'datatree', 'Datatree Documentation', [author], 1)]
-texinfo_documents = [('index', 'datatree', 'Datatree Documentation', author,
- 'datatree', 'Tree-like hierarchical data structure for xarray.',
- 'Miscellaneous')]
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+# language = None
+
+# There are two options for replacing |today|: either, you set today to some
+# non-false value, then it is used:
+# today = ''
+# Else, today_fmt is used as the format for a strftime call.
+# today_fmt = '%B %d, %Y'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+exclude_patterns = ["_build"]
+
+# The reST default role (used for this markup: `text`) to use for all documents.
+# default_role = None
+
+# If true, '()' will be appended to :func: etc. cross-reference text.
+# add_function_parentheses = True
+
+# If true, the current module name will be prepended to all description
+# unit titles (such as .. function::).
+# add_module_names = True
+
+# If true, sectionauthor and moduleauthor directives will be shown in the
+# output. They are ignored by default.
+# show_authors = False
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = "sphinx"
+
+# A list of ignored prefixes for module index sorting.
+# modindex_common_prefix = []
+
+# If true, keep warnings as "system message" paragraphs in the built documents.
+# keep_warnings = False
+
+
+# -- Intersphinx links ---------------------------------------------------------
+
+intersphinx_mapping = {
+ "python": ("https://docs.python.org/3.8/", None),
+ "numpy": ("https://numpy.org/doc/stable", None),
+ "xarray": ("https://xarray.pydata.org/en/stable/", None),
+}
+
+# -- Options for HTML output ---------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+html_theme = "sphinx_book_theme"
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+html_theme_options = {
+ "repository_url": "https://github.com/xarray-contrib/datatree",
+ "repository_branch": "main",
+ "path_to_docs": "docs/source",
+ "use_repository_button": True,
+ "use_issues_button": True,
+ "use_edit_page_button": True,
+}
+
+# Add any paths that contain custom themes here, relative to this directory.
+# html_theme_path = []
+
+# The name for this set of Sphinx documents. If None, it defaults to
+# "<project> v<release> documentation".
+# html_title = None
+
+# A shorter title for the navigation bar. Default is the same as html_title.
+# html_short_title = None
+
+# The name of an image file (relative to this directory) to place at the top
+# of the sidebar.
+# html_logo = None
+
+# The name of an image file (within the static path) to use as favicon of the
+# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
+# pixels large.
+# html_favicon = None
+
+# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
+# using the given strftime format.
+# html_last_updated_fmt = '%b %d, %Y'
+# If true, SmartyPants will be used to convert quotes and dashes to
+# typographically correct entities.
+# html_use_smartypants = True
+
+# Custom sidebar templates, maps document names to template names.
+# html_sidebars = {}
+
+# Additional templates that should be rendered to pages, maps page names to
+# template names.
+# html_additional_pages = {}
+
+# If false, no module index is generated.
+# html_domain_indices = True
+
+# If false, no index is generated.
+# html_use_index = True
+
+# If true, the index is split into individual pages for each letter.
+# html_split_index = False
+
+# If true, links to the reST sources are added to the pages.
+# html_show_sourcelink = True
+
+# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
+# html_show_sphinx = True
+
+# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
+# html_show_copyright = True
+
+# If true, an OpenSearch description file will be output, and all pages will
+# contain a <link> tag referring to it. The value of this option must be the
+# base URL from which the finished HTML is served.
+# html_use_opensearch = ''
+
+# This is the file name suffix for HTML files (e.g. ".xhtml").
+# html_file_suffix = None
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = "datatree_doc"
+
+
+# -- Options for LaTeX output --------------------------------------------------
+
+latex_elements: dict = {
+ # The paper size ('letterpaper' or 'a4paper').
+ # 'papersize': 'letterpaper',
+ # The font size ('10pt', '11pt' or '12pt').
+ # 'pointsize': '10pt',
+ # Additional stuff for the LaTeX preamble.
+ # 'preamble': '',
+}
+
+# Grouping the document tree into LaTeX files. List of tuples
+# (source start file, target name, title, author, documentclass [howto/manual]).
+latex_documents = [
+ ("index", "datatree.tex", "Datatree Documentation", author, "manual")
+]
+
+# The name of an image file (relative to this directory) to place at the top of
+# the title page.
+# latex_logo = None
+
+# For "manual" documents, if this is true, then toplevel headings are parts,
+# not chapters.
+# latex_use_parts = False
+
+# If true, show page references after internal links.
+# latex_show_pagerefs = False
+
+# If true, show URL addresses after external links.
+# latex_show_urls = False
+
+# Documents to append as an appendix to all manuals.
+# latex_appendices = []
+
+# If false, no module index is generated.
+# latex_domain_indices = True
+
+
+# -- Options for manual page output --------------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [("index", "datatree", "Datatree Documentation", [author], 1)]
+
+# If true, show URL addresses after external links.
+# man_show_urls = False
+
+
+# -- Options for Texinfo output ------------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (
+ "index",
+ "datatree",
+ "Datatree Documentation",
+ author,
+ "datatree",
+ "Tree-like hierarchical data structure for xarray.",
+ "Miscellaneous",
+ )
+]
+
+# Documents to append as an appendix to all manuals.
+# texinfo_appendices = []
+
+# If false, no module index is generated.
+# texinfo_domain_indices = True
+
+# How to display URL addresses: 'footnote', 'no', or 'inline'.
+# texinfo_show_urls = 'footnote'
+
+# If true, do not generate a @detailmenu in the "Top" node's menu.
+# texinfo_no_detailmenu = False
+
+
+# based on numpy doc/source/conf.py
def linkcode_resolve(domain, info):
"""
Determine the URL corresponding to Python object
"""
- pass
+ if domain != "py":
+ return None
+
+ modname = info["module"]
+ fullname = info["fullname"]
+
+ submod = sys.modules.get(modname)
+ if submod is None:
+ return None
+
+ obj = submod
+ for part in fullname.split("."):
+ try:
+ obj = getattr(obj, part)
+ except AttributeError:
+ return None
+
+ try:
+ fn = inspect.getsourcefile(inspect.unwrap(obj))
+ except TypeError:
+ fn = None
+ if not fn:
+ return None
+
+ try:
+ source, lineno = inspect.getsourcelines(obj)
+ except OSError:
+ lineno = None
+
+ if lineno:
+ linespec = f"#L{lineno}-L{lineno + len(source) - 1}"
+ else:
+ linespec = ""
+
+ fn = os.path.relpath(fn, start=os.path.dirname(datatree.__file__))
+
+ if "+" in datatree.__version__:
+ return f"https://github.com/xarray-contrib/datatree/blob/main/datatree/{fn}{linespec}"
+ else:
+ return (
+ f"https://github.com/xarray-contrib/datatree/blob/"
+ f"v{datatree.__version__}/datatree/{fn}{linespec}"
+ )
diff --git a/xarray/groupers.py b/xarray/groupers.py
index dcb05c0a..becb005b 100644
--- a/xarray/groupers.py
+++ b/xarray/groupers.py
@@ -3,13 +3,17 @@ This module provides Grouper objects that encapsulate the
"factorization" process - conversion of value we are grouping by
to integer codes (one per group).
"""
+
from __future__ import annotations
+
import datetime
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Literal, cast
+
import numpy as np
import pandas as pd
+
from xarray.coding.cftime_offsets import _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.dataarray import DataArray
@@ -18,9 +22,17 @@ from xarray.core.indexes import safe_cast_to_index
from xarray.core.resample_cftime import CFTimeGrouper
from xarray.core.types import Bins, DatetimeLike, GroupIndices, SideOptions
from xarray.core.variable import Variable
-__all__ = ['EncodedGroups', 'Grouper', 'Resampler', 'UniqueGrouper',
- 'BinGrouper', 'TimeResampler']
-RESAMPLE_DIM = '__resample_dim__'
+
+__all__ = [
+ "EncodedGroups",
+ "Grouper",
+ "Resampler",
+ "UniqueGrouper",
+ "BinGrouper",
+ "TimeResampler",
+]
+
+RESAMPLE_DIM = "__resample_dim__"
@dataclass
@@ -42,6 +54,7 @@ class EncodedGroups:
unique_coord : Variable, optional
Unique group values present in dataset. Inferred if not provided
"""
+
codes: DataArray
full_index: pd.Index
group_indices: GroupIndices | None = field(default=None)
@@ -50,18 +63,19 @@ class EncodedGroups:
def __post_init__(self):
assert isinstance(self.codes, DataArray)
if self.codes.name is None:
- raise ValueError(
- 'Please set a name on the array you are grouping by.')
+ raise ValueError("Please set a name on the array you are grouping by.")
assert isinstance(self.full_index, pd.Index)
- assert isinstance(self.unique_coord, (Variable, _DummyGroup)
- ) or self.unique_coord is None
+ assert (
+ isinstance(self.unique_coord, (Variable, _DummyGroup))
+ or self.unique_coord is None
+ )
class Grouper(ABC):
"""Abstract base class for Grouper objects that allow specializing GroupBy instructions."""
@abstractmethod
- def factorize(self, group: T_Group) ->EncodedGroups:
+ def factorize(self, group: T_Group) -> EncodedGroups:
"""
Creates intermediates necessary for GroupBy.
@@ -83,18 +97,80 @@ class Resampler(Grouper):
Currently only used for TimeResampler, but could be used for SpaceResampler in the future.
"""
+
pass
@dataclass
class UniqueGrouper(Grouper):
"""Grouper object for grouping by a categorical variable."""
+
_group_as_index: pd.Index | None = field(default=None, repr=False)
@property
- def group_as_index(self) ->pd.Index:
+ def group_as_index(self) -> pd.Index:
"""Caches the group DataArray as a pandas Index."""
- pass
+ if self._group_as_index is None:
+ self._group_as_index = self.group.to_index()
+ return self._group_as_index
+
+ def factorize(self, group1d: T_Group) -> EncodedGroups:
+ self.group = group1d
+
+ index = self.group_as_index
+ is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or (
+ index.is_unique
+ and (index.is_monotonic_increasing or index.is_monotonic_decreasing)
+ )
+ is_dimension = self.group.dims == (self.group.name,)
+ can_squeeze = is_dimension and is_unique_and_monotonic
+
+ if can_squeeze:
+ return self._factorize_dummy()
+ else:
+ return self._factorize_unique()
+
+ def _factorize_unique(self) -> EncodedGroups:
+ # look through group to find the unique values
+ sort = not isinstance(self.group_as_index, pd.MultiIndex)
+ unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
+ if (codes_ == -1).all():
+ raise ValueError(
+ "Failed to group data. Are you grouping by a variable that is all NaN?"
+ )
+ codes = self.group.copy(data=codes_)
+ unique_coord = Variable(
+ dims=codes.name, data=unique_values, attrs=self.group.attrs
+ )
+ full_index = pd.Index(unique_values)
+
+ return EncodedGroups(
+ codes=codes, full_index=full_index, unique_coord=unique_coord
+ )
+
+ def _factorize_dummy(self) -> EncodedGroups:
+ size = self.group.size
+ # no need to factorize
+ # use slices to do views instead of fancy indexing
+ # equivalent to: group_indices = group_indices.reshape(-1, 1)
+ group_indices: GroupIndices = tuple(slice(i, i + 1) for i in range(size))
+ size_range = np.arange(size)
+ full_index: pd.Index
+ if isinstance(self.group, _DummyGroup):
+ codes = self.group.to_dataarray().copy(data=size_range)
+ unique_coord = self.group
+ full_index = pd.RangeIndex(self.group.size)
+ else:
+ codes = self.group.copy(data=size_range)
+ unique_coord = self.group.variable.to_base_variable()
+ full_index = pd.Index(unique_coord.data)
+
+ return EncodedGroups(
+ codes=codes,
+ group_indices=group_indices,
+ full_index=full_index,
+ unique_coord=unique_coord,
+ )
@dataclass
@@ -136,16 +212,56 @@ class BinGrouper(Grouper):
duplicates : {"raise", "drop"}, default: "raise"
If bin edges are not unique, raise ValueError or drop non-uniques.
"""
+
bins: Bins
+ # The rest are copied from pandas
right: bool = True
labels: Any = None
precision: int = 3
include_lowest: bool = False
- duplicates: Literal['raise', 'drop'] = 'raise'
+ duplicates: Literal["raise", "drop"] = "raise"
- def __post_init__(self) ->None:
+ def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
- raise ValueError('All bin edges are NaN.')
+ raise ValueError("All bin edges are NaN.")
+
+ def factorize(self, group: T_Group) -> EncodedGroups:
+ from xarray.core.dataarray import DataArray
+
+ data = np.asarray(group.data) # Cast _DummyGroup data to array
+
+ binned, self.bins = pd.cut( # type: ignore [call-overload]
+ data,
+ bins=self.bins,
+ right=self.right,
+ labels=self.labels,
+ precision=self.precision,
+ include_lowest=self.include_lowest,
+ duplicates=self.duplicates,
+ retbins=True,
+ )
+
+ binned_codes = binned.codes
+ if (binned_codes == -1).all():
+ raise ValueError(
+ f"None of the data falls within bins with edges {self.bins!r}"
+ )
+
+ new_dim_name = f"{group.name}_bins"
+
+ full_index = binned.categories
+ uniques = np.sort(pd.unique(binned_codes))
+ unique_values = full_index[uniques[uniques != -1]]
+
+ codes = DataArray(
+ binned_codes, getattr(group, "coords", None), name=new_dim_name
+ )
+ unique_coord = Variable(
+ dims=new_dim_name, data=unique_values, attrs=group.attrs
+ )
+ return EncodedGroups(
+ codes=codes, full_index=full_index, unique_coord=unique_coord
+ )
@dataclass(repr=False)
@@ -176,18 +292,100 @@ class TimeResampler(Resampler):
offset : pd.Timedelta, datetime.timedelta, or str, default is None
An offset timedelta added to the origin.
"""
+
freq: str
closed: SideOptions | None = field(default=None)
label: SideOptions | None = field(default=None)
- origin: str | DatetimeLike = field(default='start_day')
- offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None
- )
+ origin: str | DatetimeLike = field(default="start_day")
+ offset: pd.Timedelta | datetime.timedelta | str | None = field(default=None)
+
index_grouper: CFTimeGrouper | pd.Grouper = field(init=False, repr=False)
group_as_index: pd.Index = field(init=False, repr=False)
+ def _init_properties(self, group: T_Group) -> None:
+ from xarray import CFTimeIndex
+
+ group_as_index = safe_cast_to_index(group)
+ offset = self.offset
+
+ if not group_as_index.is_monotonic_increasing:
+ # TODO: sort instead of raising an error
+ raise ValueError("Index must be monotonic for resampling")
-def unique_value_groups(ar, sort: bool=True) ->tuple[np.ndarray | pd.Index,
- np.ndarray]:
+ if isinstance(group_as_index, CFTimeIndex):
+ from xarray.core.resample_cftime import CFTimeGrouper
+
+ self.index_grouper = CFTimeGrouper(
+ freq=self.freq,
+ closed=self.closed,
+ label=self.label,
+ origin=self.origin,
+ offset=offset,
+ )
+ else:
+ self.index_grouper = pd.Grouper(
+ # TODO remove once requiring pandas >= 2.2
+ freq=_new_to_legacy_freq(self.freq),
+ closed=self.closed,
+ label=self.label,
+ origin=self.origin,
+ offset=offset,
+ )
+ self.group_as_index = group_as_index
+
+ def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]:
+ first_items, codes = self.first_items()
+ full_index = first_items.index
+ if first_items.isnull().any():
+ first_items = first_items.dropna()
+
+ full_index = full_index.rename("__resample_dim__")
+ return full_index, first_items, codes
+
+ def first_items(self) -> tuple[pd.Series, np.ndarray]:
+ from xarray.coding.cftimeindex import CFTimeIndex
+ from xarray.core.resample_cftime import CFTimeGrouper
+
+ if isinstance(self.index_grouper, CFTimeGrouper):
+ return self.index_grouper.first_items(
+ cast(CFTimeIndex, self.group_as_index)
+ )
+ else:
+ s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index)
+ grouped = s.groupby(self.index_grouper)
+ first_items = grouped.first()
+ counts = grouped.count()
+ # This way we generate codes for the final output index: full_index.
+ # So for _flox_reduce we avoid one reindex and copy by avoiding
+ # _maybe_restore_empty_groups
+ codes = np.repeat(np.arange(len(first_items)), counts)
+ return first_items, codes
+
+ def factorize(self, group: T_Group) -> EncodedGroups:
+ self._init_properties(group)
+ full_index, first_items, codes_ = self._get_index_and_items()
+ sbins = first_items.values.astype(np.int64)
+ group_indices: GroupIndices = tuple(
+ [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])]
+ + [slice(sbins[-1], None)]
+ )
+
+ unique_coord = Variable(
+ dims=group.name, data=first_items.index, attrs=group.attrs
+ )
+ codes = group.copy(data=codes_)
+
+ return EncodedGroups(
+ codes=codes,
+ group_indices=group_indices,
+ full_index=full_index,
+ unique_coord=unique_coord,
+ )
+
+
+def unique_value_groups(
+ ar, sort: bool = True
+) -> tuple[np.ndarray | pd.Index, np.ndarray]:
"""Group an array by its unique values.
Parameters
@@ -205,4 +403,7 @@ def unique_value_groups(ar, sort: bool=True) ->tuple[np.ndarray | pd.Index,
Each element provides the integer indices in `ar` with values given by
the corresponding value in `unique_values`.
"""
- pass
+ inverse, values = pd.factorize(ar, sort=sort)
+ if isinstance(values, pd.MultiIndex):
+ values.names = ar.names
+ return values, inverse
diff --git a/xarray/namedarray/_aggregations.py b/xarray/namedarray/_aggregations.py
index c8843d3a..9f58aeb7 100644
--- a/xarray/namedarray/_aggregations.py
+++ b/xarray/namedarray/_aggregations.py
@@ -1,7 +1,12 @@
"""Mixin classes with reduction operations."""
+
+# This file was generated using xarray.util.generate_aggregations. Do not edit manually.
+
from __future__ import annotations
+
from collections.abc import Sequence
from typing import Any, Callable
+
from xarray.core import duck_array_ops
from xarray.core.types import Dims, Self
@@ -9,7 +14,22 @@ from xarray.core.types import Dims, Self
class NamedArrayAggregations:
__slots__ = ()
- def count(self, dim: Dims=None, **kwargs: Any) ->Self:
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ *,
+ axis: int | Sequence[int] | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> Self:
+ raise NotImplementedError()
+
+ def count(
+ self,
+ dim: Dims = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``count`` along some dimension(s).
@@ -53,9 +73,17 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(5)
"""
- pass
-
- def all(self, dim: Dims=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.count,
+ dim=dim,
+ **kwargs,
+ )
+
+ def all(
+ self,
+ dim: Dims = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``all`` along some dimension(s).
@@ -99,9 +127,17 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 1B
array(False)
"""
- pass
-
- def any(self, dim: Dims=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.array_all,
+ dim=dim,
+ **kwargs,
+ )
+
+ def any(
+ self,
+ dim: Dims = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``any`` along some dimension(s).
@@ -145,10 +181,19 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 1B
array(True)
"""
- pass
-
- def max(self, dim: Dims=None, *, skipna: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.array_any,
+ dim=dim,
+ **kwargs,
+ )
+
+ def max(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``max`` along some dimension(s).
@@ -203,10 +248,20 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(nan)
"""
- pass
-
- def min(self, dim: Dims=None, *, skipna: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.max,
+ dim=dim,
+ skipna=skipna,
+ **kwargs,
+ )
+
+ def min(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``min`` along some dimension(s).
@@ -261,10 +316,20 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(nan)
"""
- pass
-
- def mean(self, dim: Dims=None, *, skipna: (bool | None)=None, **kwargs: Any
- ) ->Self:
+ return self.reduce(
+ duck_array_ops.min,
+ dim=dim,
+ skipna=skipna,
+ **kwargs,
+ )
+
+ def mean(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``mean`` along some dimension(s).
@@ -323,10 +388,21 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(nan)
"""
- pass
-
- def prod(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.mean,
+ dim=dim,
+ skipna=skipna,
+ **kwargs,
+ )
+
+ def prod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``prod`` along some dimension(s).
@@ -397,10 +473,22 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(0.)
"""
- pass
-
- def sum(self, dim: Dims=None, *, skipna: (bool | None)=None, min_count:
- (int | None)=None, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.prod,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ **kwargs,
+ )
+
+ def sum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ min_count: int | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``sum`` along some dimension(s).
@@ -471,10 +559,22 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(8.)
"""
- pass
-
- def std(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.sum,
+ dim=dim,
+ skipna=skipna,
+ min_count=min_count,
+ **kwargs,
+ )
+
+ def std(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``std`` along some dimension(s).
@@ -542,10 +642,22 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(1.14017543)
"""
- pass
-
- def var(self, dim: Dims=None, *, skipna: (bool | None)=None, ddof: int=
- 0, **kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.std,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ **kwargs,
+ )
+
+ def var(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ ddof: int = 0,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``var`` along some dimension(s).
@@ -613,10 +725,21 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(1.3)
"""
- pass
-
- def median(self, dim: Dims=None, *, skipna: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.var,
+ dim=dim,
+ skipna=skipna,
+ ddof=ddof,
+ **kwargs,
+ )
+
+ def median(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``median`` along some dimension(s).
@@ -675,10 +798,20 @@ class NamedArrayAggregations:
<xarray.NamedArray ()> Size: 8B
array(nan)
"""
- pass
-
- def cumsum(self, dim: Dims=None, *, skipna: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.median,
+ dim=dim,
+ skipna=skipna,
+ **kwargs,
+ )
+
+ def cumsum(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``cumsum`` along some dimension(s).
@@ -737,10 +870,20 @@ class NamedArrayAggregations:
<xarray.NamedArray (x: 6)> Size: 48B
array([ 1., 3., 6., 6., 8., nan])
"""
- pass
-
- def cumprod(self, dim: Dims=None, *, skipna: (bool | None)=None, **
- kwargs: Any) ->Self:
+ return self.reduce(
+ duck_array_ops.cumsum,
+ dim=dim,
+ skipna=skipna,
+ **kwargs,
+ )
+
+ def cumprod(
+ self,
+ dim: Dims = None,
+ *,
+ skipna: bool | None = None,
+ **kwargs: Any,
+ ) -> Self:
"""
Reduce this NamedArray's data by applying ``cumprod`` along some dimension(s).
@@ -799,4 +942,9 @@ class NamedArrayAggregations:
<xarray.NamedArray (x: 6)> Size: 48B
array([ 1., 2., 6., 0., 0., nan])
"""
- pass
+ return self.reduce(
+ duck_array_ops.cumprod,
+ dim=dim,
+ skipna=skipna,
+ **kwargs,
+ )
diff --git a/xarray/namedarray/_array_api.py b/xarray/namedarray/_array_api.py
index b9c696ab..acbfc8af 100644
--- a/xarray/namedarray/_array_api.py
+++ b/xarray/namedarray/_array_api.py
@@ -1,13 +1,39 @@
from __future__ import annotations
+
from types import ModuleType
from typing import Any
+
import numpy as np
-from xarray.namedarray._typing import Default, _arrayapi, _Axes, _Axis, _default, _Dim, _DType, _ScalarType, _ShapeType, _SupportsImag, _SupportsReal
+
+from xarray.namedarray._typing import (
+ Default,
+ _arrayapi,
+ _Axes,
+ _Axis,
+ _default,
+ _Dim,
+ _DType,
+ _ScalarType,
+ _ShapeType,
+ _SupportsImag,
+ _SupportsReal,
+)
from xarray.namedarray.core import NamedArray
-def astype(x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool=True
- ) ->NamedArray[_ShapeType, _DType]:
+def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
+ if isinstance(x._data, _arrayapi):
+ return x._data.__array_namespace__()
+
+ return np
+
+
+# %% Creation Functions
+
+
+def astype(
+ x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True
+) -> NamedArray[_ShapeType, _DType]:
"""
Copies an array to a specified data type irrespective of Type Promotion Rules rules.
@@ -41,11 +67,20 @@ def astype(x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool=True
<xarray.NamedArray (x: 2)> Size: 8B
array([1, 2], dtype=int32)
"""
- pass
+ if isinstance(x._data, _arrayapi):
+ xp = x._data.__array_namespace__()
+ return x._new(data=xp.astype(x._data, dtype, copy=copy))
+
+ # np.astype doesn't exist yet:
+ return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]
+
+
+# %% Elementwise Functions
-def imag(x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], /
- ) ->NamedArray[_ShapeType, np.dtype[_ScalarType]]:
+def imag(
+ x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var]
+) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the imaginary component of a complex number for each element x_i of the
input array x.
@@ -70,11 +105,14 @@ def imag(x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], /
<xarray.NamedArray (x: 2)> Size: 16B
array([2., 4.])
"""
- pass
+ xp = _get_data_namespace(x)
+ out = x._new(data=xp.imag(x._data))
+ return out
-def real(x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], /
- ) ->NamedArray[_ShapeType, np.dtype[_ScalarType]]:
+def real(
+ x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var]
+) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
"""
Returns the real component of a complex number for each element x_i of the
input array x.
@@ -99,11 +137,19 @@ def real(x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], /
<xarray.NamedArray (x: 2)> Size: 16B
array([1., 2.])
"""
- pass
-
-
-def expand_dims(x: NamedArray[Any, _DType], /, *, dim: (_Dim | Default)=
- _default, axis: _Axis=0) ->NamedArray[Any, _DType]:
+ xp = _get_data_namespace(x)
+ out = x._new(data=xp.real(x._data))
+ return out
+
+
+# %% Manipulation functions
+def expand_dims(
+ x: NamedArray[Any, _DType],
+ /,
+ *,
+ dim: _Dim | Default = _default,
+ axis: _Axis = 0,
+) -> NamedArray[Any, _DType]:
"""
Expands the shape of an array by inserting a new dimension of size one at the
position specified by dims.
@@ -134,11 +180,17 @@ def expand_dims(x: NamedArray[Any, _DType], /, *, dim: (_Dim | Default)=
array([[[1., 2.],
[3., 4.]]])
"""
- pass
+ xp = _get_data_namespace(x)
+ dims = x.dims
+ if dim is _default:
+ dim = f"dim_{len(dims)}"
+ d = list(dims)
+ d.insert(axis, dim)
+ out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis))
+ return out
-def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) ->NamedArray[Any,
- _DType]:
+def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) -> NamedArray[Any, _DType]:
"""
Permutes the dimensions of an array.
@@ -156,4 +208,12 @@ def permute_dims(x: NamedArray[Any, _DType], axes: _Axes) ->NamedArray[Any,
data type as x.
"""
- pass
+
+ dims = x.dims
+ new_dims = tuple(dims[i] for i in axes)
+ if isinstance(x._data, _arrayapi):
+ xp = _get_data_namespace(x)
+ out = x._new(dims=new_dims, data=xp.permute_dims(x._data, axes))
+ else:
+ out = x._new(dims=new_dims, data=x._data.transpose(axes)) # type: ignore[attr-defined]
+ return out
diff --git a/xarray/namedarray/_typing.py b/xarray/namedarray/_typing.py
index 809c3b16..c8d6953f 100644
--- a/xarray/namedarray/_typing.py
+++ b/xarray/namedarray/_typing.py
@@ -1,10 +1,25 @@
from __future__ import annotations
+
import sys
from collections.abc import Hashable, Iterable, Mapping, Sequence
from enum import Enum
from types import ModuleType
-from typing import TYPE_CHECKING, Any, Callable, Final, Literal, Protocol, SupportsIndex, TypeVar, Union, overload, runtime_checkable
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Final,
+ Literal,
+ Protocol,
+ SupportsIndex,
+ TypeVar,
+ Union,
+ overload,
+ runtime_checkable,
+)
+
import numpy as np
+
try:
if sys.version_info >= (3, 11):
from typing import TypeAlias
@@ -17,54 +32,81 @@ except ImportError:
Self: Any = None
+# Singleton type, as per https://github.com/python/typing/pull/240
class Default(Enum):
token: Final = 0
_default = Default.token
-_T = TypeVar('_T')
-_T_co = TypeVar('_T_co', covariant=True)
+
+# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+
_dtype = np.dtype
-_DType = TypeVar('_DType', bound=np.dtype[Any])
-_DType_co = TypeVar('_DType_co', covariant=True, bound=np.dtype[Any])
-_ScalarType = TypeVar('_ScalarType', bound=np.generic)
-_ScalarType_co = TypeVar('_ScalarType_co', bound=np.generic, covariant=True)
+_DType = TypeVar("_DType", bound=np.dtype[Any])
+_DType_co = TypeVar("_DType_co", covariant=True, bound=np.dtype[Any])
+# A subset of `npt.DTypeLike` that can be parametrized w.r.t. `np.generic`
+
+_ScalarType = TypeVar("_ScalarType", bound=np.generic)
+_ScalarType_co = TypeVar("_ScalarType_co", bound=np.generic, covariant=True)
+# A protocol for anything with the dtype attribute
@runtime_checkable
class _SupportsDType(Protocol[_DType_co]):
- pass
+ @property
+ def dtype(self) -> _DType_co: ...
-_DTypeLike = Union[np.dtype[_ScalarType], type[_ScalarType], _SupportsDType
- [np.dtype[_ScalarType]]]
+_DTypeLike = Union[
+ np.dtype[_ScalarType],
+ type[_ScalarType],
+ _SupportsDType[np.dtype[_ScalarType]],
+]
+
+# For unknown shapes Dask uses np.nan, array_api uses None:
_IntOrUnknown = int
_Shape = tuple[_IntOrUnknown, ...]
_ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]]
-_ShapeType = TypeVar('_ShapeType', bound=Any)
-_ShapeType_co = TypeVar('_ShapeType_co', bound=Any, covariant=True)
+_ShapeType = TypeVar("_ShapeType", bound=Any)
+_ShapeType_co = TypeVar("_ShapeType_co", bound=Any, covariant=True)
+
_Axis = int
_Axes = tuple[_Axis, ...]
_AxisLike = Union[_Axis, _Axes]
+
_Chunks = tuple[_Shape, ...]
_NormalizedChunks = tuple[tuple[int, ...], ...]
-T_ChunkDim: TypeAlias = Union[int, Literal['auto'], None, tuple[int, ...]]
+# FYI in some cases we don't allow `None`, which this doesn't take account of.
+T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]
+# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]]
+
_Dim = Hashable
_Dims = tuple[_Dim, ...]
+
_DimsLike = Union[str, Iterable[_Dim]]
-_IndexKey = Union[int, slice, 'ellipsis']
-_IndexKeys = tuple[Union[_IndexKey], ...]
+
+# https://data-apis.org/array-api/latest/API_specification/indexing.html
+# TODO: np.array_api was bugged and didn't allow (None,), but should!
+# https://github.com/numpy/numpy/pull/25022
+# https://github.com/data-apis/array-api/pull/674
+_IndexKey = Union[int, slice, "ellipsis"]
+_IndexKeys = tuple[Union[_IndexKey], ...] # tuple[Union[_IndexKey, None], ...]
_IndexKeyLike = Union[_IndexKey, _IndexKeys]
+
_AttrsLike = Union[Mapping[Any, Any], None]
class _SupportsReal(Protocol[_T_co]):
- pass
+ @property
+ def real(self) -> _T_co: ...
class _SupportsImag(Protocol[_T_co]):
- pass
+ @property
+ def imag(self) -> _T_co: ...
@runtime_checkable
@@ -75,10 +117,17 @@ class _array(Protocol[_ShapeType_co, _DType_co]):
Corresponds to np.ndarray.
"""
+ @property
+ def shape(self) -> _Shape: ...
+
+ @property
+ def dtype(self) -> _DType_co: ...
+
@runtime_checkable
-class _arrayfunction(_array[_ShapeType_co, _DType_co], Protocol[
- _ShapeType_co, _DType_co]):
+class _arrayfunction(
+ _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Duck array supporting NEP 18.
@@ -86,131 +135,188 @@ class _arrayfunction(_array[_ShapeType_co, _DType_co], Protocol[
"""
@overload
- def __getitem__(self, key: (_arrayfunction[Any, Any] | tuple[
- _arrayfunction[Any, Any], ...]), /) ->_arrayfunction[Any, _DType_co]:
- ...
+ def __getitem__(
+ self, key: _arrayfunction[Any, Any] | tuple[_arrayfunction[Any, Any], ...], /
+ ) -> _arrayfunction[Any, _DType_co]: ...
@overload
- def __getitem__(self, key: _IndexKeyLike, /) ->Any:
- ...
-
- def __getitem__(self, key: (_IndexKeyLike | _arrayfunction[Any, Any] |
- tuple[_arrayfunction[Any, Any], ...]), /) ->(_arrayfunction[Any,
- _DType_co] | Any):
- ...
+ def __getitem__(self, key: _IndexKeyLike, /) -> Any: ...
+
+ def __getitem__(
+ self,
+ key: (
+ _IndexKeyLike
+ | _arrayfunction[Any, Any]
+ | tuple[_arrayfunction[Any, Any], ...]
+ ),
+ /,
+ ) -> _arrayfunction[Any, _DType_co] | Any: ...
@overload
- def __array__(self, dtype: None=..., /, *, copy: (None | bool)=...
- ) ->np.ndarray[Any, _DType_co]:
- ...
-
+ def __array__(
+ self, dtype: None = ..., /, *, copy: None | bool = ...
+ ) -> np.ndarray[Any, _DType_co]: ...
@overload
- def __array__(self, dtype: _DType, /, *, copy: (None | bool)=...
- ) ->np.ndarray[Any, _DType]:
- ...
-
- def __array__(self, dtype: (_DType | None)=..., /, *, copy: (None |
- bool)=...) ->(np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]):
- ...
-
- def __array_ufunc__(self, ufunc: Any, method: Any, *inputs: Any, **
- kwargs: Any) ->Any:
- ...
-
- def __array_function__(self, func: Callable[..., Any], types: Iterable[
- type], args: Iterable[Any], kwargs: Mapping[str, Any]) ->Any:
- ...
+ def __array__(
+ self, dtype: _DType, /, *, copy: None | bool = ...
+ ) -> np.ndarray[Any, _DType]: ...
+
+ def __array__(
+ self, dtype: _DType | None = ..., /, *, copy: None | bool = ...
+ ) -> np.ndarray[Any, _DType] | np.ndarray[Any, _DType_co]: ...
+
+ # TODO: Should return the same subclass but with a new dtype generic.
+ # https://github.com/python/typing/issues/548
+ def __array_ufunc__(
+ self,
+ ufunc: Any,
+ method: Any,
+ *inputs: Any,
+ **kwargs: Any,
+ ) -> Any: ...
+
+ # TODO: Should return the same subclass but with a new dtype generic.
+ # https://github.com/python/typing/issues/548
+ def __array_function__(
+ self,
+ func: Callable[..., Any],
+ types: Iterable[type],
+ args: Iterable[Any],
+ kwargs: Mapping[str, Any],
+ ) -> Any: ...
+
+ @property
+ def imag(self) -> _arrayfunction[_ShapeType_co, Any]: ...
+
+ @property
+ def real(self) -> _arrayfunction[_ShapeType_co, Any]: ...
@runtime_checkable
-class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co,
- _DType_co]):
+class _arrayapi(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]):
"""
Duck array supporting NEP 47.
Corresponds to np.ndarray.
"""
- def __getitem__(self, key: (_IndexKeyLike | Any), /) ->_arrayapi[Any, Any]:
- ...
+ def __getitem__(
+ self,
+ key: (
+ _IndexKeyLike | Any
+ ), # TODO: Any should be _arrayapi[Any, _dtype[np.integer]]
+ /,
+ ) -> _arrayapi[Any, Any]: ...
+
+ def __array_namespace__(self) -> ModuleType: ...
+
- def __array_namespace__(self) ->ModuleType:
- ...
+# NamedArray can most likely use both __array_function__ and __array_namespace__:
+_arrayfunction_or_api = (_arrayfunction, _arrayapi)
+duckarray = Union[
+ _arrayfunction[_ShapeType_co, _DType_co], _arrayapi[_ShapeType_co, _DType_co]
+]
-_arrayfunction_or_api = _arrayfunction, _arrayapi
-duckarray = Union[_arrayfunction[_ShapeType_co, _DType_co], _arrayapi[
- _ShapeType_co, _DType_co]]
+# Corresponds to np.typing.NDArray:
DuckArray = _arrayfunction[Any, np.dtype[_ScalarType_co]]
@runtime_checkable
-class _chunkedarray(_array[_ShapeType_co, _DType_co], Protocol[
- _ShapeType_co, _DType_co]):
+class _chunkedarray(
+ _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Minimal chunked duck array.
Corresponds to np.ndarray.
"""
+ @property
+ def chunks(self) -> _Chunks: ...
+
@runtime_checkable
-class _chunkedarrayfunction(_arrayfunction[_ShapeType_co, _DType_co],
- Protocol[_ShapeType_co, _DType_co]):
+class _chunkedarrayfunction(
+ _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Chunked duck array supporting NEP 18.
Corresponds to np.ndarray.
"""
+ @property
+ def chunks(self) -> _Chunks: ...
+
@runtime_checkable
-class _chunkedarrayapi(_arrayapi[_ShapeType_co, _DType_co], Protocol[
- _ShapeType_co, _DType_co]):
+class _chunkedarrayapi(
+ _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Chunked duck array supporting NEP 47.
Corresponds to np.ndarray.
"""
+ @property
+ def chunks(self) -> _Chunks: ...
-_chunkedarrayfunction_or_api = _chunkedarrayfunction, _chunkedarrayapi
-chunkedduckarray = Union[_chunkedarrayfunction[_ShapeType_co, _DType_co],
- _chunkedarrayapi[_ShapeType_co, _DType_co]]
+
+# NamedArray can most likely use both __array_function__ and __array_namespace__:
+_chunkedarrayfunction_or_api = (_chunkedarrayfunction, _chunkedarrayapi)
+chunkedduckarray = Union[
+ _chunkedarrayfunction[_ShapeType_co, _DType_co],
+ _chunkedarrayapi[_ShapeType_co, _DType_co],
+]
@runtime_checkable
-class _sparsearray(_array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co,
- _DType_co]):
+class _sparsearray(
+ _array[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Minimal sparse duck array.
Corresponds to np.ndarray.
"""
+ def todense(self) -> np.ndarray[Any, _DType_co]: ...
+
@runtime_checkable
-class _sparsearrayfunction(_arrayfunction[_ShapeType_co, _DType_co],
- Protocol[_ShapeType_co, _DType_co]):
+class _sparsearrayfunction(
+ _arrayfunction[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Sparse duck array supporting NEP 18.
Corresponds to np.ndarray.
"""
+ def todense(self) -> np.ndarray[Any, _DType_co]: ...
+
@runtime_checkable
-class _sparsearrayapi(_arrayapi[_ShapeType_co, _DType_co], Protocol[
- _ShapeType_co, _DType_co]):
+class _sparsearrayapi(
+ _arrayapi[_ShapeType_co, _DType_co], Protocol[_ShapeType_co, _DType_co]
+):
"""
Sparse duck array supporting NEP 47.
Corresponds to np.ndarray.
"""
+ def todense(self) -> np.ndarray[Any, _DType_co]: ...
+
+
+# NamedArray can most likely use both __array_function__ and __array_namespace__:
+_sparsearrayfunction_or_api = (_sparsearrayfunction, _sparsearrayapi)
+sparseduckarray = Union[
+ _sparsearrayfunction[_ShapeType_co, _DType_co],
+ _sparsearrayapi[_ShapeType_co, _DType_co],
+]
-_sparsearrayfunction_or_api = _sparsearrayfunction, _sparsearrayapi
-sparseduckarray = Union[_sparsearrayfunction[_ShapeType_co, _DType_co],
- _sparsearrayapi[_ShapeType_co, _DType_co]]
-ErrorOptions = Literal['raise', 'ignore']
-ErrorOptionsWithWarn = Literal['raise', 'warn', 'ignore']
+ErrorOptions = Literal["raise", "ignore"]
+ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"]
diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py
index 881d23bb..e9668d89 100644
--- a/xarray/namedarray/core.py
+++ b/xarray/namedarray/core.py
@@ -1,44 +1,125 @@
from __future__ import annotations
+
import copy
import math
import sys
import warnings
from collections.abc import Hashable, Iterable, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, cast, overload
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Generic,
+ Literal,
+ TypeVar,
+ cast,
+ overload,
+)
+
import numpy as np
+
+# TODO: get rid of this after migrating this class to array API
from xarray.core import dtypes, formatting, formatting_html
-from xarray.core.indexing import ExplicitlyIndexed, ImplicitToExplicitIndexingAdapter, OuterIndexer
+from xarray.core.indexing import (
+ ExplicitlyIndexed,
+ ImplicitToExplicitIndexingAdapter,
+ OuterIndexer,
+)
from xarray.namedarray._aggregations import NamedArrayAggregations
-from xarray.namedarray._typing import ErrorOptionsWithWarn, _arrayapi, _arrayfunction_or_api, _chunkedarray, _default, _dtype, _DType_co, _ScalarType_co, _ShapeType_co, _sparsearrayfunction_or_api, _SupportsImag, _SupportsReal
+from xarray.namedarray._typing import (
+ ErrorOptionsWithWarn,
+ _arrayapi,
+ _arrayfunction_or_api,
+ _chunkedarray,
+ _default,
+ _dtype,
+ _DType_co,
+ _ScalarType_co,
+ _ShapeType_co,
+ _sparsearrayfunction_or_api,
+ _SupportsImag,
+ _SupportsReal,
+)
from xarray.namedarray.parallelcompat import guess_chunkmanager
from xarray.namedarray.pycompat import to_numpy
-from xarray.namedarray.utils import either_dict_or_kwargs, infix_dims, is_dict_like, is_duck_dask_array, to_0d_object_array
+from xarray.namedarray.utils import (
+ either_dict_or_kwargs,
+ infix_dims,
+ is_dict_like,
+ is_duck_dask_array,
+ to_0d_object_array,
+)
+
if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray
+
from xarray.core.types import Dims, T_Chunks
- from xarray.namedarray._typing import Default, _AttrsLike, _Chunks, _Dim, _Dims, _DimsLike, _DType, _IntOrUnknown, _ScalarType, _Shape, _ShapeType, duckarray
+ from xarray.namedarray._typing import (
+ Default,
+ _AttrsLike,
+ _Chunks,
+ _Dim,
+ _Dims,
+ _DimsLike,
+ _DType,
+ _IntOrUnknown,
+ _ScalarType,
+ _Shape,
+ _ShapeType,
+ duckarray,
+ )
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
+
try:
- from dask.typing import Graph, NestedKeys, PostComputeCallable, PostPersistCallable, SchedulerGetCallable
+ from dask.typing import (
+ Graph,
+ NestedKeys,
+ PostComputeCallable,
+ PostPersistCallable,
+ SchedulerGetCallable,
+ )
except ImportError:
- Graph: Any
- NestedKeys: Any
- SchedulerGetCallable: Any
- PostComputeCallable: Any
- PostPersistCallable: Any
+ Graph: Any # type: ignore[no-redef]
+ NestedKeys: Any # type: ignore[no-redef]
+ SchedulerGetCallable: Any # type: ignore[no-redef]
+ PostComputeCallable: Any # type: ignore[no-redef]
+ PostPersistCallable: Any # type: ignore[no-redef]
+
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
- T_NamedArray = TypeVar('T_NamedArray', bound='_NamedArray[Any]')
- T_NamedArrayInteger = TypeVar('T_NamedArrayInteger', bound=
- '_NamedArray[np.integer[Any]]')
+
+ T_NamedArray = TypeVar("T_NamedArray", bound="_NamedArray[Any]")
+ T_NamedArrayInteger = TypeVar(
+ "T_NamedArrayInteger", bound="_NamedArray[np.integer[Any]]"
+ )
+
+
+@overload
+def _new(
+ x: NamedArray[Any, _DType_co],
+ dims: _DimsLike | Default = ...,
+ data: duckarray[_ShapeType, _DType] = ...,
+ attrs: _AttrsLike | Default = ...,
+) -> NamedArray[_ShapeType, _DType]: ...
+
+
+@overload
+def _new(
+ x: NamedArray[_ShapeType_co, _DType_co],
+ dims: _DimsLike | Default = ...,
+ data: Default = ...,
+ attrs: _AttrsLike | Default = ...,
+) -> NamedArray[_ShapeType_co, _DType_co]: ...
-def _new(x: NamedArray[Any, _DType_co], dims: (_DimsLike | Default)=
- _default, data: (duckarray[_ShapeType, _DType] | Default)=_default,
- attrs: (_AttrsLike | Default)=_default) ->(NamedArray[_ShapeType,
- _DType] | NamedArray[Any, _DType_co]):
+def _new(
+ x: NamedArray[Any, _DType_co],
+ dims: _DimsLike | Default = _default,
+ data: duckarray[_ShapeType, _DType] | Default = _default,
+ attrs: _AttrsLike | Default = _default,
+) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, _DType_co]:
"""
Create a new array with new typing information.
@@ -58,12 +139,42 @@ def _new(x: NamedArray[Any, _DType_co], dims: (_DimsLike | Default)=
attributes you want to store with the array.
Will copy the attrs from x by default.
"""
- pass
+ dims_ = copy.copy(x._dims) if dims is _default else dims
+
+ attrs_: Mapping[Any, Any] | None
+ if attrs is _default:
+ attrs_ = None if x._attrs is None else x._attrs.copy()
+ else:
+ attrs_ = attrs
+
+ if data is _default:
+ return type(x)(dims_, copy.copy(x._data), attrs_)
+ else:
+ cls_ = cast("type[NamedArray[_ShapeType, _DType]]", type(x))
+ return cls_(dims_, data, attrs_)
+
+
+@overload
+def from_array(
+ dims: _DimsLike,
+ data: duckarray[_ShapeType, _DType],
+ attrs: _AttrsLike = ...,
+) -> NamedArray[_ShapeType, _DType]: ...
-def from_array(dims: _DimsLike, data: (duckarray[_ShapeType, _DType] |
- ArrayLike), attrs: _AttrsLike=None) ->(NamedArray[_ShapeType, _DType] |
- NamedArray[Any, Any]):
+@overload
+def from_array(
+ dims: _DimsLike,
+ data: ArrayLike,
+ attrs: _AttrsLike = ...,
+) -> NamedArray[Any, Any]: ...
+
+
+def from_array(
+ dims: _DimsLike,
+ data: duckarray[_ShapeType, _DType] | ArrayLike,
+ attrs: _AttrsLike = None,
+) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]:
"""
Create a Named array from an array-like object.
@@ -79,7 +190,29 @@ def from_array(dims: _DimsLike, data: (duckarray[_ShapeType, _DType] |
attributes you want to store with the array.
Default is None, meaning no attributes will be stored.
"""
- pass
+ if isinstance(data, NamedArray):
+ raise TypeError(
+ "Array is already a Named array. Use 'data.data' to retrieve the data array"
+ )
+
+ # TODO: dask.array.ma.MaskedArray also exists, better way?
+ if isinstance(data, np.ma.MaskedArray):
+ mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call]
+ if mask.any():
+ # TODO: requires refactoring/vendoring xarray.core.dtypes and
+ # xarray.core.duck_array_ops
+ raise NotImplementedError("MaskedArray is not supported yet")
+
+ return NamedArray(dims, data, attrs)
+
+ if isinstance(data, _arrayfunction_or_api):
+ return NamedArray(dims, data, attrs)
+
+ if isinstance(data, tuple):
+ return NamedArray(dims, to_0d_object_array(data), attrs)
+
+ # validate whether the data is valid data types.
+ return NamedArray(dims, np.asarray(data), attrs)
class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
@@ -114,27 +247,54 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
>>> data = np.array([1.5, 2, 3], dtype=float)
>>> narr = NamedArray(("x",), data, {"units": "m"}) # TODO: Better name than narr?
"""
- __slots__ = '_data', '_dims', '_attrs'
+
+ __slots__ = ("_data", "_dims", "_attrs")
+
_data: duckarray[Any, _DType_co]
_dims: _Dims
_attrs: dict[Any, Any] | None
- def __init__(self, dims: _DimsLike, data: duckarray[Any, _DType_co],
- attrs: _AttrsLike=None):
+ def __init__(
+ self,
+ dims: _DimsLike,
+ data: duckarray[Any, _DType_co],
+ attrs: _AttrsLike = None,
+ ):
self._data = data
self._dims = self._parse_dimensions(dims)
self._attrs = dict(attrs) if attrs else None
- def __init_subclass__(cls, **kwargs: Any) ->None:
- if NamedArray in cls.__bases__ and cls._new == NamedArray._new:
+ def __init_subclass__(cls, **kwargs: Any) -> None:
+ if NamedArray in cls.__bases__ and (cls._new == NamedArray._new):
+ # Type hinting does not work for subclasses unless _new is
+ # overridden with the correct class.
raise TypeError(
- 'Subclasses of `NamedArray` must override the `_new` method.')
+ "Subclasses of `NamedArray` must override the `_new` method."
+ )
super().__init_subclass__(**kwargs)
- def _new(self, dims: (_DimsLike | Default)=_default, data: (duckarray[
- Any, _DType] | Default)=_default, attrs: (_AttrsLike | Default)=
- _default) ->(NamedArray[_ShapeType, _DType] | NamedArray[
- _ShapeType_co, _DType_co]):
+ @overload
+ def _new(
+ self,
+ dims: _DimsLike | Default = ...,
+ data: duckarray[_ShapeType, _DType] = ...,
+ attrs: _AttrsLike | Default = ...,
+ ) -> NamedArray[_ShapeType, _DType]: ...
+
+ @overload
+ def _new(
+ self,
+ dims: _DimsLike | Default = ...,
+ data: Default = ...,
+ attrs: _AttrsLike | Default = ...,
+ ) -> NamedArray[_ShapeType_co, _DType_co]: ...
+
+ def _new(
+ self,
+ dims: _DimsLike | Default = _default,
+ data: duckarray[Any, _DType] | Default = _default,
+ attrs: _AttrsLike | Default = _default,
+ ) -> NamedArray[_ShapeType, _DType] | NamedArray[_ShapeType_co, _DType_co]:
"""
Create a new array with new typing information.
@@ -156,11 +316,14 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
attributes you want to store with the array.
Will copy the attrs from x by default.
"""
- pass
-
- def _replace(self, dims: (_DimsLike | Default)=_default, data: (
- duckarray[_ShapeType_co, _DType_co] | Default)=_default, attrs: (
- _AttrsLike | Default)=_default) ->Self:
+ return _new(self, dims, data, attrs)
+
+ def _replace(
+ self,
+ dims: _DimsLike | Default = _default,
+ data: duckarray[_ShapeType_co, _DType_co] | Default = _default,
+ attrs: _AttrsLike | Default = _default,
+ ) -> Self:
"""
Create a new array with the same typing information.
@@ -181,16 +344,39 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
attributes you want to store with the array.
Will copy the attrs from x by default.
"""
- pass
+ return cast("Self", self._new(dims, data, attrs))
+
+ def _copy(
+ self,
+ deep: bool = True,
+ data: duckarray[_ShapeType_co, _DType_co] | None = None,
+ memo: dict[int, Any] | None = None,
+ ) -> Self:
+ if data is None:
+ ndata = self._data
+ if deep:
+ ndata = copy.deepcopy(ndata, memo=memo)
+ else:
+ ndata = data
+ self._check_shape(ndata)
- def __copy__(self) ->Self:
+ attrs = (
+ copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs)
+ )
+
+ return self._replace(data=ndata, attrs=attrs)
+
+ def __copy__(self) -> Self:
return self._copy(deep=False)
- def __deepcopy__(self, memo: (dict[int, Any] | None)=None) ->Self:
+ def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
return self._copy(deep=True, memo=memo)
- def copy(self, deep: bool=True, data: (duckarray[_ShapeType_co,
- _DType_co] | None)=None) ->Self:
+ def copy(
+ self,
+ deep: bool = True,
+ data: duckarray[_ShapeType_co, _DType_co] | None = None,
+ ) -> Self:
"""Returns a copy of this object.
If `deep=True`, the data array is loaded into memory and copied onto
@@ -216,10 +402,10 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
"""
- pass
+ return self._copy(deep=deep, data=data)
@property
- def ndim(self) ->int:
+ def ndim(self) -> int:
"""
Number of array dimensions.
@@ -227,10 +413,10 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
--------
numpy.ndarray.ndim
"""
- pass
+ return len(self.shape)
@property
- def size(self) ->_IntOrUnknown:
+ def size(self) -> _IntOrUnknown:
"""
Number of elements in the array.
@@ -240,16 +426,16 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
--------
numpy.ndarray.size
"""
- pass
+ return math.prod(self.shape)
- def __len__(self) ->_IntOrUnknown:
+ def __len__(self) -> _IntOrUnknown:
try:
return self.shape[0]
except Exception as exc:
- raise TypeError('len() of unsized object') from exc
+ raise TypeError("len() of unsized object") from exc
@property
- def dtype(self) ->_DType_co:
+ def dtype(self) -> _DType_co:
"""
Data-type of the array’s elements.
@@ -258,10 +444,10 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
ndarray.dtype
numpy.dtype
"""
- pass
+ return self._data.dtype
@property
- def shape(self) ->_Shape:
+ def shape(self) -> _Shape:
"""
Get the shape of the array.
@@ -274,40 +460,103 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
--------
numpy.ndarray.shape
"""
- pass
+ return self._data.shape
@property
- def nbytes(self) ->_IntOrUnknown:
+ def nbytes(self) -> _IntOrUnknown:
"""
Total bytes consumed by the elements of the data array.
If the underlying data array does not include ``nbytes``, estimates
the bytes consumed based on the ``size`` and ``dtype``.
"""
- pass
+ from xarray.namedarray._array_api import _get_data_namespace
+
+ if hasattr(self._data, "nbytes"):
+ return self._data.nbytes # type: ignore[no-any-return]
+
+ if hasattr(self.dtype, "itemsize"):
+ itemsize = self.dtype.itemsize
+ elif isinstance(self._data, _arrayapi):
+ xp = _get_data_namespace(self)
+
+ if xp.isdtype(self.dtype, "bool"):
+ itemsize = 1
+ elif xp.isdtype(self.dtype, "integral"):
+ itemsize = xp.iinfo(self.dtype).bits // 8
+ else:
+ itemsize = xp.finfo(self.dtype).bits // 8
+ else:
+ raise TypeError(
+ "cannot compute the number of bytes (no array API nor nbytes / itemsize)"
+ )
+
+ return self.size * itemsize
@property
- def dims(self) ->_Dims:
+ def dims(self) -> _Dims:
"""Tuple of dimension names with which this NamedArray is associated."""
- pass
+ return self._dims
+
+ @dims.setter
+ def dims(self, value: _DimsLike) -> None:
+ self._dims = self._parse_dimensions(value)
+
+ def _parse_dimensions(self, dims: _DimsLike) -> _Dims:
+ dims = (dims,) if isinstance(dims, str) else tuple(dims)
+ if len(dims) != self.ndim:
+ raise ValueError(
+ f"dimensions {dims} must have the same length as the "
+ f"number of data dimensions, ndim={self.ndim}"
+ )
+ if len(set(dims)) < len(dims):
+ repeated_dims = {d for d in dims if dims.count(d) > 1}
+ warnings.warn(
+ f"Duplicate dimension names present: dimensions {repeated_dims} appear more than once in dims={dims}. "
+ "We do not yet support duplicate dimension names, but we do allow initial construction of the object. "
+ "We recommend you rename the dims immediately to become distinct, as most xarray functionality is likely to fail silently if you do not. "
+ "To rename the dimensions you will need to set the ``.dims`` attribute of each variable, ``e.g. var.dims=('x0', 'x1')``.",
+ UserWarning,
+ )
+ return dims
@property
- def attrs(self) ->dict[Any, Any]:
+ def attrs(self) -> dict[Any, Any]:
"""Dictionary of local attributes on this NamedArray."""
- pass
+ if self._attrs is None:
+ self._attrs = {}
+ return self._attrs
+
+ @attrs.setter
+ def attrs(self, value: Mapping[Any, Any]) -> None:
+ self._attrs = dict(value) if value else None
+
+ def _check_shape(self, new_data: duckarray[Any, _DType_co]) -> None:
+ if new_data.shape != self.shape:
+ raise ValueError(
+ f"replacement data must match the {self.__class__.__name__}'s shape. "
+ f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}"
+ )
@property
- def data(self) ->duckarray[Any, _DType_co]:
+ def data(self) -> duckarray[Any, _DType_co]:
"""
The NamedArray's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.
"""
- pass
+
+ return self._data
+
+ @data.setter
+ def data(self, data: duckarray[Any, _DType_co]) -> None:
+ self._check_shape(data)
+ self._data = data
@property
- def imag(self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]]
- ) ->NamedArray[_ShapeType, _dtype[_ScalarType]]:
+ def imag(
+ self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
+ ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]:
"""
The imaginary part of the array.
@@ -315,11 +564,17 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
--------
numpy.ndarray.imag
"""
- pass
+ if isinstance(self._data, _arrayapi):
+ from xarray.namedarray._array_api import imag
+
+ return imag(self)
+
+ return self._new(data=self._data.imag)
@property
- def real(self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]]
- ) ->NamedArray[_ShapeType, _dtype[_ScalarType]]:
+ def real(
+ self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
+ ) -> NamedArray[_ShapeType, _dtype[_ScalarType]]:
"""
The real part of the array.
@@ -327,71 +582,99 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
--------
numpy.ndarray.real
"""
- pass
+ if isinstance(self._data, _arrayapi):
+ from xarray.namedarray._array_api import real
+
+ return real(self)
+ return self._new(data=self._data.real)
- def __dask_tokenize__(self) ->object:
+ def __dask_tokenize__(self) -> object:
+ # Use v.data, instead of v._data, in order to cope with the wrappers
+ # around NetCDF and the like
from dask.base import normalize_token
- return normalize_token((type(self), self._dims, self.data, self.
- _attrs or None))
- def __dask_graph__(self) ->(Graph | None):
+ return normalize_token((type(self), self._dims, self.data, self._attrs or None))
+
+ def __dask_graph__(self) -> Graph | None:
if is_duck_dask_array(self._data):
return self._data.__dask_graph__()
else:
+ # TODO: Should this method just raise instead?
+ # raise NotImplementedError("Method requires self.data to be a dask array")
return None
- def __dask_keys__(self) ->NestedKeys:
+ def __dask_keys__(self) -> NestedKeys:
if is_duck_dask_array(self._data):
return self._data.__dask_keys__()
else:
- raise AttributeError(
- 'Method requires self.data to be a dask array.')
+ raise AttributeError("Method requires self.data to be a dask array.")
- def __dask_layers__(self) ->Sequence[str]:
+ def __dask_layers__(self) -> Sequence[str]:
if is_duck_dask_array(self._data):
return self._data.__dask_layers__()
else:
- raise AttributeError(
- 'Method requires self.data to be a dask array.')
+ raise AttributeError("Method requires self.data to be a dask array.")
@property
- def __dask_optimize__(self) ->Callable[..., dict[Any, Any]]:
+ def __dask_optimize__(
+ self,
+ ) -> Callable[..., dict[Any, Any]]:
if is_duck_dask_array(self._data):
- return self._data.__dask_optimize__
+ return self._data.__dask_optimize__ # type: ignore[no-any-return]
else:
- raise AttributeError(
- 'Method requires self.data to be a dask array.')
+ raise AttributeError("Method requires self.data to be a dask array.")
@property
- def __dask_scheduler__(self) ->SchedulerGetCallable:
+ def __dask_scheduler__(self) -> SchedulerGetCallable:
if is_duck_dask_array(self._data):
return self._data.__dask_scheduler__
else:
- raise AttributeError(
- 'Method requires self.data to be a dask array.')
+ raise AttributeError("Method requires self.data to be a dask array.")
- def __dask_postcompute__(self) ->tuple[PostComputeCallable, tuple[Any, ...]
- ]:
+ def __dask_postcompute__(
+ self,
+ ) -> tuple[PostComputeCallable, tuple[Any, ...]]:
if is_duck_dask_array(self._data):
- array_func, array_args = self._data.__dask_postcompute__()
+ array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call]
return self._dask_finalize, (array_func,) + array_args
else:
- raise AttributeError(
- 'Method requires self.data to be a dask array.')
-
- def __dask_postpersist__(self) ->tuple[Callable[[Graph,
- PostPersistCallable[Any], Any, Any], Self], tuple[Any, ...]]:
+ raise AttributeError("Method requires self.data to be a dask array.")
+
+ def __dask_postpersist__(
+ self,
+ ) -> tuple[
+ Callable[
+ [Graph, PostPersistCallable[Any], Any, Any],
+ Self,
+ ],
+ tuple[Any, ...],
+ ]:
if is_duck_dask_array(self._data):
a: tuple[PostPersistCallable[Any], tuple[Any, ...]]
- a = self._data.__dask_postpersist__()
+ a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call]
array_func, array_args = a
+
return self._dask_finalize, (array_func,) + array_args
else:
- raise AttributeError(
- 'Method requires self.data to be a dask array.')
+ raise AttributeError("Method requires self.data to be a dask array.")
- def get_axis_num(self, dim: (Hashable | Iterable[Hashable])) ->(int |
- tuple[int, ...]):
+ def _dask_finalize(
+ self,
+ results: Graph,
+ array_func: PostPersistCallable[Any],
+ *args: Any,
+ **kwargs: Any,
+ ) -> Self:
+ data = array_func(results, *args, **kwargs)
+ return type(self)(self._dims, data, attrs=self._attrs)
+
+ @overload
+ def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]: ...
+
+ @overload
+ def get_axis_num(self, dim: Hashable) -> int: ...
+
+ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
"""Return axis number(s) corresponding to dimension(s) in this array.
Parameters
@@ -404,10 +687,20 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
int or tuple of int
Axis number or numbers corresponding to the given dimensions.
"""
- pass
+ if not isinstance(dim, str) and isinstance(dim, Iterable):
+ return tuple(self._get_axis_num(d) for d in dim)
+ else:
+ return self._get_axis_num(dim)
+
+ def _get_axis_num(self: Any, dim: Hashable) -> int:
+ _raise_if_any_duplicate_dimensions(self.dims)
+ try:
+ return self.dims.index(dim) # type: ignore[no-any-return]
+ except ValueError:
+ raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}")
@property
- def chunks(self) ->(_Chunks | None):
+ def chunks(self) -> _Chunks | None:
"""
Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if
the underlying data is not a dask array.
@@ -418,10 +711,16 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
NamedArray.chunksizes
xarray.unify_chunks
"""
- pass
+ data = self._data
+ if isinstance(data, _chunkedarray):
+ return data.chunks
+ else:
+ return None
@property
- def chunksizes(self) ->Mapping[_Dim, _Shape]:
+ def chunksizes(
+ self,
+ ) -> Mapping[_Dim, _Shape]:
"""
Mapping from dimension names to block lengths for this namedArray's data, or None if
the underlying data is not a dask array.
@@ -436,16 +735,24 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
NamedArray.chunks
xarray.unify_chunks
"""
- pass
+ data = self._data
+ if isinstance(data, _chunkedarray):
+ return dict(zip(self.dims, data.chunks))
+ else:
+ return {}
@property
- def sizes(self) ->dict[_Dim, _IntOrUnknown]:
+ def sizes(self) -> dict[_Dim, _IntOrUnknown]:
"""Ordered mapping from dimension names to lengths."""
- pass
-
- def chunk(self, chunks: T_Chunks={}, chunked_array_type: (str |
- ChunkManagerEntrypoint[Any] | None)=None, from_array_kwargs: Any=
- None, **chunks_kwargs: Any) ->Self:
+ return dict(zip(self.dims, self.shape))
+
+ def chunk(
+ self,
+ chunks: T_Chunks = {},
+ chunked_array_type: str | ChunkManagerEntrypoint[Any] | None = None,
+ from_array_kwargs: Any = None,
+ **chunks_kwargs: Any,
+ ) -> Self:
"""Coerce this array's data into a dask array with the given chunks.
If this variable is a non-dask array, it will be converted to dask
@@ -485,19 +792,76 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
xarray.unify_chunks
dask.array.from_array
"""
- pass
- def to_numpy(self) ->np.ndarray[Any, Any]:
+ if from_array_kwargs is None:
+ from_array_kwargs = {}
+
+ if chunks is None:
+ warnings.warn(
+ "None value for 'chunks' is deprecated. "
+ "It will raise an error in the future. Use instead '{}'",
+ category=FutureWarning,
+ )
+ chunks = {}
+
+ if isinstance(chunks, (float, str, int, tuple, list)):
+ # TODO we shouldn't assume here that other chunkmanagers can handle these types
+ # TODO should we call normalize_chunks here?
+ pass # dask.array.from_array can handle these directly
+ else:
+ chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk")
+
+ if is_dict_like(chunks):
+ # This method of iteration allows for duplicated dimension names, GH8579
+ chunks = {
+ dim_number: chunks[dim]
+ for dim_number, dim in enumerate(self.dims)
+ if dim in chunks
+ }
+
+ chunkmanager = guess_chunkmanager(chunked_array_type)
+
+ data_old = self._data
+ if chunkmanager.is_chunked_array(data_old):
+ data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
+ else:
+ if not isinstance(data_old, ExplicitlyIndexed):
+ ndata = data_old
+ else:
+ # Unambiguously handle array storage backends (like NetCDF4 and h5py)
+ # that can't handle general array indexing. For example, in netCDF4 you
+ # can do "outer" indexing along two dimensions independent, which works
+ # differently from how NumPy handles it.
+ # da.from_array works by using lazy indexing with a tuple of slices.
+ # Using OuterIndexer is a pragmatic choice: dask does not yet handle
+ # different indexing types in an explicit way:
+ # https://github.com/dask/dask/issues/2883
+ ndata = ImplicitToExplicitIndexingAdapter(data_old, OuterIndexer) # type: ignore[assignment]
+
+ if is_dict_like(chunks):
+ chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))
+
+ data_chunked = chunkmanager.from_array(ndata, chunks, **from_array_kwargs) # type: ignore[arg-type]
+
+ return self._replace(data=data_chunked)
+
+ def to_numpy(self) -> np.ndarray[Any, Any]:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
- pass
+ # TODO an entrypoint so array libraries can choose coercion method?
+ return to_numpy(self._data)
- def as_numpy(self) ->Self:
+ def as_numpy(self) -> Self:
"""Coerces wrapped data into a numpy array, returning a Variable."""
- pass
-
- def reduce(self, func: Callable[..., Any], dim: Dims=None, axis: (int |
- Sequence[int] | None)=None, keepdims: bool=False, **kwargs: Any
- ) ->NamedArray[Any, Any]:
+ return self._replace(data=self.to_numpy())
+
+ def reduce(
+ self,
+ func: Callable[..., Any],
+ dim: Dims = None,
+ axis: int | Sequence[int] | None = None,
+ keepdims: bool = False,
+ **kwargs: Any,
+ ) -> NamedArray[Any, Any]:
"""Reduce this array by applying `func` along some dimension(s).
Parameters
@@ -526,31 +890,116 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
Array with summarized data and the indicated dimension(s)
removed.
"""
- pass
-
- def _nonzero(self: T_NamedArrayInteger) ->tuple[T_NamedArrayInteger, ...]:
+ if dim == ...:
+ dim = None
+ if dim is not None and axis is not None:
+ raise ValueError("cannot supply both 'axis' and 'dim' arguments")
+
+ if dim is not None:
+ axis = self.get_axis_num(dim)
+
+ with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore", r"Mean of empty slice", category=RuntimeWarning
+ )
+ if axis is not None:
+ if isinstance(axis, tuple) and len(axis) == 1:
+ # unpack axis for the benefit of functions
+ # like np.argmin which can't handle tuple arguments
+ axis = axis[0]
+ data = func(self.data, axis=axis, **kwargs)
+ else:
+ data = func(self.data, **kwargs)
+
+ if getattr(data, "shape", ()) == self.shape:
+ dims = self.dims
+ else:
+ removed_axes: Iterable[int]
+ if axis is None:
+ removed_axes = range(self.ndim)
+ else:
+ removed_axes = np.atleast_1d(axis) % self.ndim
+ if keepdims:
+ # Insert np.newaxis for removed dims
+ slices = tuple(
+ np.newaxis if i in removed_axes else slice(None, None)
+ for i in range(self.ndim)
+ )
+ if getattr(data, "shape", None) is None:
+ # Reduce has produced a scalar value, not an array-like
+ data = np.asanyarray(data)[slices]
+ else:
+ data = data[slices]
+ dims = self.dims
+ else:
+ dims = tuple(
+ adim for n, adim in enumerate(self.dims) if n not in removed_axes
+ )
+
+ # Return NamedArray to handle IndexVariable when data is nD
+ return from_array(dims, data, attrs=self._attrs)
+
+ def _nonzero(self: T_NamedArrayInteger) -> tuple[T_NamedArrayInteger, ...]:
"""Equivalent numpy's nonzero but returns a tuple of NamedArrays."""
- pass
-
- def __repr__(self) ->str:
+ # TODO: we should replace dask's native nonzero
+ # after https://github.com/dask/dask/issues/1076 is implemented.
+ # TODO: cast to ndarray and back to T_DuckArray is a workaround
+ nonzeros = np.nonzero(cast("NDArray[np.integer[Any]]", self.data))
+ _attrs = self.attrs
+ return tuple(
+ cast("T_NamedArrayInteger", self._new((dim,), nz, _attrs))
+ for nz, dim in zip(nonzeros, self.dims)
+ )
+
+ def __repr__(self) -> str:
return formatting.array_repr(self)
- def _as_sparse(self, sparse_format: (Literal['coo'] | Default)=_default,
- fill_value: (ArrayLike | Default)=_default) ->NamedArray[Any, _DType_co
- ]:
+ def _repr_html_(self) -> str:
+ return formatting_html.array_repr(self)
+
+ def _as_sparse(
+ self,
+ sparse_format: Literal["coo"] | Default = _default,
+ fill_value: ArrayLike | Default = _default,
+ ) -> NamedArray[Any, _DType_co]:
"""
Use sparse-array as backend.
"""
- pass
+ import sparse
+
+ from xarray.namedarray._array_api import astype
+
+ # TODO: what to do if dask-backended?
+ if fill_value is _default:
+ dtype, fill_value = dtypes.maybe_promote(self.dtype)
+ else:
+ dtype = dtypes.result_type(self.dtype, fill_value)
+
+ if sparse_format is _default:
+ sparse_format = "coo"
+ try:
+ as_sparse = getattr(sparse, f"as_{sparse_format.lower()}")
+ except AttributeError as exc:
+ raise ValueError(f"{sparse_format} is not a valid sparse format") from exc
- def _to_dense(self) ->NamedArray[Any, _DType_co]:
+ data = as_sparse(astype(self, dtype).data, fill_value=fill_value)
+ return self._new(data=data)
+
+ def _to_dense(self) -> NamedArray[Any, _DType_co]:
"""
Change backend from sparse to np.array.
"""
- pass
+ if isinstance(self._data, _sparsearrayfunction_or_api):
+ data_dense: np.ndarray[Any, _DType_co] = self._data.todense()
+ return self._new(data=data_dense)
+ else:
+ raise TypeError("self.data is not a sparse array")
- def permute_dims(self, *dim: (Iterable[_Dim] | ellipsis), missing_dims:
- ErrorOptionsWithWarn='raise') ->NamedArray[Any, _DType_co]:
+ def permute_dims(
+ self,
+ *dim: Iterable[_Dim] | ellipsis,
+ missing_dims: ErrorOptionsWithWarn = "raise",
+ ) -> NamedArray[Any, _DType_co]:
"""Return a new object with transposed dimensions.
Parameters
@@ -576,15 +1025,37 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
--------
numpy.transpose
"""
- pass
+
+ from xarray.namedarray._array_api import permute_dims
+
+ if not dim:
+ dims = self.dims[::-1]
+ else:
+ dims = tuple(infix_dims(dim, self.dims, missing_dims)) # type: ignore[arg-type]
+
+ if len(dims) < 2 or dims == self.dims:
+ # no need to transpose if only one dimension
+ # or dims are in same order
+ return self.copy(deep=False)
+
+ axes_result = self.get_axis_num(dims)
+ axes = (axes_result,) if isinstance(axes_result, int) else axes_result
+
+ return permute_dims(self, axes)
@property
- def T(self) ->NamedArray[Any, _DType_co]:
+ def T(self) -> NamedArray[Any, _DType_co]:
"""Return a new object with transposed dimensions."""
- pass
+ if self.ndim != 2:
+ raise ValueError(
+ f"x.T requires x to have 2 dimensions, got {self.ndim}. Use x.permute_dims() to permute dimensions."
+ )
- def broadcast_to(self, dim: (Mapping[_Dim, int] | None)=None, **
- dim_kwargs: Any) ->NamedArray[Any, _DType_co]:
+ return self.permute_dims()
+
+ def broadcast_to(
+ self, dim: Mapping[_Dim, int] | None = None, **dim_kwargs: Any
+ ) -> NamedArray[Any, _DType_co]:
"""
Broadcast the NamedArray to a new shape. New dimensions are not allowed.
@@ -617,10 +1088,35 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
>>> broadcasted.sizes
{'x': 2, 'y': 2}
"""
- pass
- def expand_dims(self, dim: (_Dim | Default)=_default) ->NamedArray[Any,
- _DType_co]:
+ from xarray.core import duck_array_ops
+
+ combined_dims = either_dict_or_kwargs(dim, dim_kwargs, "broadcast_to")
+
+ # Check that no new dimensions are added
+ if new_dims := set(combined_dims) - set(self.dims):
+ raise ValueError(
+ f"Cannot add new dimensions: {new_dims}. Only existing dimensions are allowed. "
+ "Use `expand_dims` method to add new dimensions."
+ )
+
+ # Create a dictionary of the current dimensions and their sizes
+ current_shape = self.sizes
+
+ # Update the current shape with the new dimensions, keeping the order of the original dimensions
+ broadcast_shape = {d: current_shape.get(d, 1) for d in self.dims}
+ broadcast_shape |= combined_dims
+
+ # Ensure the dimensions are in the correct order
+ ordered_dims = list(broadcast_shape.keys())
+ ordered_shape = tuple(broadcast_shape[d] for d in ordered_dims)
+ data = duck_array_ops.broadcast_to(self._data, ordered_shape) # type: ignore[no-untyped-call] # TODO: use array-api-compat function
+ return self._new(data=data, dims=ordered_dims)
+
+ def expand_dims(
+ self,
+ dim: _Dim | Default = _default,
+ ) -> NamedArray[Any, _DType_co]:
"""
Expand the dimensions of the NamedArray.
@@ -650,7 +1146,20 @@ class NamedArray(NamedArrayAggregations, Generic[_ShapeType_co, _DType_co]):
('z', 'x', 'y')
"""
- pass
+
+ from xarray.namedarray._array_api import expand_dims
+
+ return expand_dims(self, dim=dim)
_NamedArray = NamedArray[Any, np.dtype[_ScalarType_co]]
+
+
+def _raise_if_any_duplicate_dimensions(
+ dims: _Dims, err_context: str = "This function"
+) -> None:
+ if len(set(dims)) < len(dims):
+ repeated_dims = {d for d in dims if dims.count(d) > 1}
+ raise ValueError(
+ f"{err_context} cannot handle duplicate dimensions, but dimensions {repeated_dims} appear more than once on this object's dims: {dims}"
+ )
diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py
index cfbd1e84..963d12fd 100644
--- a/xarray/namedarray/daskmanager.py
+++ b/xarray/namedarray/daskmanager.py
@@ -1,31 +1,253 @@
from __future__ import annotations
+
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Callable
+
import numpy as np
from packaging.version import Version
+
from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint, T_ChunkedArray
from xarray.namedarray.utils import is_duck_dask_array, module_available
+
if TYPE_CHECKING:
- from xarray.namedarray._typing import T_Chunks, _DType_co, _NormalizedChunks, duckarray
+ from xarray.namedarray._typing import (
+ T_Chunks,
+ _DType_co,
+ _NormalizedChunks,
+ duckarray,
+ )
+
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray[Any, Any]
-dask_available = module_available('dask')
-class DaskManager(ChunkManagerEntrypoint['DaskArray']):
+dask_available = module_available("dask")
+
+
+class DaskManager(ChunkManagerEntrypoint["DaskArray"]):
array_cls: type[DaskArray]
available: bool = dask_available
- def __init__(self) ->None:
+ def __init__(self) -> None:
+ # TODO can we replace this with a class attribute instead?
+
from dask.array import Array
+
self.array_cls = Array
- def normalize_chunks(self, chunks: (T_Chunks | _NormalizedChunks),
- shape: (tuple[int, ...] | None)=None, limit: (int | None)=None,
- dtype: (_DType_co | None)=None, previous_chunks: (_NormalizedChunks |
- None)=None) ->Any:
+ def is_chunked_array(self, data: duckarray[Any, Any]) -> bool:
+ return is_duck_dask_array(data)
+
+ def chunks(self, data: Any) -> _NormalizedChunks:
+ return data.chunks # type: ignore[no-any-return]
+
+ def normalize_chunks(
+ self,
+ chunks: T_Chunks | _NormalizedChunks,
+ shape: tuple[int, ...] | None = None,
+ limit: int | None = None,
+ dtype: _DType_co | None = None,
+ previous_chunks: _NormalizedChunks | None = None,
+ ) -> Any:
"""Called by open_dataset"""
- pass
+ from dask.array.core import normalize_chunks
+
+ return normalize_chunks(
+ chunks,
+ shape=shape,
+ limit=limit,
+ dtype=dtype,
+ previous_chunks=previous_chunks,
+ ) # type: ignore[no-untyped-call]
+
+ def from_array(
+ self, data: Any, chunks: T_Chunks | _NormalizedChunks, **kwargs: Any
+ ) -> DaskArray | Any:
+ import dask.array as da
+
+ if isinstance(data, ImplicitToExplicitIndexingAdapter):
+ # lazily loaded backend array classes should use NumPy array operations.
+ kwargs["meta"] = np.ndarray
+
+ return da.from_array(
+ data,
+ chunks,
+ **kwargs,
+ ) # type: ignore[no-untyped-call]
+
+ def compute(
+ self, *data: Any, **kwargs: Any
+ ) -> tuple[np.ndarray[Any, _DType_co], ...]:
+ from dask.array import compute
+
+ return compute(*data, **kwargs) # type: ignore[no-untyped-call, no-any-return]
+
+ @property
+ def array_api(self) -> Any:
+ from dask import array as da
+
+ return da
+
+ def reduction(
+ self,
+ arr: T_ChunkedArray,
+ func: Callable[..., Any],
+ combine_func: Callable[..., Any] | None = None,
+ aggregate_func: Callable[..., Any] | None = None,
+ axis: int | Sequence[int] | None = None,
+ dtype: _DType_co | None = None,
+ keepdims: bool = False,
+ ) -> DaskArray | Any:
+ from dask.array import reduction
+
+ return reduction(
+ arr,
+ chunk=func,
+ combine=combine_func,
+ aggregate=aggregate_func,
+ axis=axis,
+ dtype=dtype,
+ keepdims=keepdims,
+ ) # type: ignore[no-untyped-call]
+
+ def scan(
+ self,
+ func: Callable[..., Any],
+ binop: Callable[..., Any],
+ ident: float,
+ arr: T_ChunkedArray,
+ axis: int | None = None,
+ dtype: _DType_co | None = None,
+ **kwargs: Any,
+ ) -> DaskArray | Any:
+ from dask.array.reductions import cumreduction
+
+ return cumreduction(
+ func,
+ binop,
+ ident,
+ arr,
+ axis=axis,
+ dtype=dtype,
+ **kwargs,
+ ) # type: ignore[no-untyped-call]
+
+ def apply_gufunc(
+ self,
+ func: Callable[..., Any],
+ signature: str,
+ *args: Any,
+ axes: Sequence[tuple[int, ...]] | None = None,
+ axis: int | None = None,
+ keepdims: bool = False,
+ output_dtypes: Sequence[_DType_co] | None = None,
+ output_sizes: dict[str, int] | None = None,
+ vectorize: bool | None = None,
+ allow_rechunk: bool = False,
+ meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ from dask.array.gufunc import apply_gufunc
+
+ return apply_gufunc(
+ func,
+ signature,
+ *args,
+ axes=axes,
+ axis=axis,
+ keepdims=keepdims,
+ output_dtypes=output_dtypes,
+ output_sizes=output_sizes,
+ vectorize=vectorize,
+ allow_rechunk=allow_rechunk,
+ meta=meta,
+ **kwargs,
+ ) # type: ignore[no-untyped-call]
+
+ def map_blocks(
+ self,
+ func: Callable[..., Any],
+ *args: Any,
+ dtype: _DType_co | None = None,
+ chunks: tuple[int, ...] | None = None,
+ drop_axis: int | Sequence[int] | None = None,
+ new_axis: int | Sequence[int] | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ import dask
+ from dask.array import map_blocks
+
+ if drop_axis is None and Version(dask.__version__) < Version("2022.9.1"):
+ # See https://github.com/pydata/xarray/pull/7019#discussion_r1196729489
+ # TODO remove once dask minimum version >= 2022.9.1
+ drop_axis = []
+
+ # pass through name, meta, token as kwargs
+ return map_blocks(
+ func,
+ *args,
+ dtype=dtype,
+ chunks=chunks,
+ drop_axis=drop_axis,
+ new_axis=new_axis,
+ **kwargs,
+ ) # type: ignore[no-untyped-call]
+
+ def blockwise(
+ self,
+ func: Callable[..., Any],
+ out_ind: Iterable[Any],
+ *args: Any,
+ # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types
+ name: str | None = None,
+ token: Any | None = None,
+ dtype: _DType_co | None = None,
+ adjust_chunks: dict[Any, Callable[..., Any]] | None = None,
+ new_axes: dict[Any, int] | None = None,
+ align_arrays: bool = True,
+ concatenate: bool | None = None,
+ meta: tuple[np.ndarray[Any, _DType_co], ...] | None = None,
+ **kwargs: Any,
+ ) -> DaskArray | Any:
+ from dask.array import blockwise
+
+ return blockwise(
+ func,
+ out_ind,
+ *args,
+ name=name,
+ token=token,
+ dtype=dtype,
+ adjust_chunks=adjust_chunks,
+ new_axes=new_axes,
+ align_arrays=align_arrays,
+ concatenate=concatenate,
+ meta=meta,
+ **kwargs,
+ ) # type: ignore[no-untyped-call]
+
+ def unify_chunks(
+ self,
+ *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
+ **kwargs: Any,
+ ) -> tuple[dict[str, _NormalizedChunks], list[DaskArray]]:
+ from dask.array.core import unify_chunks
+
+ return unify_chunks(*args, **kwargs) # type: ignore[no-any-return, no-untyped-call]
+
+ def store(
+ self,
+ sources: Any | Sequence[Any],
+ targets: Any,
+ **kwargs: Any,
+ ) -> Any:
+ from dask.array import store
+
+ return store(
+ sources=sources,
+ targets=targets,
+ **kwargs,
+ )
diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py
index 5bebd041..7a83bd17 100644
--- a/xarray/namedarray/dtypes.py
+++ b/xarray/namedarray/dtypes.py
@@ -1,44 +1,57 @@
from __future__ import annotations
+
import functools
import sys
from typing import Any, Literal
+
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
+
import numpy as np
+
from xarray.namedarray import utils
-NA = utils.ReprObject('<NA>')
+
+# Use as a sentinel value to indicate a dtype appropriate NA value.
+NA = utils.ReprObject("<NA>")
@functools.total_ordering
class AlwaysGreaterThan:
-
- def __gt__(self, other: Any) ->Literal[True]:
+ def __gt__(self, other: Any) -> Literal[True]:
return True
- def __eq__(self, other: Any) ->bool:
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, type(self))
@functools.total_ordering
class AlwaysLessThan:
-
- def __lt__(self, other: Any) ->Literal[True]:
+ def __lt__(self, other: Any) -> Literal[True]:
return True
- def __eq__(self, other: Any) ->bool:
+ def __eq__(self, other: Any) -> bool:
return isinstance(other, type(self))
+# Equivalence to np.inf (-np.inf) for object-type
INF = AlwaysGreaterThan()
NINF = AlwaysLessThan()
-PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ((
- np.number, np.character), (np.bool_, np.character), (np.bytes_, np.str_))
-def maybe_promote(dtype: np.dtype[np.generic]) ->tuple[np.dtype[np.generic],
- Any]:
+# Pairs of types that, if both found, should be promoted to object dtype
+# instead of following NumPy's own type-promotion rules. These type promotion
+# rules match pandas instead. For reference, see the NumPy type hierarchy:
+# https://numpy.org/doc/stable/reference/arrays.scalars.html
+PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
+ (np.number, np.character), # numpy promotes to character
+ (np.bool_, np.character), # numpy promotes to character
+ (np.bytes_, np.str_), # numpy promotes to unicode
+)
+
+
+def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]:
"""Simpler equivalent of pandas.core.common._maybe_promote
Parameters
@@ -50,13 +63,40 @@ def maybe_promote(dtype: np.dtype[np.generic]) ->tuple[np.dtype[np.generic],
dtype : Promoted dtype that can hold missing values.
fill_value : Valid missing value for the promoted dtype.
"""
- pass
-
-
-NAT_TYPES = {np.datetime64('NaT').dtype, np.timedelta64('NaT').dtype}
-
-
-def get_fill_value(dtype: np.dtype[np.generic]) ->Any:
+ # N.B. these casting rules should match pandas
+ dtype_: np.typing.DTypeLike
+ fill_value: Any
+ if np.issubdtype(dtype, np.floating):
+ dtype_ = dtype
+ fill_value = np.nan
+ elif np.issubdtype(dtype, np.timedelta64):
+ # See https://github.com/numpy/numpy/issues/10685
+ # np.timedelta64 is a subclass of np.integer
+ # Check np.timedelta64 before np.integer
+ fill_value = np.timedelta64("NaT")
+ dtype_ = dtype
+ elif np.issubdtype(dtype, np.integer):
+ dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
+ fill_value = np.nan
+ elif np.issubdtype(dtype, np.complexfloating):
+ dtype_ = dtype
+ fill_value = np.nan + np.nan * 1j
+ elif np.issubdtype(dtype, np.datetime64):
+ dtype_ = dtype
+ fill_value = np.datetime64("NaT")
+ else:
+ dtype_ = object
+ fill_value = np.nan
+
+ dtype_out = np.dtype(dtype_)
+ fill_value = dtype_out.type(fill_value)
+ return dtype_out, fill_value
+
+
+NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
+
+
+def get_fill_value(dtype: np.dtype[np.generic]) -> Any:
"""Return an appropriate fill value for this dtype.
Parameters
@@ -67,11 +107,13 @@ def get_fill_value(dtype: np.dtype[np.generic]) ->Any:
-------
fill_value : Missing value corresponding to this dtype.
"""
- pass
+ _, fill_value = maybe_promote(dtype)
+ return fill_value
-def get_pos_infinity(dtype: np.dtype[np.generic], max_for_int: bool=False) ->(
- float | complex | AlwaysGreaterThan):
+def get_pos_infinity(
+ dtype: np.dtype[np.generic], max_for_int: bool = False
+) -> float | complex | AlwaysGreaterThan:
"""Return an appropriate positive infinity for this dtype.
Parameters
@@ -84,11 +126,20 @@ def get_pos_infinity(dtype: np.dtype[np.generic], max_for_int: bool=False) ->(
-------
fill_value : positive infinity value corresponding to this dtype.
"""
- pass
+ if issubclass(dtype.type, np.floating):
+ return np.inf
+
+ if issubclass(dtype.type, np.integer):
+ return np.iinfo(dtype.type).max if max_for_int else np.inf
+ if issubclass(dtype.type, np.complexfloating):
+ return np.inf + 1j * np.inf
+
+ return INF
-def get_neg_infinity(dtype: np.dtype[np.generic], min_for_int: bool=False) ->(
- float | complex | AlwaysLessThan):
+def get_neg_infinity(
+ dtype: np.dtype[np.generic], min_for_int: bool = False
+) -> float | complex | AlwaysLessThan:
"""Return an appropriate positive infinity for this dtype.
Parameters
@@ -101,17 +152,27 @@ def get_neg_infinity(dtype: np.dtype[np.generic], min_for_int: bool=False) ->(
-------
fill_value : positive infinity value corresponding to this dtype.
"""
- pass
+ if issubclass(dtype.type, np.floating):
+ return -np.inf
+ if issubclass(dtype.type, np.integer):
+ return np.iinfo(dtype.type).min if min_for_int else -np.inf
+ if issubclass(dtype.type, np.complexfloating):
+ return -np.inf - 1j * np.inf
-def is_datetime_like(dtype: np.dtype[np.generic]) ->TypeGuard[np.datetime64 |
- np.timedelta64]:
+ return NINF
+
+
+def is_datetime_like(
+ dtype: np.dtype[np.generic],
+) -> TypeGuard[np.datetime64 | np.timedelta64]:
"""Check if a dtype is a subclass of the numpy datetime types"""
- pass
+ return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
-def result_type(*arrays_and_dtypes: (np.typing.ArrayLike | np.typing.DTypeLike)
- ) ->np.dtype[np.generic]:
+def result_type(
+ *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
+) -> np.dtype[np.generic]:
"""Like np.result_type, but with type promotion rules matching pandas.
Examples of changed behavior:
@@ -127,4 +188,12 @@ def result_type(*arrays_and_dtypes: (np.typing.ArrayLike | np.typing.DTypeLike)
-------
numpy.dtype for the result.
"""
- pass
+ types = {np.result_type(t).type for t in arrays_and_dtypes}
+
+ for left, right in PROMOTE_TO_OBJECT:
+ if any(issubclass(t, left) for t in types) and any(
+ issubclass(t, right) for t in types
+ ):
+ return np.dtype(object)
+
+ return np.result_type(*arrays_and_dtypes)
diff --git a/xarray/namedarray/parallelcompat.py b/xarray/namedarray/parallelcompat.py
index edcc83cd..dd555fe2 100644
--- a/xarray/namedarray/parallelcompat.py
+++ b/xarray/namedarray/parallelcompat.py
@@ -3,29 +3,51 @@ The code in this module is an experiment in going from N=1 to N=2 parallel compu
It could later be used as the basis for a public interface allowing any N frameworks to interoperate with xarray,
but for now it is just a private experiment.
"""
+
from __future__ import annotations
+
import functools
import sys
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from importlib.metadata import EntryPoint, entry_points
from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, TypeVar
+
import numpy as np
+
from xarray.core.utils import emit_user_level_warning
from xarray.namedarray.pycompat import is_chunked_array
+
if TYPE_CHECKING:
- from xarray.namedarray._typing import _Chunks, _DType, _DType_co, _NormalizedChunks, _ShapeType, duckarray
+ from xarray.namedarray._typing import (
+ _Chunks,
+ _DType,
+ _DType_co,
+ _NormalizedChunks,
+ _ShapeType,
+ duckarray,
+ )
class ChunkedArrayMixinProtocol(Protocol):
- pass
+ def rechunk(self, chunks: Any, **kwargs: Any) -> Any: ...
+
+ @property
+ def dtype(self) -> np.dtype[Any]: ...
+ @property
+ def chunks(self) -> _NormalizedChunks: ...
+
+ def compute(
+ self, *data: Any, **kwargs: Any
+ ) -> tuple[np.ndarray[Any, _DType_co], ...]: ...
-T_ChunkedArray = TypeVar('T_ChunkedArray', bound=ChunkedArrayMixinProtocol)
+
+T_ChunkedArray = TypeVar("T_ChunkedArray", bound=ChunkedArrayMixinProtocol)
@functools.lru_cache(maxsize=1)
-def list_chunkmanagers() ->dict[str, ChunkManagerEntrypoint[Any]]:
+def list_chunkmanagers() -> dict[str, ChunkManagerEntrypoint[Any]]:
"""
Return a dictionary of available chunk managers and their ChunkManagerEntrypoint subclass objects.
@@ -39,33 +61,114 @@ def list_chunkmanagers() ->dict[str, ChunkManagerEntrypoint[Any]]:
-----
# New selection mechanism introduced with Python 3.10. See GH6514.
"""
- pass
+ if sys.version_info >= (3, 10):
+ entrypoints = entry_points(group="xarray.chunkmanagers")
+ else:
+ entrypoints = entry_points().get("xarray.chunkmanagers", ())
+ return load_chunkmanagers(entrypoints)
-def load_chunkmanagers(entrypoints: Sequence[EntryPoint]) ->dict[str,
- ChunkManagerEntrypoint[Any]]:
- """Load entrypoints and instantiate chunkmanagers only once."""
- pass
+def load_chunkmanagers(
+ entrypoints: Sequence[EntryPoint],
+) -> dict[str, ChunkManagerEntrypoint[Any]]:
+ """Load entrypoints and instantiate chunkmanagers only once."""
-def guess_chunkmanager(manager: (str | ChunkManagerEntrypoint[Any] | None)
- ) ->ChunkManagerEntrypoint[Any]:
+ loaded_entrypoints = {}
+ for entrypoint in entrypoints:
+ try:
+ loaded_entrypoints[entrypoint.name] = entrypoint.load()
+ except ModuleNotFoundError as e:
+ emit_user_level_warning(
+ f"Failed to load chunk manager entrypoint {entrypoint.name} due to {e}. Skipping.",
+ )
+ pass
+
+ available_chunkmanagers = {
+ name: chunkmanager()
+ for name, chunkmanager in loaded_entrypoints.items()
+ if chunkmanager.available
+ }
+ return available_chunkmanagers
+
+
+def guess_chunkmanager(
+ manager: str | ChunkManagerEntrypoint[Any] | None,
+) -> ChunkManagerEntrypoint[Any]:
"""
Get namespace of chunk-handling methods, guessing from what's available.
If the name of a specific ChunkManager is given (e.g. "dask"), then use that.
Else use whatever is installed, defaulting to dask if there are multiple options.
"""
- pass
-
-def get_chunked_array_type(*args: Any) ->ChunkManagerEntrypoint[Any]:
+ chunkmanagers = list_chunkmanagers()
+
+ if manager is None:
+ if len(chunkmanagers) == 1:
+ # use the only option available
+ manager = next(iter(chunkmanagers.keys()))
+ else:
+ # default to trying to use dask
+ manager = "dask"
+
+ if isinstance(manager, str):
+ if manager not in chunkmanagers:
+ raise ValueError(
+ f"unrecognized chunk manager {manager} - must be one of: {list(chunkmanagers)}"
+ )
+
+ return chunkmanagers[manager]
+ elif isinstance(manager, ChunkManagerEntrypoint):
+ # already a valid ChunkManager so just pass through
+ return manager
+ else:
+ raise TypeError(
+ f"manager must be a string or instance of ChunkManagerEntrypoint, but received type {type(manager)}"
+ )
+
+
+def get_chunked_array_type(*args: Any) -> ChunkManagerEntrypoint[Any]:
"""
Detects which parallel backend should be used for given set of arrays.
Also checks that all arrays are of same chunking type (i.e. not a mix of cubed and dask).
"""
- pass
+
+ # TODO this list is probably redundant with something inside xarray.apply_ufunc
+ ALLOWED_NON_CHUNKED_TYPES = {int, float, np.ndarray}
+
+ chunked_arrays = [
+ a
+ for a in args
+ if is_chunked_array(a) and type(a) not in ALLOWED_NON_CHUNKED_TYPES
+ ]
+
+ # Asserts all arrays are the same type (or numpy etc.)
+ chunked_array_types = {type(a) for a in chunked_arrays}
+ if len(chunked_array_types) > 1:
+ raise TypeError(
+ f"Mixing chunked array types is not supported, but received multiple types: {chunked_array_types}"
+ )
+ elif len(chunked_array_types) == 0:
+ raise TypeError("Expected a chunked array but none were found")
+
+ # iterate over defined chunk managers, seeing if each recognises this array type
+ chunked_arr = chunked_arrays[0]
+ chunkmanagers = list_chunkmanagers()
+ selected = [
+ chunkmanager
+ for chunkmanager in chunkmanagers.values()
+ if chunkmanager.is_chunked_array(chunked_arr)
+ ]
+ if not selected:
+ raise TypeError(
+ f"Could not find a Chunk Manager which recognises type {type(chunked_arr)}"
+ )
+ elif len(selected) >= 2:
+ raise TypeError(f"Multiple ChunkManagers recognise type {type(chunked_arr)}")
+ else:
+ return selected[0]
class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
@@ -86,15 +189,16 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
Parallel frameworks need to provide an array class that supports the array API standard.
This attribute is used for array instance type checking at runtime.
"""
+
array_cls: type[T_ChunkedArray]
available: bool = True
@abstractmethod
- def __init__(self) ->None:
+ def __init__(self) -> None:
"""Used to set the array_cls attribute at import time."""
raise NotImplementedError()
- def is_chunked_array(self, data: duckarray[Any, Any]) ->bool:
+ def is_chunked_array(self, data: duckarray[Any, Any]) -> bool:
"""
Check if the given object is an instance of this type of chunked array.
@@ -112,10 +216,10 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
--------
dask.is_dask_collection
"""
- pass
+ return isinstance(data, self.array_cls)
@abstractmethod
- def chunks(self, data: T_ChunkedArray) ->_NormalizedChunks:
+ def chunks(self, data: T_ChunkedArray) -> _NormalizedChunks:
"""
Return the current chunks of the given array.
@@ -136,13 +240,17 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.Array.chunks
cubed.Array.chunks
"""
- pass
+ raise NotImplementedError()
@abstractmethod
- def normalize_chunks(self, chunks: (_Chunks | _NormalizedChunks), shape:
- (_ShapeType | None)=None, limit: (int | None)=None, dtype: (_DType |
- None)=None, previous_chunks: (_NormalizedChunks | None)=None
- ) ->_NormalizedChunks:
+ def normalize_chunks(
+ self,
+ chunks: _Chunks | _NormalizedChunks,
+ shape: _ShapeType | None = None,
+ limit: int | None = None,
+ dtype: _DType | None = None,
+ previous_chunks: _NormalizedChunks | None = None,
+ ) -> _NormalizedChunks:
"""
Normalize given chunking pattern into an explicit tuple of tuples representation.
@@ -169,11 +277,12 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
--------
dask.array.core.normalize_chunks
"""
- pass
+ raise NotImplementedError()
@abstractmethod
- def from_array(self, data: duckarray[Any, Any], chunks: _Chunks, **
- kwargs: Any) ->T_ChunkedArray:
+ def from_array(
+ self, data: duckarray[Any, Any], chunks: _Chunks, **kwargs: Any
+ ) -> T_ChunkedArray:
"""
Create a chunked array from a non-chunked numpy-like array.
@@ -194,10 +303,14 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.from_array
cubed.from_array
"""
- pass
+ raise NotImplementedError()
- def rechunk(self, data: T_ChunkedArray, chunks: (_NormalizedChunks |
- tuple[int, ...] | _Chunks), **kwargs: Any) ->Any:
+ def rechunk(
+ self,
+ data: T_ChunkedArray,
+ chunks: _NormalizedChunks | tuple[int, ...] | _Chunks,
+ **kwargs: Any,
+ ) -> Any:
"""
Changes the chunking pattern of the given array.
@@ -221,11 +334,12 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.Array.rechunk
cubed.Array.rechunk
"""
- pass
+ return data.rechunk(chunks, **kwargs)
@abstractmethod
- def compute(self, *data: (T_ChunkedArray | Any), **kwargs: Any) ->tuple[
- np.ndarray[Any, _DType_co], ...]:
+ def compute(
+ self, *data: T_ChunkedArray | Any, **kwargs: Any
+ ) -> tuple[np.ndarray[Any, _DType_co], ...]:
"""
Computes one or more chunked arrays, returning them as eager numpy arrays.
@@ -248,10 +362,10 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.compute
cubed.compute
"""
- pass
+ raise NotImplementedError()
@property
- def array_api(self) ->Any:
+ def array_api(self) -> Any:
"""
Return the array_api namespace following the python array API standard.
@@ -264,13 +378,18 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array
cubed.array_api
"""
- pass
+ raise NotImplementedError()
- def reduction(self, arr: T_ChunkedArray, func: Callable[..., Any],
- combine_func: (Callable[..., Any] | None)=None, aggregate_func: (
- Callable[..., Any] | None)=None, axis: (int | Sequence[int] | None)
- =None, dtype: (_DType_co | None)=None, keepdims: bool=False
- ) ->T_ChunkedArray:
+ def reduction(
+ self,
+ arr: T_ChunkedArray,
+ func: Callable[..., Any],
+ combine_func: Callable[..., Any] | None = None,
+ aggregate_func: Callable[..., Any] | None = None,
+ axis: int | Sequence[int] | None = None,
+ dtype: _DType_co | None = None,
+ keepdims: bool = False,
+ ) -> T_ChunkedArray:
"""
A general version of array reductions along one or more axes.
@@ -308,11 +427,18 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.reduction
cubed.core.reduction
"""
- pass
+ raise NotImplementedError()
- def scan(self, func: Callable[..., Any], binop: Callable[..., Any],
- ident: float, arr: T_ChunkedArray, axis: (int | None)=None, dtype:
- (_DType_co | None)=None, **kwargs: Any) ->T_ChunkedArray:
+ def scan(
+ self,
+ func: Callable[..., Any],
+ binop: Callable[..., Any],
+ ident: float,
+ arr: T_ChunkedArray,
+ axis: int | None = None,
+ dtype: _DType_co | None = None,
+ **kwargs: Any,
+ ) -> T_ChunkedArray:
"""
General version of a 1D scan, also known as a cumulative array reduction.
@@ -338,13 +464,20 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
--------
dask.array.cumreduction
"""
- pass
+ raise NotImplementedError()
@abstractmethod
- def apply_gufunc(self, func: Callable[..., Any], signature: str, *args:
- Any, axes: (Sequence[tuple[int, ...]] | None)=None, keepdims: bool=
- False, output_dtypes: (Sequence[_DType_co] | None)=None, vectorize:
- (bool | None)=None, **kwargs: Any) ->Any:
+ def apply_gufunc(
+ self,
+ func: Callable[..., Any],
+ signature: str,
+ *args: Any,
+ axes: Sequence[tuple[int, ...]] | None = None,
+ keepdims: bool = False,
+ output_dtypes: Sequence[_DType_co] | None = None,
+ vectorize: bool | None = None,
+ **kwargs: Any,
+ ) -> Any:
"""
Apply a generalized ufunc or similar python function to arrays.
@@ -418,12 +551,18 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
.. [1] https://docs.scipy.org/doc/numpy/reference/ufuncs.html
.. [2] https://docs.scipy.org/doc/numpy/reference/c-api/generalized-ufuncs.html
"""
- pass
+ raise NotImplementedError()
- def map_blocks(self, func: Callable[..., Any], *args: Any, dtype: (
- _DType_co | None)=None, chunks: (tuple[int, ...] | None)=None,
- drop_axis: (int | Sequence[int] | None)=None, new_axis: (int |
- Sequence[int] | None)=None, **kwargs: Any) ->Any:
+ def map_blocks(
+ self,
+ func: Callable[..., Any],
+ *args: Any,
+ dtype: _DType_co | None = None,
+ chunks: tuple[int, ...] | None = None,
+ drop_axis: int | Sequence[int] | None = None,
+ new_axis: int | Sequence[int] | None = None,
+ **kwargs: Any,
+ ) -> Any:
"""
Map a function across all blocks of a chunked array.
@@ -460,12 +599,18 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.map_blocks
cubed.map_blocks
"""
- pass
+ raise NotImplementedError()
- def blockwise(self, func: Callable[..., Any], out_ind: Iterable[Any], *
- args: Any, adjust_chunks: (dict[Any, Callable[..., Any]] | None)=
- None, new_axes: (dict[Any, int] | None)=None, align_arrays: bool=
- True, **kwargs: Any) ->Any:
+ def blockwise(
+ self,
+ func: Callable[..., Any],
+ out_ind: Iterable[Any],
+ *args: Any, # can't type this as mypy assumes args are all same type, but dask blockwise args alternate types
+ adjust_chunks: dict[Any, Callable[..., Any]] | None = None,
+ new_axes: dict[Any, int] | None = None,
+ align_arrays: bool = True,
+ **kwargs: Any,
+ ) -> Any:
"""
Tensor operation: Generalized inner and outer products.
@@ -505,10 +650,13 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.blockwise
cubed.core.blockwise
"""
- pass
+ raise NotImplementedError()
- def unify_chunks(self, *args: Any, **kwargs: Any) ->tuple[dict[str,
- _NormalizedChunks], list[T_ChunkedArray]]:
+ def unify_chunks(
+ self,
+ *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
+ **kwargs: Any,
+ ) -> tuple[dict[str, _NormalizedChunks], list[T_ChunkedArray]]:
"""
Unify chunks across a sequence of arrays.
@@ -524,10 +672,14 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.core.unify_chunks
cubed.core.unify_chunks
"""
- pass
+ raise NotImplementedError()
- def store(self, sources: (T_ChunkedArray | Sequence[T_ChunkedArray]),
- targets: Any, **kwargs: dict[str, Any]) ->Any:
+ def store(
+ self,
+ sources: T_ChunkedArray | Sequence[T_ChunkedArray],
+ targets: Any,
+ **kwargs: dict[str, Any],
+ ) -> Any:
"""
Store chunked arrays in array-like objects, overwriting data in target.
@@ -553,4 +705,4 @@ class ChunkManagerEntrypoint(ABC, Generic[T_ChunkedArray]):
dask.array.store
cubed.store
"""
- pass
+ raise NotImplementedError()
diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py
index c632ef6f..3ce33d4d 100644
--- a/xarray/namedarray/pycompat.py
+++ b/xarray/namedarray/pycompat.py
@@ -1,15 +1,20 @@
from __future__ import annotations
+
from importlib import import_module
from types import ModuleType
from typing import TYPE_CHECKING, Any, Literal
+
import numpy as np
from packaging.version import Version
+
from xarray.core.utils import is_scalar
from xarray.namedarray.utils import is_duck_array, is_duck_dask_array
-integer_types = int, np.integer
+
+integer_types = (int, np.integer)
+
if TYPE_CHECKING:
- ModType = Literal['dask', 'pint', 'cupy', 'sparse', 'cubed', 'numbagg']
- DuckArrayTypes = tuple[type[Any], ...]
+ ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"]
+ DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic
from xarray.namedarray._typing import _DType, _ShapeType, duckarray
@@ -20,36 +25,41 @@ class DuckArrayModule:
Motivated by having to only import pint when required (as pint currently imports xarray)
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
"""
+
module: ModuleType | None
version: Version
type: DuckArrayTypes
available: bool
- def __init__(self, mod: ModType) ->None:
+ def __init__(self, mod: ModType) -> None:
duck_array_module: ModuleType | None
duck_array_version: Version
duck_array_type: DuckArrayTypes
try:
duck_array_module = import_module(mod)
duck_array_version = Version(duck_array_module.__version__)
- if mod == 'dask':
- duck_array_type = import_module('dask.array').Array,
- elif mod == 'pint':
- duck_array_type = duck_array_module.Quantity,
- elif mod == 'cupy':
- duck_array_type = duck_array_module.ndarray,
- elif mod == 'sparse':
- duck_array_type = duck_array_module.SparseArray,
- elif mod == 'cubed':
- duck_array_type = duck_array_module.Array,
- elif mod == 'numbagg':
+
+ if mod == "dask":
+ duck_array_type = (import_module("dask.array").Array,)
+ elif mod == "pint":
+ duck_array_type = (duck_array_module.Quantity,)
+ elif mod == "cupy":
+ duck_array_type = (duck_array_module.ndarray,)
+ elif mod == "sparse":
+ duck_array_type = (duck_array_module.SparseArray,)
+ elif mod == "cubed":
+ duck_array_type = (duck_array_module.Array,)
+ # Not a duck array module, but using this system regardless, to get lazy imports
+ elif mod == "numbagg":
duck_array_type = ()
else:
raise NotImplementedError
- except (ImportError, AttributeError):
+
+ except (ImportError, AttributeError): # pragma: no cover
duck_array_module = None
- duck_array_version = Version('0.0.0')
+ duck_array_version = Version("0.0.0")
duck_array_type = ()
+
self.module = duck_array_module
self.version = duck_array_version
self.type = duck_array_type
@@ -59,11 +69,70 @@ class DuckArrayModule:
_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {}
-def array_type(mod: ModType) ->DuckArrayTypes:
+def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule:
+ if mod not in _cached_duck_array_modules:
+ duckmod = DuckArrayModule(mod)
+ _cached_duck_array_modules[mod] = duckmod
+ return duckmod
+ else:
+ return _cached_duck_array_modules[mod]
+
+
+def array_type(mod: ModType) -> DuckArrayTypes:
"""Quick wrapper to get the array class of the module."""
- pass
+ return _get_cached_duck_array_module(mod).type
-def mod_version(mod: ModType) ->Version:
+def mod_version(mod: ModType) -> Version:
"""Quick wrapper to get the version of the module."""
- pass
+ return _get_cached_duck_array_module(mod).version
+
+
+def is_chunked_array(x: duckarray[Any, Any]) -> bool:
+ return is_duck_dask_array(x) or (is_duck_array(x) and hasattr(x, "chunks"))
+
+
+def is_0d_dask_array(x: duckarray[Any, Any]) -> bool:
+ return is_duck_dask_array(x) and is_scalar(x)
+
+
+def to_numpy(
+ data: duckarray[Any, Any], **kwargs: dict[str, Any]
+) -> np.ndarray[Any, np.dtype[Any]]:
+ from xarray.core.indexing import ExplicitlyIndexed
+ from xarray.namedarray.parallelcompat import get_chunked_array_type
+
+ if isinstance(data, ExplicitlyIndexed):
+ data = data.get_duck_array() # type: ignore[no-untyped-call]
+
+ # TODO first attempt to call .to_numpy() once some libraries implement it
+ if is_chunked_array(data):
+ chunkmanager = get_chunked_array_type(data)
+ data, *_ = chunkmanager.compute(data, **kwargs)
+ if isinstance(data, array_type("cupy")):
+ data = data.get()
+ # pint has to be imported dynamically as pint imports xarray
+ if isinstance(data, array_type("pint")):
+ data = data.magnitude
+ if isinstance(data, array_type("sparse")):
+ data = data.todense()
+ data = np.asarray(data)
+
+ return data
+
+
+def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, _DType]:
+ from xarray.core.indexing import ExplicitlyIndexed
+ from xarray.namedarray.parallelcompat import get_chunked_array_type
+
+ if is_chunked_array(data):
+ chunkmanager = get_chunked_array_type(data)
+ loaded_data, *_ = chunkmanager.compute(data, **kwargs) # type: ignore[var-annotated]
+ return loaded_data
+
+ if isinstance(data, ExplicitlyIndexed):
+ return data.get_duck_array() # type: ignore[no-untyped-call, no-any-return]
+ elif is_duck_array(data):
+ return data
+ else:
+ return np.asarray(data) # type: ignore[return-value]
diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py
index 04211ef3..b82a80b5 100644
--- a/xarray/namedarray/utils.py
+++ b/xarray/namedarray/utils.py
@@ -1,33 +1,42 @@
from __future__ import annotations
+
import importlib
import sys
import warnings
from collections.abc import Hashable, Iterable, Iterator, Mapping
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar, cast
+
import numpy as np
from packaging.version import Version
+
from xarray.namedarray._typing import ErrorOptionsWithWarn, _DimsLike
+
if TYPE_CHECKING:
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
+
from numpy.typing import NDArray
+
try:
from dask.array.core import Array as DaskArray
from dask.typing import DaskCollection
except ImportError:
- DaskArray = NDArray
- DaskCollection: Any = NDArray
+ DaskArray = NDArray # type: ignore
+ DaskCollection: Any = NDArray # type: ignore
+
from xarray.namedarray._typing import _Dim, duckarray
-K = TypeVar('K')
-V = TypeVar('V')
-T = TypeVar('T')
+
+
+K = TypeVar("K")
+V = TypeVar("V")
+T = TypeVar("T")
@lru_cache
-def module_available(module: str, minversion: (str | None)=None) ->bool:
+def module_available(module: str, minversion: str | None = None) -> bool:
"""Checks whether a module is installed without importing it.
Use this for a lightweight check and lazy imports.
@@ -44,16 +53,65 @@ def module_available(module: str, minversion: (str | None)=None) ->bool:
available : bool
Whether the module is installed.
"""
- pass
+ if importlib.util.find_spec(module) is None:
+ return False
+
+ if minversion is not None:
+ version = importlib.metadata.version(module)
+
+ return Version(version) >= Version(minversion)
+
+ return True
+
+
+def is_dask_collection(x: object) -> TypeGuard[DaskCollection]:
+ if module_available("dask"):
+ from dask.base import is_dask_collection
+
+ # use is_dask_collection function instead of dask.typing.DaskCollection
+ # see https://github.com/pydata/xarray/pull/8241#discussion_r1476276023
+ return is_dask_collection(x)
+ return False
+
+def is_duck_array(value: Any) -> TypeGuard[duckarray[Any, Any]]:
+ # TODO: replace is_duck_array with runtime checks via _arrayfunction_or_api protocol on
+ # python 3.12 and higher (see https://github.com/pydata/xarray/issues/8696#issuecomment-1924588981)
+ if isinstance(value, np.ndarray):
+ return True
+ return (
+ hasattr(value, "ndim")
+ and hasattr(value, "shape")
+ and hasattr(value, "dtype")
+ and (
+ (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__"))
+ or hasattr(value, "__array_namespace__")
+ )
+ )
-def to_0d_object_array(value: object) ->NDArray[np.object_]:
+
+def is_duck_dask_array(x: duckarray[Any, Any]) -> TypeGuard[DaskArray]:
+ return is_duck_array(x) and is_dask_collection(x)
+
+
+def to_0d_object_array(
+ value: object,
+) -> NDArray[np.object_]:
"""Given a value, wrap it in a 0-D numpy.ndarray with dtype=object."""
- pass
+ result = np.empty((), dtype=object)
+ result[()] = value
+ return result
+
+def is_dict_like(value: Any) -> TypeGuard[Mapping[Any, Any]]:
+ return hasattr(value, "keys") and hasattr(value, "__getitem__")
-def drop_missing_dims(supplied_dims: Iterable[_Dim], dims: Iterable[_Dim],
- missing_dims: ErrorOptionsWithWarn) ->_DimsLike:
+
+def drop_missing_dims(
+ supplied_dims: Iterable[_Dim],
+ dims: Iterable[_Dim],
+ missing_dims: ErrorOptionsWithWarn,
+) -> _DimsLike:
"""Depending on the setting of missing_dims, drop any dimensions from supplied_dims that
are not present in dims.
@@ -63,36 +121,104 @@ def drop_missing_dims(supplied_dims: Iterable[_Dim], dims: Iterable[_Dim],
dims : Iterable of Hashable
missing_dims : {"raise", "warn", "ignore"}
"""
- pass
+ if missing_dims == "raise":
+ supplied_dims_set = {val for val in supplied_dims if val is not ...}
+ if invalid := supplied_dims_set - set(dims):
+ raise ValueError(
+ f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
+ )
+
+ return supplied_dims
+
+ elif missing_dims == "warn":
+ if invalid := set(supplied_dims) - set(dims):
+ warnings.warn(
+ f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
+ )
+
+ return [val for val in supplied_dims if val in dims or val is ...]
-def infix_dims(dims_supplied: Iterable[_Dim], dims_all: Iterable[_Dim],
- missing_dims: ErrorOptionsWithWarn='raise') ->Iterator[_Dim]:
+ elif missing_dims == "ignore":
+ return [val for val in supplied_dims if val in dims or val is ...]
+
+ else:
+ raise ValueError(
+ f"Unrecognised option {missing_dims} for missing_dims argument"
+ )
+
+
+def infix_dims(
+ dims_supplied: Iterable[_Dim],
+ dims_all: Iterable[_Dim],
+ missing_dims: ErrorOptionsWithWarn = "raise",
+) -> Iterator[_Dim]:
"""
Resolves a supplied list containing an ellipsis representing other items, to
a generator with the 'realized' list of all items
"""
- pass
+ if ... in dims_supplied:
+ dims_all_list = list(dims_all)
+ if len(set(dims_all)) != len(dims_all_list):
+ raise ValueError("Cannot use ellipsis with repeated dims")
+ if list(dims_supplied).count(...) > 1:
+ raise ValueError("More than one ellipsis supplied")
+ other_dims = [d for d in dims_all if d not in dims_supplied]
+ existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
+ for d in existing_dims:
+ if d is ...:
+ yield from other_dims
+ else:
+ yield d
+ else:
+ existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
+ if set(existing_dims) ^ set(dims_all):
+ raise ValueError(
+ f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
+ )
+ yield from existing_dims
+
+
+def either_dict_or_kwargs(
+ pos_kwargs: Mapping[Any, T] | None,
+ kw_kwargs: Mapping[str, T],
+ func_name: str,
+) -> Mapping[Hashable, T]:
+ if pos_kwargs is None or pos_kwargs == {}:
+ # Need an explicit cast to appease mypy due to invariance; see
+ # https://github.com/python/mypy/issues/6228
+ return cast(Mapping[Hashable, T], kw_kwargs)
+
+ if not is_dict_like(pos_kwargs):
+ raise ValueError(f"the first argument to .{func_name} must be a dictionary")
+ if kw_kwargs:
+ raise ValueError(
+ f"cannot specify both keyword and positional arguments to .{func_name}"
+ )
+ return pos_kwargs
class ReprObject:
"""Object that prints as the given value, for use with sentinel values."""
- __slots__ = '_value',
+
+ __slots__ = ("_value",)
+
_value: str
def __init__(self, value: str):
self._value = value
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
return self._value
- def __eq__(self, other: (ReprObject | Any)) ->bool:
- return self._value == other._value if isinstance(other, ReprObject
- ) else False
+ def __eq__(self, other: ReprObject | Any) -> bool:
+ # TODO: What type can other be? ArrayLike?
+ return self._value == other._value if isinstance(other, ReprObject) else False
- def __hash__(self) ->int:
+ def __hash__(self) -> int:
return hash((type(self), self._value))
- def __dask_tokenize__(self) ->object:
+ def __dask_tokenize__(self) -> object:
from dask.base import normalize_token
+
return normalize_token((type(self), self._value))
diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py
index 029a5a78..9db4ae4e 100644
--- a/xarray/plot/accessor.py
+++ b/xarray/plot/accessor.py
@@ -1,9 +1,14 @@
from __future__ import annotations
+
import functools
from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Literal, NoReturn, overload
+
import numpy as np
+
+# Accessor methods have the same name as plotting methods, so we need a different namespace
from xarray.plot import dataarray_plot, dataset_plot
+
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection, PathCollection, QuadMesh
@@ -15,6 +20,7 @@ if TYPE_CHECKING:
from matplotlib.quiver import Quiver
from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection
from numpy.typing import ArrayLike
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import AspectOptions, HueStyleOptions, ScaleOptions
@@ -26,31 +32,1241 @@ class DataArrayPlotAccessor:
Enables use of xarray.plot functions as attributes on a DataArray.
For example, DataArray.plot.imshow
"""
+
_da: DataArray
- __slots__ = '_da',
+
+ __slots__ = ("_da",)
__doc__ = dataarray_plot.plot.__doc__
- def __init__(self, darray: DataArray) ->None:
+ def __init__(self, darray: DataArray) -> None:
self._da = darray
- @functools.wraps(dataarray_plot.plot, assigned=('__doc__',
- '__annotations__'))
- def __call__(self, **kwargs) ->Any:
+ # Should return Any such that the user does not run into problems
+ # with the many possible return values
+ @functools.wraps(dataarray_plot.plot, assigned=("__doc__", "__annotations__"))
+ def __call__(self, **kwargs) -> Any:
return dataarray_plot.plot(self._da, **kwargs)
+ @functools.wraps(dataarray_plot.hist)
+ def hist(
+ self, *args, **kwargs
+ ) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]:
+ return dataarray_plot.hist(self._da, *args, **kwargs)
+
+ @overload
+ def line( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+ ) -> list[Line3D]: ...
+
+ @overload
+ def line(
+ self,
+ *args: Any,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def line(
+ self,
+ *args: Any,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @functools.wraps(dataarray_plot.line, assigned=("__doc__",))
+ def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]:
+ return dataarray_plot.line(self._da, *args, **kwargs)
+
+ @overload
+ def step( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ **kwargs: Any,
+ ) -> list[Line3D]: ...
+
+ @overload
+ def step(
+ self,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def step(
+ self,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @functools.wraps(dataarray_plot.step, assigned=("__doc__",))
+ def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]:
+ return dataarray_plot.step(self._da, *args, **kwargs)
+
+ @overload
+ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap=None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend=None,
+ levels=None,
+ **kwargs,
+ ) -> PathCollection: ...
+
+ @overload
+ def scatter(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap=None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend=None,
+ levels=None,
+ **kwargs,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def scatter(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap=None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend=None,
+ levels=None,
+ **kwargs,
+ ) -> FacetGrid[DataArray]: ...
+
+ @functools.wraps(dataarray_plot.scatter, assigned=("__doc__",))
+ def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]:
+ return dataarray_plot.scatter(self._da, *args, **kwargs)
+
+ @overload
+ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> AxesImage: ...
+
+ @overload
+ def imshow(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def imshow(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @functools.wraps(dataarray_plot.imshow, assigned=("__doc__",))
+ def imshow(self, *args, **kwargs) -> AxesImage | FacetGrid[DataArray]:
+ return dataarray_plot.imshow(self._da, *args, **kwargs)
+
+ @overload
+ def contour( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> QuadContourSet: ...
+
+ @overload
+ def contour(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def contour(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @functools.wraps(dataarray_plot.contour, assigned=("__doc__",))
+ def contour(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]:
+ return dataarray_plot.contour(self._da, *args, **kwargs)
+
+ @overload
+ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> QuadContourSet: ...
+
+ @overload
+ def contourf(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def contourf(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid: ...
+
+ @functools.wraps(dataarray_plot.contourf, assigned=("__doc__",))
+ def contourf(self, *args, **kwargs) -> QuadContourSet | FacetGrid[DataArray]:
+ return dataarray_plot.contourf(self._da, *args, **kwargs)
+
+ @overload
+ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> QuadMesh: ...
+
+ @overload
+ def pcolormesh(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @overload
+ def pcolormesh(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid[DataArray]: ...
+
+ @functools.wraps(dataarray_plot.pcolormesh, assigned=("__doc__",))
+ def pcolormesh(self, *args, **kwargs) -> QuadMesh | FacetGrid[DataArray]:
+ return dataarray_plot.pcolormesh(self._da, *args, **kwargs)
+
+ @overload
+ def surface(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> Poly3DCollection: ...
+
+ @overload
+ def surface(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid: ...
+
+ @overload
+ def surface(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap=None,
+ center=None,
+ robust: bool = False,
+ extend=None,
+ levels=None,
+ infer_intervals=None,
+ colors=None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> FacetGrid: ...
+
+ @functools.wraps(dataarray_plot.surface, assigned=("__doc__",))
+ def surface(self, *args, **kwargs) -> Poly3DCollection:
+ return dataarray_plot.surface(self._da, *args, **kwargs)
+
class DatasetPlotAccessor:
"""
Enables use of xarray.plot functions as attributes on a Dataset.
For example, Dataset.plot.scatter
"""
+
_ds: Dataset
- __slots__ = '_ds',
+ __slots__ = ("_ds",)
- def __init__(self, dataset: Dataset) ->None:
+ def __init__(self, dataset: Dataset) -> None:
self._ds = dataset
- def __call__(self, *args, **kwargs) ->NoReturn:
+ def __call__(self, *args, **kwargs) -> NoReturn:
raise ValueError(
- 'Dataset.plot cannot be called directly. Use an explicit plot method, e.g. ds.plot.scatter(...)'
- )
+ "Dataset.plot cannot be called directly. Use "
+ "an explicit plot method, e.g. ds.plot.scatter(...)"
+ )
+
+ @overload
+ def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap=None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend=None,
+ levels=None,
+ **kwargs: Any,
+ ) -> PathCollection: ...
+
+ @overload
+ def scatter(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap=None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend=None,
+ levels=None,
+ **kwargs: Any,
+ ) -> FacetGrid[Dataset]: ...
+
+ @overload
+ def scatter(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap=None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend=None,
+ levels=None,
+ **kwargs: Any,
+ ) -> FacetGrid[Dataset]: ...
+
+ @functools.wraps(dataset_plot.scatter, assigned=("__doc__",))
+ def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[Dataset]:
+ return dataset_plot.scatter(self._ds, *args, **kwargs)
+
+ @overload
+ def quiver( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: None = None, # no wrap -> primitive
+ row: None = None, # no wrap -> primitive
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust: bool | None = None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ **kwargs: Any,
+ ) -> Quiver: ...
+
+ @overload
+ def quiver(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable, # wrap -> FacetGrid
+ row: Hashable | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust: bool | None = None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ **kwargs: Any,
+ ) -> FacetGrid[Dataset]: ...
+
+ @overload
+ def quiver(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust: bool | None = None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ **kwargs: Any,
+ ) -> FacetGrid[Dataset]: ...
+
+ @functools.wraps(dataset_plot.quiver, assigned=("__doc__",))
+ def quiver(self, *args, **kwargs) -> Quiver | FacetGrid[Dataset]:
+ return dataset_plot.quiver(self._ds, *args, **kwargs)
+
+ @overload
+ def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: None = None, # no wrap -> primitive
+ row: None = None, # no wrap -> primitive
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust: bool | None = None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ **kwargs: Any,
+ ) -> LineCollection: ...
+
+ @overload
+ def streamplot(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable, # wrap -> FacetGrid
+ row: Hashable | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust: bool | None = None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ **kwargs: Any,
+ ) -> FacetGrid[Dataset]: ...
+
+ @overload
+ def streamplot(
+ self,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals=None,
+ center=None,
+ levels=None,
+ robust: bool | None = None,
+ colors=None,
+ extend=None,
+ cmap=None,
+ **kwargs: Any,
+ ) -> FacetGrid[Dataset]: ...
+
+ @functools.wraps(dataset_plot.streamplot, assigned=("__doc__",))
+ def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid[Dataset]:
+ return dataset_plot.streamplot(self._ds, *args, **kwargs)
diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py
index f56da420..ed752d34 100644
--- a/xarray/plot/dataarray_plot.py
+++ b/xarray/plot/dataarray_plot.py
@@ -1,14 +1,38 @@
from __future__ import annotations
+
import functools
import warnings
from collections.abc import Hashable, Iterable, MutableMapping
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast, overload
+
import numpy as np
import pandas as pd
+
from xarray.core.alignment import broadcast
from xarray.core.concat import concat
from xarray.plot.facetgrid import _easy_facetgrid
-from xarray.plot.utils import _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, _add_colorbar, _add_legend, _assert_valid_xy, _determine_guide, _ensure_plottable, _guess_coords_to_plot, _infer_interval_breaks, _infer_xy_labels, _Normalize, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, _resolve_intervals_1dplot, _resolve_intervals_2dplot, _set_concise_date, _update_axes, get_axis, label_from_attrs
+from xarray.plot.utils import (
+ _LINEWIDTH_RANGE,
+ _MARKERSIZE_RANGE,
+ _add_colorbar,
+ _add_legend,
+ _assert_valid_xy,
+ _determine_guide,
+ _ensure_plottable,
+ _guess_coords_to_plot,
+ _infer_interval_breaks,
+ _infer_xy_labels,
+ _Normalize,
+ _process_cmap_cbar_kwargs,
+ _rescale_imshow_rgb,
+ _resolve_intervals_1dplot,
+ _resolve_intervals_2dplot,
+ _set_concise_date,
+ _update_axes,
+ get_axis,
+ label_from_attrs,
+)
+
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.collections import PathCollection, QuadMesh
@@ -19,15 +43,113 @@ if TYPE_CHECKING:
from matplotlib.patches import Polygon
from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection
from numpy.typing import ArrayLike
+
from xarray.core.dataarray import DataArray
- from xarray.core.types import AspectOptions, ExtendOptions, HueStyleOptions, ScaleOptions, T_DataArray
+ from xarray.core.types import (
+ AspectOptions,
+ ExtendOptions,
+ HueStyleOptions,
+ ScaleOptions,
+ T_DataArray,
+ )
from xarray.plot.facetgrid import FacetGrid
-_styles: dict[str, Any] = {'scatter.edgecolors': 'w'}
+_styles: dict[str, Any] = {
+ # Add a white border to make it easier seeing overlapping markers:
+ "scatter.edgecolors": "w",
+}
+
+
+def _infer_line_data(
+ darray: DataArray, x: Hashable | None, y: Hashable | None, hue: Hashable | None
+) -> tuple[DataArray, DataArray, DataArray | None, str]:
+ ndims = len(darray.dims)
+
+ if x is not None and y is not None:
+ raise ValueError("Cannot specify both x and y kwargs for line plots.")
+
+ if x is not None:
+ _assert_valid_xy(darray, x, "x")
+
+ if y is not None:
+ _assert_valid_xy(darray, y, "y")
+
+ if ndims == 1:
+ huename = None
+ hueplt = None
+ huelabel = ""
+
+ if x is not None:
+ xplt = darray[x]
+ yplt = darray
+
+ elif y is not None:
+ xplt = darray
+ yplt = darray[y]
+
+ else: # Both x & y are None
+ dim = darray.dims[0]
+ xplt = darray[dim]
+ yplt = darray
+
+ else:
+ if x is None and y is None and hue is None:
+ raise ValueError("For 2D inputs, please specify either hue, x or y.")
+
+ if y is None:
+ if hue is not None:
+ _assert_valid_xy(darray, hue, "hue")
+ xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue)
+ xplt = darray[xname]
+ if xplt.ndim > 1:
+ if huename in darray.dims:
+ otherindex = 1 if darray.dims.index(huename) == 0 else 0
+ otherdim = darray.dims[otherindex]
+ yplt = darray.transpose(otherdim, huename, transpose_coords=False)
+ xplt = xplt.transpose(otherdim, huename, transpose_coords=False)
+ else:
+ raise ValueError(
+ "For 2D inputs, hue must be a dimension"
+ " i.e. one of " + repr(darray.dims)
+ )
+
+ else:
+ (xdim,) = darray[xname].dims
+ (huedim,) = darray[huename].dims
+ yplt = darray.transpose(xdim, huedim)
-def _prepare_plot1d_data(darray: T_DataArray, coords_to_plot:
- MutableMapping[str, Hashable], plotfunc_name: (str | None)=None,
- _is_facetgrid: bool=False) ->dict[str, T_DataArray]:
+ else:
+ yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue)
+ yplt = darray[yname]
+ if yplt.ndim > 1:
+ if huename in darray.dims:
+ otherindex = 1 if darray.dims.index(huename) == 0 else 0
+ otherdim = darray.dims[otherindex]
+ xplt = darray.transpose(otherdim, huename, transpose_coords=False)
+ yplt = yplt.transpose(otherdim, huename, transpose_coords=False)
+ else:
+ raise ValueError(
+ "For 2D inputs, hue must be a dimension"
+ " i.e. one of " + repr(darray.dims)
+ )
+
+ else:
+ (ydim,) = darray[yname].dims
+ (huedim,) = darray[huename].dims
+ xplt = darray.transpose(ydim, huedim)
+
+ huelabel = label_from_attrs(darray[huename])
+ hueplt = darray[huename]
+
+ return xplt, yplt, hueplt, huelabel
+
+
+def _prepare_plot1d_data(
+ darray: T_DataArray,
+ coords_to_plot: MutableMapping[str, Hashable],
+ plotfunc_name: str | None = None,
+ _is_facetgrid: bool = False,
+) -> dict[str, T_DataArray]:
"""
Prepare data for usage with plt.scatter.
@@ -61,13 +183,50 @@ def _prepare_plot1d_data(darray: T_DataArray, coords_to_plot:
>>> print({k: v.name for k, v in plts.items()})
{'y': 'a', 'x': 1}
"""
- pass
+ # If there are more than 1 dimension in the array than stack all the
+ # dimensions so the plotter can plot anything:
+ if darray.ndim > 1:
+ # When stacking dims the lines will continue connecting. For floats
+ # this can be solved by adding a nan element in between the flattening
+ # points:
+ dims_T = []
+ if np.issubdtype(darray.dtype, np.floating):
+ for v in ["z", "x"]:
+ dim = coords_to_plot.get(v, None)
+ if (dim is not None) and (dim in darray.dims):
+ darray_nan = np.nan * darray.isel({dim: -1})
+ darray = concat([darray, darray_nan], dim=dim)
+ dims_T.append(coords_to_plot[v])
+
+ # Lines should never connect to the same coordinate when stacked,
+ # transpose to avoid this as much as possible:
+ darray = darray.transpose(..., *dims_T)
+
+ # Array is now ready to be stacked:
+ darray = darray.stack(_stacked_dim=darray.dims)
+
+ # Broadcast together all the chosen variables:
+ plts = dict(y=darray)
+ plts.update(
+ {k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None}
+ )
+ plts = dict(zip(plts.keys(), broadcast(*(plts.values()))))
+ return plts
-def plot(darray: DataArray, *, row: (Hashable | None)=None, col: (Hashable |
- None)=None, col_wrap: (int | None)=None, ax: (Axes | None)=None, hue: (
- Hashable | None)=None, subplot_kws: (dict[str, Any] | None)=None, **
- kwargs: Any) ->Any:
+
+# return type is Any due to the many different possibilities
+def plot(
+ darray: DataArray,
+ *,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ **kwargs: Any,
+) -> Any:
"""
Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`.
@@ -106,19 +265,157 @@ def plot(darray: DataArray, *, row: (Hashable | None)=None, col: (Hashable |
--------
xarray.DataArray.squeeze
"""
- pass
-
-
-def line(darray: T_DataArray, *args: Any, row: (Hashable | None)=None, col:
- (Hashable | None)=None, figsize: (Iterable[float] | None)=None, aspect:
- AspectOptions=None, size: (float | None)=None, ax: (Axes | None)=None,
- hue: (Hashable | None)=None, x: (Hashable | None)=None, y: (Hashable |
- None)=None, xincrease: (bool | None)=None, yincrease: (bool | None)=
- None, xscale: ScaleOptions=None, yscale: ScaleOptions=None, xticks: (
- ArrayLike | None)=None, yticks: (ArrayLike | None)=None, xlim: (tuple[
- float, float] | None)=None, ylim: (tuple[float, float] | None)=None,
- add_legend: bool=True, _labels: bool=True, **kwargs: Any) ->(list[
- Line3D] | FacetGrid[T_DataArray]):
+ darray = darray.squeeze(
+ d for d, s in darray.sizes.items() if s == 1 and d not in (row, col, hue)
+ ).compute()
+
+ plot_dims = set(darray.dims)
+ plot_dims.discard(row)
+ plot_dims.discard(col)
+ plot_dims.discard(hue)
+
+ ndims = len(plot_dims)
+
+ plotfunc: Callable
+
+ if ndims == 0 or darray.size == 0:
+ raise TypeError("No numeric data to plot.")
+ if ndims in (1, 2):
+ if row or col:
+ kwargs["subplot_kws"] = subplot_kws
+ kwargs["row"] = row
+ kwargs["col"] = col
+ kwargs["col_wrap"] = col_wrap
+ if ndims == 1:
+ plotfunc = line
+ kwargs["hue"] = hue
+ elif ndims == 2:
+ if hue:
+ plotfunc = line
+ kwargs["hue"] = hue
+ else:
+ plotfunc = pcolormesh
+ kwargs["subplot_kws"] = subplot_kws
+ else:
+ if row or col or hue:
+ raise ValueError(
+ "Only 1d and 2d plots are supported for facets in xarray. "
+ "See the package `Seaborn` for more options."
+ )
+ plotfunc = hist
+
+ kwargs["ax"] = ax
+
+ return plotfunc(darray, **kwargs)
+
+
+@overload
+def line( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ *args: Any,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+) -> list[Line3D]: ...
+
+
+@overload
+def line(
+ darray: T_DataArray,
+ *args: Any,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def line(
+ darray: T_DataArray,
+ *args: Any,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+# This function signature should not change so that it can use
+# matplotlib format strings
+def line(
+ darray: T_DataArray,
+ *args: Any,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ aspect: AspectOptions = None,
+ size: float | None = None,
+ ax: Axes | None = None,
+ hue: Hashable | None = None,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ add_legend: bool = True,
+ _labels: bool = True,
+ **kwargs: Any,
+) -> list[Line3D] | FacetGrid[T_DataArray]:
"""
Line plot of DataArray values.
@@ -175,13 +472,114 @@ def line(darray: T_DataArray, *args: Any, row: (Hashable | None)=None, col:
When either col or row is given, returns a FacetGrid, otherwise
a list of matplotlib Line3D objects.
"""
- pass
+ # Handle facetgrids first
+ if row or col:
+ allargs = locals().copy()
+ allargs.update(allargs.pop("kwargs"))
+ allargs.pop("darray")
+ return _easy_facetgrid(darray, line, kind="line", **allargs)
+
+ ndims = len(darray.dims)
+ if ndims == 0 or darray.size == 0:
+ # TypeError to be consistent with pandas
+ raise TypeError("No numeric data to plot.")
+ if ndims > 2:
+ raise ValueError(
+ "Line plots are for 1- or 2-dimensional DataArrays. "
+ f"Passed DataArray has {ndims} "
+ "dimensions"
+ )
+
+ # The allargs dict passed to _easy_facetgrid above contains args
+ if args == ():
+ args = kwargs.pop("args", ())
+ else:
+ assert "args" not in kwargs
+
+ ax = get_axis(figsize, size, aspect, ax)
+ xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
+
+ # Remove pd.Intervals if contained in xplt.values and/or yplt.values.
+ xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
+ xplt.to_numpy(), yplt.to_numpy(), kwargs
+ )
+ xlabel = label_from_attrs(xplt, extra=x_suffix)
+ ylabel = label_from_attrs(yplt, extra=y_suffix)
+
+ _ensure_plottable(xplt_val, yplt_val)
+
+ primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs)
+
+ if _labels:
+ if xlabel is not None:
+ ax.set_xlabel(xlabel)
+
+ if ylabel is not None:
+ ax.set_ylabel(ylabel)
+
+ ax.set_title(darray._title_for_slice())
+ if darray.ndim == 2 and add_legend:
+ assert hueplt is not None
+ ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
-def step(darray: DataArray, *args: Any, where: Literal['pre', 'post', 'mid'
- ]='pre', drawstyle: (str | None)=None, ds: (str | None)=None, row: (
- Hashable | None)=None, col: (Hashable | None)=None, **kwargs: Any) ->(list
- [Line3D] | FacetGrid[DataArray]):
+ if np.issubdtype(xplt.dtype, np.datetime64):
+ _set_concise_date(ax, axis="x")
+
+ _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
+
+ return primitive
+
+
+@overload
+def step( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ **kwargs: Any,
+) -> list[Line3D]: ...
+
+
+@overload
+def step(
+ darray: DataArray,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ **kwargs: Any,
+) -> FacetGrid[DataArray]: ...
+
+
+@overload
+def step(
+ darray: DataArray,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ **kwargs: Any,
+) -> FacetGrid[DataArray]: ...
+
+
+def step(
+ darray: DataArray,
+ *args: Any,
+ where: Literal["pre", "post", "mid"] = "pre",
+ drawstyle: str | None = None,
+ ds: str | None = None,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ **kwargs: Any,
+) -> list[Line3D] | FacetGrid[DataArray]:
"""
Step plot of DataArray values.
@@ -219,16 +617,38 @@ def step(darray: DataArray, *args: Any, where: Literal['pre', 'post', 'mid'
When either col or row is given, returns a FacetGrid, otherwise
a list of matplotlib Line3D objects.
"""
- pass
+ if where not in {"pre", "post", "mid"}:
+ raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'")
+
+ if ds is not None:
+ if drawstyle is None:
+ drawstyle = ds
+ else:
+ raise TypeError("ds and drawstyle are mutually exclusive")
+ if drawstyle is None:
+ drawstyle = ""
+ drawstyle = "steps-" + where + drawstyle
+
+ return line(darray, *args, drawstyle=drawstyle, col=col, row=row, **kwargs)
-def hist(darray: DataArray, *args: Any, figsize: (Iterable[float] | None)=
- None, size: (float | None)=None, aspect: AspectOptions=None, ax: (Axes |
- None)=None, xincrease: (bool | None)=None, yincrease: (bool | None)=
- None, xscale: ScaleOptions=None, yscale: ScaleOptions=None, xticks: (
- ArrayLike | None)=None, yticks: (ArrayLike | None)=None, xlim: (tuple[
- float, float] | None)=None, ylim: (tuple[float, float] | None)=None, **
- kwargs: Any) ->tuple[np.ndarray, np.ndarray, BarContainer | Polygon]:
+def hist(
+ darray: DataArray,
+ *args: Any,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ xincrease: bool | None = None,
+ yincrease: bool | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ **kwargs: Any,
+) -> tuple[np.ndarray, np.ndarray, BarContainer | Polygon]:
"""
Histogram of DataArray.
@@ -268,38 +688,1094 @@ def hist(darray: DataArray, *args: Any, figsize: (Iterable[float] | None)=
Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`.
"""
- pass
+ assert len(args) == 0
+
+ if darray.ndim == 0 or darray.size == 0:
+ # TypeError to be consistent with pandas
+ raise TypeError("No numeric data to plot.")
+
+ ax = get_axis(figsize, size, aspect, ax)
+
+ no_nan = np.ravel(darray.to_numpy())
+ no_nan = no_nan[pd.notnull(no_nan)]
+
+ n, bins, patches = cast(
+ tuple[np.ndarray, np.ndarray, Union["BarContainer", "Polygon"]],
+ ax.hist(no_nan, **kwargs),
+ )
+
+ ax.set_title(darray._title_for_slice())
+ ax.set_xlabel(label_from_attrs(darray))
+
+ _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
+
+ return n, bins, patches
def _plot1d(plotfunc):
"""Decorator for common 1d plotting logic."""
- pass
+ commondoc = """
+ Parameters
+ ----------
+ darray : DataArray
+ Must be 2 dimensional, unless creating faceted plots.
+ x : Hashable or None, optional
+ Coordinate for x axis. If None use darray.dims[1].
+ y : Hashable or None, optional
+ Coordinate for y axis. If None use darray.dims[0].
+ z : Hashable or None, optional
+ If specified plot 3D and use this coordinate for *z* axis.
+ hue : Hashable or None, optional
+ Dimension or coordinate for which you want multiple lines plotted.
+ markersize: Hashable or None, optional
+ scatter only. Variable by which to vary size of scattered points.
+ linewidth: Hashable or None, optional
+ Variable by which to vary linewidth.
+ row : Hashable, optional
+ If passed, make row faceted plots on this dimension name.
+ col : Hashable, optional
+ If passed, make column faceted plots on this dimension name.
+ col_wrap : int, optional
+ Use together with ``col`` to wrap faceted plots
+ ax : matplotlib axes object, optional
+ If None, uses the current axis. Not applicable when using facets.
+ figsize : Iterable[float] or None, optional
+ A tuple (width, height) of the figure in inches.
+ Mutually exclusive with ``size`` and ``ax``.
+ size : scalar, optional
+ If provided, create a new figure for the plot with the given size.
+ Height (in inches) of each plot. See also: ``aspect``.
+ aspect : "auto", "equal", scalar or None, optional
+ Aspect ratio of plot, so that ``aspect * size`` gives the width in
+ inches. Only used if a ``size`` is provided.
+ xincrease : bool or None, default: True
+ Should the values on the x axes be increasing from left to right?
+ if None, use the default for the matplotlib function.
+ yincrease : bool or None, default: True
+ Should the values on the y axes be increasing from top to bottom?
+ if None, use the default for the matplotlib function.
+ add_legend : bool or None, optional
+ If True use xarray metadata to add a legend.
+ add_colorbar : bool or None, optional
+ If True add a colorbar.
+ add_labels : bool or None, optional
+ If True use xarray metadata to label axes
+ add_title : bool or None, optional
+ If True use xarray metadata to add a title
+ subplot_kws : dict, optional
+ Dictionary of keyword arguments for matplotlib subplots. Only applies
+ to FacetGrid plotting.
+ xscale : {'linear', 'symlog', 'log', 'logit'} or None, optional
+ Specifies scaling for the x-axes.
+ yscale : {'linear', 'symlog', 'log', 'logit'} or None, optional
+ Specifies scaling for the y-axes.
+ xticks : ArrayLike or None, optional
+ Specify tick locations for x-axes.
+ yticks : ArrayLike or None, optional
+ Specify tick locations for y-axes.
+ xlim : tuple[float, float] or None, optional
+ Specify x-axes limits.
+ ylim : tuple[float, float] or None, optional
+ Specify y-axes limits.
+ cmap : matplotlib colormap name or colormap, optional
+ The mapping from data values to color space. Either a
+ Matplotlib colormap name or object. If not provided, this will
+ be either ``'viridis'`` (if the function infers a sequential
+ dataset) or ``'RdBu_r'`` (if the function infers a diverging
+ dataset).
+ See :doc:`Choosing Colormaps in Matplotlib <matplotlib:users/explain/colors/colormaps>`
+ for more information.
+
+ If *seaborn* is installed, ``cmap`` may also be a
+ `seaborn color palette <https://seaborn.pydata.org/tutorial/color_palettes.html>`_.
+ Note: if ``cmap`` is a seaborn color palette,
+ ``levels`` must also be specified.
+ vmin : float or None, optional
+ Lower value to anchor the colormap, otherwise it is inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting `vmin` or `vmax` will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ vmax : float or None, optional
+ Upper value to anchor the colormap, otherwise it is inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting `vmin` or `vmax` will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ norm : matplotlib.colors.Normalize, optional
+ If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding
+ kwarg must be ``None``.
+ extend : {'neither', 'both', 'min', 'max'}, optional
+ How to draw arrows extending the colorbar beyond its limits. If not
+ provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits.
+ levels : int or array-like, optional
+ Split the colormap (``cmap``) into discrete color intervals. If an integer
+ is provided, "nice" levels are chosen based on the data range: this can
+ imply that the final number of levels is not exactly the expected one.
+ Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to
+ setting ``levels=np.linspace(vmin, vmax, N)``.
+ **kwargs : optional
+ Additional arguments to wrapped matplotlib function
+
+ Returns
+ -------
+ artist :
+ The same type of primitive artist that the wrapped matplotlib
+ function returns
+ """
+
+ # Build on the original docstring
+ plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}"
+
+ @functools.wraps(
+ plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__")
+ )
+ def newplotfunc(
+ darray: DataArray,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs,
+ ) -> Any:
+ # All 1d plots in xarray share this function signature.
+ # Method signature below should be consistent.
+
+ import matplotlib.pyplot as plt
+
+ if subplot_kws is None:
+ subplot_kws = dict()
+
+ # Handle facetgrids first
+ if row or col:
+ if z is not None:
+ subplot_kws.update(projection="3d")
+
+ allargs = locals().copy()
+ allargs.update(allargs.pop("kwargs"))
+ allargs.pop("darray")
+ allargs.pop("plt")
+ allargs["plotfunc"] = globals()[plotfunc.__name__]
+
+ return _easy_facetgrid(darray, kind="plot1d", **allargs)
+ if darray.ndim == 0 or darray.size == 0:
+ # TypeError to be consistent with pandas
+ raise TypeError("No numeric data to plot.")
-def _add_labels(add_labels: (bool | Iterable[bool]), darrays: Iterable[
- DataArray | None], suffixes: Iterable[str], ax: Axes) ->None:
+ # The allargs dict passed to _easy_facetgrid above contains args
+ if args == ():
+ args = kwargs.pop("args", ())
+
+ if args:
+ assert "args" not in kwargs
+ # TODO: Deprecated since 2022.10:
+ msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead."
+ assert x is None
+ x = args[0]
+ if len(args) > 1:
+ assert y is None
+ y = args[1]
+ if len(args) > 2:
+ assert z is None
+ z = args[2]
+ if len(args) > 3:
+ assert hue is None
+ hue = args[3]
+ if len(args) > 4:
+ raise ValueError(msg)
+ else:
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
+ del args
+
+ if hue_style is not None:
+ # TODO: Not used since 2022.10. Deprecated since 2023.07.
+ warnings.warn(
+ (
+ "hue_style is no longer used for plot1d plots "
+ "and the argument will eventually be removed. "
+ "Convert numbers to string for a discrete hue "
+ "and use add_legend or add_colorbar to control which guide to display."
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ _is_facetgrid = kwargs.pop("_is_facetgrid", False)
+
+ if plotfunc.__name__ == "scatter":
+ size_ = kwargs.pop("_size", markersize)
+ size_r = _MARKERSIZE_RANGE
+
+ # Remove any nulls, .where(m, drop=True) doesn't work when m is
+ # a dask array, so load the array to memory.
+ # It will have to be loaded to memory at some point anyway:
+ darray = darray.load()
+ darray = darray.where(darray.notnull(), drop=True)
+ else:
+ size_ = kwargs.pop("_size", linewidth)
+ size_r = _LINEWIDTH_RANGE
+
+ # Get data to plot:
+ coords_to_plot: MutableMapping[str, Hashable | None] = dict(
+ x=x, z=z, hue=hue, size=size_
+ )
+ if not _is_facetgrid:
+ # Guess what coords to use if some of the values in coords_to_plot are None:
+ coords_to_plot = _guess_coords_to_plot(darray, coords_to_plot, kwargs)
+ plts = _prepare_plot1d_data(darray, coords_to_plot, plotfunc.__name__)
+ xplt = plts.pop("x", None)
+ yplt = plts.pop("y", None)
+ zplt = plts.pop("z", None)
+ kwargs.update(zplt=zplt)
+ hueplt = plts.pop("hue", None)
+ sizeplt = plts.pop("size", None)
+
+ # Handle size and hue:
+ hueplt_norm = _Normalize(data=hueplt)
+ kwargs.update(hueplt=hueplt_norm.values)
+ sizeplt_norm = _Normalize(
+ data=sizeplt, width=size_r, _is_facetgrid=_is_facetgrid
+ )
+ kwargs.update(sizeplt=sizeplt_norm.values)
+ cmap_params_subset = kwargs.pop("cmap_params_subset", {})
+ cbar_kwargs = kwargs.pop("cbar_kwargs", {})
+
+ if hueplt_norm.data is not None:
+ if not hueplt_norm.data_is_numeric:
+ # Map hue values back to its original value:
+ cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks)
+ levels = kwargs.get("levels", hueplt_norm.levels)
+
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ plotfunc,
+ cast("DataArray", hueplt_norm.values).data,
+ **locals(),
+ )
+
+ # subset that can be passed to scatter, hist2d
+ if not cmap_params_subset:
+ ckw = {vv: cmap_params[vv] for vv in ("vmin", "vmax", "norm", "cmap")}
+ cmap_params_subset.update(**ckw)
+
+ with plt.rc_context(_styles):
+ if z is not None:
+ import mpl_toolkits
+
+ if ax is None:
+ subplot_kws.update(projection="3d")
+ ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
+ assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D)
+
+ # Using 30, 30 minimizes rotation of the plot. Making it easier to
+ # build on your intuition from 2D plots:
+ ax.view_init(azim=30, elev=30, vertical_axis="y")
+ else:
+ ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
+
+ primitive = plotfunc(
+ xplt,
+ yplt,
+ ax=ax,
+ add_labels=add_labels,
+ **cmap_params_subset,
+ **kwargs,
+ )
+
+ if np.any(np.asarray(add_labels)) and add_title:
+ ax.set_title(darray._title_for_slice())
+
+ add_colorbar_, add_legend_ = _determine_guide(
+ hueplt_norm,
+ sizeplt_norm,
+ add_colorbar,
+ add_legend,
+ plotfunc_name=plotfunc.__name__,
+ )
+
+ if add_colorbar_:
+ if "label" not in cbar_kwargs:
+ cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data)
+
+ _add_colorbar(
+ primitive, ax, kwargs.get("cbar_ax", None), cbar_kwargs, cmap_params
+ )
+
+ if add_legend_:
+ if plotfunc.__name__ in ["scatter", "line"]:
+ _add_legend(
+ (
+ hueplt_norm
+ if add_legend or not add_colorbar_
+ else _Normalize(None)
+ ),
+ sizeplt_norm,
+ primitive,
+ legend_ax=ax,
+ plotfunc=plotfunc.__name__,
+ )
+ else:
+ hueplt_norm_values: list[np.ndarray | None]
+ if hueplt_norm.data is not None:
+ hueplt_norm_values = list(hueplt_norm.data.to_numpy())
+ else:
+ hueplt_norm_values = [hueplt_norm.data]
+
+ if plotfunc.__name__ == "hist":
+ ax.legend(
+ handles=primitive[-1],
+ labels=hueplt_norm_values,
+ title=label_from_attrs(hueplt_norm.data),
+ )
+ else:
+ ax.legend(
+ handles=primitive,
+ labels=hueplt_norm_values,
+ title=label_from_attrs(hueplt_norm.data),
+ )
+
+ _update_axes(
+ ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim
+ )
+
+ return primitive
+
+ # we want to actually expose the signature of newplotfunc
+ # and not the copied **kwargs from the plotfunc which
+ # functools.wraps adds, so delete the wrapped attr
+ del newplotfunc.__wrapped__
+
+ return newplotfunc
+
+
+def _add_labels(
+ add_labels: bool | Iterable[bool],
+ darrays: Iterable[DataArray | None],
+ suffixes: Iterable[str],
+ ax: Axes,
+) -> None:
"""Set x, y, z labels."""
- pass
+ add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels
+ axes: tuple[Literal["x", "y", "z"], ...] = ("x", "y", "z")
+ for axis, add_label, darray, suffix in zip(axes, add_labels, darrays, suffixes):
+ if darray is None:
+ continue
+
+ if add_label:
+ label = label_from_attrs(darray, extra=suffix)
+ if label is not None:
+ getattr(ax, f"set_{axis}label")(label)
+
+ if np.issubdtype(darray.dtype, np.datetime64):
+ _set_concise_date(ax, axis=axis)
+
+
+@overload
+def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs,
+) -> PathCollection: ...
+
+
+@overload
+def scatter(
+ darray: T_DataArray,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def scatter(
+ darray: T_DataArray,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs,
+) -> FacetGrid[T_DataArray]: ...
@_plot1d
-def scatter(xplt: (DataArray | None), yplt: (DataArray | None), ax: Axes,
- add_labels: (bool | Iterable[bool])=True, **kwargs) ->PathCollection:
+def scatter(
+ xplt: DataArray | None,
+ yplt: DataArray | None,
+ ax: Axes,
+ add_labels: bool | Iterable[bool] = True,
+ **kwargs,
+) -> PathCollection:
"""Scatter variables against each other.
Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`.
"""
- pass
+ if "u" in kwargs or "v" in kwargs:
+ raise ValueError("u, v are not allowed in scatter plots.")
+
+ zplt: DataArray | None = kwargs.pop("zplt", None)
+ hueplt: DataArray | None = kwargs.pop("hueplt", None)
+ sizeplt: DataArray | None = kwargs.pop("sizeplt", None)
+
+ if hueplt is not None:
+ kwargs.update(c=hueplt.to_numpy().ravel())
+
+ if sizeplt is not None:
+ kwargs.update(s=sizeplt.to_numpy().ravel())
+
+ plts_or_none = (xplt, yplt, zplt)
+ _add_labels(add_labels, plts_or_none, ("", "", ""), ax)
+
+ xplt_np = None if xplt is None else xplt.to_numpy().ravel()
+ yplt_np = None if yplt is None else yplt.to_numpy().ravel()
+ zplt_np = None if zplt is None else zplt.to_numpy().ravel()
+ plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None)
+
+ if len(plts_np) == 3:
+ import mpl_toolkits
+
+ assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D)
+ return ax.scatter(xplt_np, yplt_np, zplt_np, **kwargs)
+
+ if len(plts_np) == 2:
+ return ax.scatter(plts_np[0], plts_np[1], **kwargs)
+
+ raise ValueError("At least two variables required for a scatter plot.")
def _plot2d(plotfunc):
"""Decorator for common 2d plotting logic."""
- pass
+ commondoc = """
+ Parameters
+ ----------
+ darray : DataArray
+ Must be two-dimensional, unless creating faceted plots.
+ x : Hashable or None, optional
+ Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``.
+ y : Hashable or None, optional
+ Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``.
+ figsize : Iterable or float or None, optional
+ A tuple (width, height) of the figure in inches.
+ Mutually exclusive with ``size`` and ``ax``.
+ size : scalar, optional
+ If provided, create a new figure for the plot with the given size:
+ *height* (in inches) of each plot. See also: ``aspect``.
+ aspect : "auto", "equal", scalar or None, optional
+ Aspect ratio of plot, so that ``aspect * size`` gives the *width* in
+ inches. Only used if a ``size`` is provided.
+ ax : matplotlib axes object, optional
+ Axes on which to plot. By default, use the current axes.
+ Mutually exclusive with ``size`` and ``figsize``.
+ row : Hashable or None, optional
+ If passed, make row faceted plots on this dimension name.
+ col : Hashable or None, optional
+ If passed, make column faceted plots on this dimension name.
+ col_wrap : int, optional
+ Use together with ``col`` to wrap faceted plots.
+ xincrease : None, True, or False, optional
+ Should the values on the *x* axis be increasing from left to right?
+ If ``None``, use the default for the Matplotlib function.
+ yincrease : None, True, or False, optional
+ Should the values on the *y* axis be increasing from top to bottom?
+ If ``None``, use the default for the Matplotlib function.
+ add_colorbar : bool, optional
+ Add colorbar to axes.
+ add_labels : bool, optional
+ Use xarray metadata to label axes.
+ vmin : float or None, optional
+ Lower value to anchor the colormap, otherwise it is inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting `vmin` or `vmax` will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ vmax : float or None, optional
+ Upper value to anchor the colormap, otherwise it is inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting `vmin` or `vmax` will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ cmap : matplotlib colormap name or colormap, optional
+ The mapping from data values to color space. If not provided, this
+ will be either be ``'viridis'`` (if the function infers a sequential
+ dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset).
+ See :doc:`Choosing Colormaps in Matplotlib <matplotlib:users/explain/colors/colormaps>`
+ for more information.
+
+ If *seaborn* is installed, ``cmap`` may also be a
+ `seaborn color palette <https://seaborn.pydata.org/tutorial/color_palettes.html>`_.
+ Note: if ``cmap`` is a seaborn color palette and the plot type
+ is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified.
+ center : float or False, optional
+ The value at which to center the colormap. Passing this value implies
+ use of a diverging colormap. Setting it to ``False`` prevents use of a
+ diverging colormap.
+ robust : bool, optional
+ If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is
+ computed with 2nd and 98th percentiles instead of the extreme values.
+ extend : {'neither', 'both', 'min', 'max'}, optional
+ How to draw arrows extending the colorbar beyond its limits. If not
+ provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits.
+ levels : int or array-like, optional
+ Split the colormap (``cmap``) into discrete color intervals. If an integer
+ is provided, "nice" levels are chosen based on the data range: this can
+ imply that the final number of levels is not exactly the expected one.
+ Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to
+ setting ``levels=np.linspace(vmin, vmax, N)``.
+ infer_intervals : bool, optional
+ Only applies to pcolormesh. If ``True``, the coordinate intervals are
+ passed to pcolormesh. If ``False``, the original coordinates are used
+ (this can be useful for certain map projections). The default is to
+ always infer intervals, unless the mesh is irregular and plotted on
+ a map projection.
+ colors : str or array-like of color-like, optional
+ A single color or a sequence of colors. If the plot type is not ``'contour'``
+ or ``'contourf'``, the ``levels`` argument is required.
+ subplot_kws : dict, optional
+ Dictionary of keyword arguments for Matplotlib subplots. Only used
+ for 2D and faceted plots.
+ (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`).
+ cbar_ax : matplotlib axes object, optional
+ Axes in which to draw the colorbar.
+ cbar_kwargs : dict, optional
+ Dictionary of keyword arguments to pass to the colorbar
+ (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`).
+ xscale : {'linear', 'symlog', 'log', 'logit'} or None, optional
+ Specifies scaling for the x-axes.
+ yscale : {'linear', 'symlog', 'log', 'logit'} or None, optional
+ Specifies scaling for the y-axes.
+ xticks : ArrayLike or None, optional
+ Specify tick locations for x-axes.
+ yticks : ArrayLike or None, optional
+ Specify tick locations for y-axes.
+ xlim : tuple[float, float] or None, optional
+ Specify x-axes limits.
+ ylim : tuple[float, float] or None, optional
+ Specify y-axes limits.
+ norm : matplotlib.colors.Normalize, optional
+ If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding
+ kwarg must be ``None``.
+ **kwargs : optional
+ Additional keyword arguments to wrapped Matplotlib function.
+
+ Returns
+ -------
+ artist :
+ The same type of primitive artist that the wrapped Matplotlib
+ function returns.
+ """
+
+ # Build on the original docstring
+ plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}"
+
+ @functools.wraps(
+ plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__")
+ )
+ def newplotfunc(
+ darray: DataArray,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ # All 2d plots in xarray share this function signature.
+
+ if args:
+ # TODO: Deprecated since 2022.10:
+ msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead."
+ assert x is None
+ x = args[0]
+ if len(args) > 1:
+ assert y is None
+ y = args[1]
+ if len(args) > 2:
+ raise ValueError(msg)
+ else:
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
+ del args
+
+ # Decide on a default for the colorbar before facetgrids
+ if add_colorbar is None:
+ add_colorbar = True
+ if plotfunc.__name__ == "contour" or (
+ plotfunc.__name__ == "surface" and cmap is None
+ ):
+ add_colorbar = False
+ imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == (
+ 3 + (row is not None) + (col is not None)
+ )
+ if imshow_rgb:
+ # Don't add a colorbar when showing an image with explicit colors
+ add_colorbar = False
+ # Matplotlib does not support normalising RGB data, so do it here.
+ # See eg. https://github.com/matplotlib/matplotlib/pull/10220
+ if robust or vmax is not None or vmin is not None:
+ darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust)
+ vmin, vmax, robust = None, None, False
+
+ if subplot_kws is None:
+ subplot_kws = dict()
+
+ if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False):
+ if ax is None:
+ # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2.
+ # Remove when minimum requirement of matplotlib is 3.2:
+ from mpl_toolkits.mplot3d import Axes3D # noqa: F401
+
+ # delete so it does not end up in locals()
+ del Axes3D
+
+ # Need to create a "3d" Axes instance for surface plots
+ subplot_kws["projection"] = "3d"
+
+ # In facet grids, shared axis labels don't make sense for surface plots
+ sharex = False
+ sharey = False
+
+ # Handle facetgrids first
+ if row or col:
+ allargs = locals().copy()
+ del allargs["darray"]
+ del allargs["imshow_rgb"]
+ allargs.update(allargs.pop("kwargs"))
+ # Need the decorated plotting function
+ allargs["plotfunc"] = globals()[plotfunc.__name__]
+ return _easy_facetgrid(darray, kind="dataarray", **allargs)
+
+ if darray.ndim == 0 or darray.size == 0:
+ # TypeError to be consistent with pandas
+ raise TypeError("No numeric data to plot.")
+
+ if (
+ plotfunc.__name__ == "surface"
+ and not kwargs.get("_is_facetgrid", False)
+ and ax is not None
+ ):
+ import mpl_toolkits
+
+ if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D):
+ raise ValueError(
+ "If ax is passed to surface(), it must be created with "
+ 'projection="3d"'
+ )
+
+ rgb = kwargs.pop("rgb", None)
+ if rgb is not None and plotfunc.__name__ != "imshow":
+ raise ValueError('The "rgb" keyword is only valid for imshow()')
+ elif rgb is not None and not imshow_rgb:
+ raise ValueError(
+ 'The "rgb" keyword is only valid for imshow()'
+ "with a three-dimensional array (per facet)"
+ )
+
+ xlab, ylab = _infer_xy_labels(
+ darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb
+ )
+
+ xval = darray[xlab]
+ yval = darray[ylab]
+
+ if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface":
+ # Passing 2d coordinate values, need to ensure they are transposed the same
+ # way as darray.
+ # Also surface plots always need 2d coordinates
+ xval = xval.broadcast_like(darray)
+ yval = yval.broadcast_like(darray)
+ dims = darray.dims
+ else:
+ dims = (yval.dims[0], xval.dims[0])
+
+ # May need to transpose for correct x, y labels
+ # xlab may be the name of a coord, we have to check for dim names
+ if imshow_rgb:
+ # For RGB[A] images, matplotlib requires the color dimension
+ # to be last. In Xarray the order should be unimportant, so
+ # we transpose to (y, x, color) to make this work.
+ yx_dims = (ylab, xlab)
+ dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
+
+ if dims != darray.dims:
+ darray = darray.transpose(*dims, transpose_coords=True)
+
+ # better to pass the ndarrays directly to plotting functions
+ xvalnp = xval.to_numpy()
+ yvalnp = yval.to_numpy()
+
+ # Pass the data as a masked ndarray too
+ zval = darray.to_masked_array(copy=False)
+
+ # Replace pd.Intervals if contained in xval or yval.
+ xplt, xlab_extra = _resolve_intervals_2dplot(xvalnp, plotfunc.__name__)
+ yplt, ylab_extra = _resolve_intervals_2dplot(yvalnp, plotfunc.__name__)
+
+ _ensure_plottable(xplt, yplt, zval)
+
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ plotfunc,
+ zval.data,
+ **locals(),
+ _is_facetgrid=kwargs.pop("_is_facetgrid", False),
+ )
+
+ if "contour" in plotfunc.__name__:
+ # extend is a keyword argument only for contour and contourf, but
+ # passing it to the colorbar is sufficient for imshow and
+ # pcolormesh
+ kwargs["extend"] = cmap_params["extend"]
+ kwargs["levels"] = cmap_params["levels"]
+ # if colors == a single color, matplotlib draws dashed negative
+ # contours. we lose this feature if we pass cmap and not colors
+ if isinstance(colors, str):
+ cmap_params["cmap"] = None
+ kwargs["colors"] = colors
+
+ if "pcolormesh" == plotfunc.__name__:
+ kwargs["infer_intervals"] = infer_intervals
+ kwargs["xscale"] = xscale
+ kwargs["yscale"] = yscale
+
+ if "imshow" == plotfunc.__name__ and isinstance(aspect, str):
+ # forbid usage of mpl strings
+ raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray")
+
+ ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
+
+ primitive = plotfunc(
+ xplt,
+ yplt,
+ zval,
+ ax=ax,
+ cmap=cmap_params["cmap"],
+ vmin=cmap_params["vmin"],
+ vmax=cmap_params["vmax"],
+ norm=cmap_params["norm"],
+ **kwargs,
+ )
+
+ # Label the plot with metadata
+ if add_labels:
+ ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra))
+ ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra))
+ ax.set_title(darray._title_for_slice())
+ if plotfunc.__name__ == "surface":
+ import mpl_toolkits
+
+ assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D)
+ ax.set_zlabel(label_from_attrs(darray))
+
+ if add_colorbar:
+ if add_labels and "label" not in cbar_kwargs:
+ cbar_kwargs["label"] = label_from_attrs(darray)
+ cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params)
+ elif cbar_ax is not None or cbar_kwargs:
+ # inform the user about keywords which aren't used
+ raise ValueError(
+ "cbar_ax and cbar_kwargs can't be used with add_colorbar=False."
+ )
+
+ # origin kwarg overrides yincrease
+ if "origin" in kwargs:
+ yincrease = None
+
+ _update_axes(
+ ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim
+ )
+
+ if np.issubdtype(xplt.dtype, np.datetime64):
+ _set_concise_date(ax, "x")
+
+ return primitive
+
+ # we want to actually expose the signature of newplotfunc
+ # and not the copied **kwargs from the plotfunc which
+ # functools.wraps adds, so delete the wrapped attr
+ del newplotfunc.__wrapped__
+
+ return newplotfunc
+
+
+@overload
+def imshow( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> AxesImage: ...
+
+
+@overload
+def imshow(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def imshow(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
@_plot2d
-def imshow(x: np.ndarray, y: np.ndarray, z: np.ma.core.MaskedArray, ax:
- Axes, **kwargs: Any) ->AxesImage:
+def imshow(
+ x: np.ndarray, y: np.ndarray, z: np.ma.core.MaskedArray, ax: Axes, **kwargs: Any
+) -> AxesImage:
"""
Image plot of 2D DataArray.
@@ -326,49 +1802,654 @@ def imshow(x: np.ndarray, y: np.ndarray, z: np.ma.core.MaskedArray, ax:
The pixels are centered on the coordinates. For example, if the coordinate
value is 3.2, then the pixels for those coordinates will be centered on 3.2.
"""
- pass
+
+ if x.ndim != 1 or y.ndim != 1:
+ raise ValueError(
+ "imshow requires 1D coordinates, try using pcolormesh or contour(f)"
+ )
+
+ def _center_pixels(x):
+ """Center the pixels on the coordinates."""
+ if np.issubdtype(x.dtype, str):
+ # When using strings as inputs imshow converts it to
+ # integers. Choose extent values which puts the indices in
+ # in the center of the pixels:
+ return 0 - 0.5, len(x) - 0.5
+
+ try:
+ # Center the pixels assuming uniform spacing:
+ xstep = 0.5 * (x[1] - x[0])
+ except IndexError:
+ # Arbitrary default value, similar to matplotlib behaviour:
+ xstep = 0.1
+
+ return x[0] - xstep, x[-1] + xstep
+
+ # Center the pixels:
+ left, right = _center_pixels(x)
+ top, bottom = _center_pixels(y)
+
+ defaults: dict[str, Any] = {"origin": "upper", "interpolation": "nearest"}
+
+ if not hasattr(ax, "projection"):
+ # not for cartopy geoaxes
+ defaults["aspect"] = "auto"
+
+ # Allow user to override these defaults
+ defaults.update(kwargs)
+
+ if defaults["origin"] == "upper":
+ defaults["extent"] = [left, right, bottom, top]
+ else:
+ defaults["extent"] = [left, right, top, bottom]
+
+ if z.ndim == 3:
+ # matplotlib imshow uses black for missing data, but Xarray makes
+ # missing data transparent. We therefore add an alpha channel if
+ # there isn't one, and set it to transparent where data is masked.
+ if z.shape[-1] == 3:
+ safe_dtype = np.promote_types(z.dtype, np.uint8)
+ alpha = np.ma.ones(z.shape[:2] + (1,), dtype=safe_dtype)
+ if np.issubdtype(z.dtype, np.integer):
+ alpha[:] = 255
+ z = np.ma.concatenate((z, alpha), axis=2)
+ else:
+ z = z.copy()
+ z[np.any(z.mask, axis=-1), -1] = 0
+
+ primitive = ax.imshow(z, **defaults)
+
+ # If x or y are strings the ticklabels have been replaced with
+ # integer indices. Replace them back to strings:
+ for axis, v in [("x", x), ("y", y)]:
+ if np.issubdtype(v.dtype, str):
+ getattr(ax, f"set_{axis}ticks")(np.arange(len(v)))
+ getattr(ax, f"set_{axis}ticklabels")(v)
+
+ return primitive
+
+
+@overload
+def contour( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> QuadContourSet: ...
+
+
+@overload
+def contour(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def contour(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
@_plot2d
-def contour(x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs:
- Any) ->QuadContourSet:
+def contour(
+ x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any
+) -> QuadContourSet:
"""
Contour plot of 2D DataArray.
Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`.
"""
- pass
+ primitive = ax.contour(x, y, z, **kwargs)
+ return primitive
+
+
+@overload
+def contourf( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> QuadContourSet: ...
+
+
+@overload
+def contourf(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def contourf(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
@_plot2d
-def contourf(x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **
- kwargs: Any) ->QuadContourSet:
+def contourf(
+ x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any
+) -> QuadContourSet:
"""
Filled contour plot of 2D DataArray.
Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`.
"""
- pass
+ primitive = ax.contourf(x, y, z, **kwargs)
+ return primitive
+
+
+@overload
+def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :(
+ darray: DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> QuadMesh: ...
+
+
+@overload
+def pcolormesh(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def pcolormesh(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
@_plot2d
-def pcolormesh(x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes,
- xscale: (ScaleOptions | None)=None, yscale: (ScaleOptions | None)=None,
- infer_intervals=None, **kwargs: Any) ->QuadMesh:
+def pcolormesh(
+ x: np.ndarray,
+ y: np.ndarray,
+ z: np.ndarray,
+ ax: Axes,
+ xscale: ScaleOptions | None = None,
+ yscale: ScaleOptions | None = None,
+ infer_intervals=None,
+ **kwargs: Any,
+) -> QuadMesh:
"""
Pseudocolor plot of 2D DataArray.
Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`.
"""
- pass
+
+ # decide on a default for infer_intervals (GH781)
+ x = np.asarray(x)
+ if infer_intervals is None:
+ if hasattr(ax, "projection"):
+ if len(x.shape) == 1:
+ infer_intervals = True
+ else:
+ infer_intervals = False
+ else:
+ infer_intervals = True
+
+ if any(np.issubdtype(k.dtype, str) for k in (x, y)):
+ # do not infer intervals if any axis contains str ticks, see #6775
+ infer_intervals = False
+
+ if infer_intervals and (
+ (np.shape(x)[0] == np.shape(z)[1])
+ or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1]))
+ ):
+ if x.ndim == 1:
+ x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale)
+ else:
+ # we have to infer the intervals on both axes
+ x = _infer_interval_breaks(x, axis=1, scale=xscale)
+ x = _infer_interval_breaks(x, axis=0, scale=xscale)
+
+ if infer_intervals and (np.shape(y)[0] == np.shape(z)[0]):
+ if y.ndim == 1:
+ y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale)
+ else:
+ # we have to infer the intervals on both axes
+ y = _infer_interval_breaks(y, axis=1, scale=yscale)
+ y = _infer_interval_breaks(y, axis=0, scale=yscale)
+
+ ax.grid(False)
+ primitive = ax.pcolormesh(x, y, z, **kwargs)
+
+ # by default, pcolormesh picks "round" values for bounds
+ # this results in ugly looking plots with lots of surrounding whitespace
+ if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1:
+ # not a cartopy geoaxis
+ ax.set_xlim(x[0], x[-1])
+ ax.set_ylim(y[0], y[-1])
+
+ return primitive
+
+
+@overload
+def surface(
+ darray: DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> Poly3DCollection: ...
+
+
+@overload
+def surface(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
+
+
+@overload
+def surface(
+ darray: T_DataArray,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ *,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_colorbar: bool | None = None,
+ add_labels: bool = True,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ cmap: str | Colormap | None = None,
+ center: float | Literal[False] | None = None,
+ robust: bool = False,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ infer_intervals=None,
+ colors: str | ArrayLike | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ norm: Normalize | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArray]: ...
@_plot2d
-def surface(x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs:
- Any) ->Poly3DCollection:
+def surface(
+ x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any
+) -> Poly3DCollection:
"""
Surface plot of 2D DataArray.
Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`.
"""
- pass
+ import mpl_toolkits
+
+ assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D)
+ primitive = ax.plot_surface(x, y, z, **kwargs)
+ return primitive
diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py
index 458a594d..96b59f61 100644
--- a/xarray/plot/dataset_plot.py
+++ b/xarray/plot/dataset_plot.py
@@ -1,49 +1,656 @@
from __future__ import annotations
+
import functools
import inspect
import warnings
from collections.abc import Hashable, Iterable
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
+
from xarray.core.alignment import broadcast
from xarray.plot import dataarray_plot
from xarray.plot.facetgrid import _easy_facetgrid
-from xarray.plot.utils import _add_colorbar, _get_nice_quiver_magnitude, _infer_meta_data, _process_cmap_cbar_kwargs, get_axis
+from xarray.plot.utils import (
+ _add_colorbar,
+ _get_nice_quiver_magnitude,
+ _infer_meta_data,
+ _process_cmap_cbar_kwargs,
+ get_axis,
+)
+
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection, PathCollection
from matplotlib.colors import Colormap, Normalize
from matplotlib.quiver import Quiver
from numpy.typing import ArrayLike
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
- from xarray.core.types import AspectOptions, ExtendOptions, HueStyleOptions, ScaleOptions
+ from xarray.core.types import (
+ AspectOptions,
+ ExtendOptions,
+ HueStyleOptions,
+ ScaleOptions,
+ )
from xarray.plot.facetgrid import FacetGrid
+def _dsplot(plotfunc):
+ commondoc = """
+ Parameters
+ ----------
+
+ ds : Dataset
+ x : Hashable or None, optional
+ Variable name for x-axis.
+ y : Hashable or None, optional
+ Variable name for y-axis.
+ u : Hashable or None, optional
+ Variable name for the *u* velocity (in *x* direction).
+ quiver/streamplot plots only.
+ v : Hashable or None, optional
+ Variable name for the *v* velocity (in *y* direction).
+ quiver/streamplot plots only.
+ hue: Hashable or None, optional
+ Variable by which to color scatter points or arrows.
+ hue_style: {'continuous', 'discrete'} or None, optional
+ How to use the ``hue`` variable:
+
+ - ``'continuous'`` -- continuous color scale
+ (default for numeric ``hue`` variables)
+ - ``'discrete'`` -- a color for each unique value, using the default color cycle
+ (default for non-numeric ``hue`` variables)
+
+ row : Hashable or None, optional
+ If passed, make row faceted plots on this dimension name.
+ col : Hashable or None, optional
+ If passed, make column faceted plots on this dimension name.
+ col_wrap : int, optional
+ Use together with ``col`` to wrap faceted plots.
+ ax : matplotlib axes object or None, optional
+ If ``None``, use the current axes. Not applicable when using facets.
+ figsize : Iterable[float] or None, optional
+ A tuple (width, height) of the figure in inches.
+ Mutually exclusive with ``size`` and ``ax``.
+ size : scalar, optional
+ If provided, create a new figure for the plot with the given size.
+ Height (in inches) of each plot. See also: ``aspect``.
+ aspect : "auto", "equal", scalar or None, optional
+ Aspect ratio of plot, so that ``aspect * size`` gives the width in
+ inches. Only used if a ``size`` is provided.
+ sharex : bool or None, optional
+ If True all subplots share the same x-axis.
+ sharey : bool or None, optional
+ If True all subplots share the same y-axis.
+ add_guide: bool or None, optional
+ Add a guide that depends on ``hue_style``:
+
+ - ``'continuous'`` -- build a colorbar
+ - ``'discrete'`` -- build a legend
+
+ subplot_kws : dict or None, optional
+ Dictionary of keyword arguments for Matplotlib subplots
+ (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`).
+ Only applies to FacetGrid plotting.
+ cbar_kwargs : dict, optional
+ Dictionary of keyword arguments to pass to the colorbar
+ (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`).
+ cbar_ax : matplotlib axes object, optional
+ Axes in which to draw the colorbar.
+ cmap : matplotlib colormap name or colormap, optional
+ The mapping from data values to color space. Either a
+ Matplotlib colormap name or object. If not provided, this will
+ be either ``'viridis'`` (if the function infers a sequential
+ dataset) or ``'RdBu_r'`` (if the function infers a diverging
+ dataset).
+ See :doc:`Choosing Colormaps in Matplotlib <matplotlib:users/explain/colors/colormaps>`
+ for more information.
+
+ If *seaborn* is installed, ``cmap`` may also be a
+ `seaborn color palette <https://seaborn.pydata.org/tutorial/color_palettes.html>`_.
+ Note: if ``cmap`` is a seaborn color palette,
+ ``levels`` must also be specified.
+ vmin : float or None, optional
+ Lower value to anchor the colormap, otherwise it is inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting `vmin` or `vmax` will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ vmax : float or None, optional
+ Upper value to anchor the colormap, otherwise it is inferred from the
+ data and other keyword arguments. When a diverging dataset is inferred,
+ setting `vmin` or `vmax` will fix the other by symmetry around
+ ``center``. Setting both values prevents use of a diverging colormap.
+ If discrete levels are provided as an explicit list, both of these
+ values are ignored.
+ norm : matplotlib.colors.Normalize, optional
+ If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding
+ kwarg must be ``None``.
+ infer_intervals: bool | None
+ If True the intervals are inferred.
+ center : float, optional
+ The value at which to center the colormap. Passing this value implies
+ use of a diverging colormap. Setting it to ``False`` prevents use of a
+ diverging colormap.
+ robust : bool, optional
+ If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is
+ computed with 2nd and 98th percentiles instead of the extreme values.
+ colors : str or array-like of color-like, optional
+ A single color or a list of colors. The ``levels`` argument
+ is required.
+ extend : {'neither', 'both', 'min', 'max'}, optional
+ How to draw arrows extending the colorbar beyond its limits. If not
+ provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits.
+ levels : int or array-like, optional
+ Split the colormap (``cmap``) into discrete color intervals. If an integer
+ is provided, "nice" levels are chosen based on the data range: this can
+ imply that the final number of levels is not exactly the expected one.
+ Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to
+ setting ``levels=np.linspace(vmin, vmax, N)``.
+ **kwargs : optional
+ Additional keyword arguments to wrapped Matplotlib function.
+ """
+
+ # Build on the original docstring
+ plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}"
+
+ @functools.wraps(
+ plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__")
+ )
+ def newplotfunc(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ add_guide: bool | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ if args:
+ # TODO: Deprecated since 2022.10:
+ msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead."
+ assert x is None
+ x = args[0]
+ if len(args) > 1:
+ assert y is None
+ y = args[1]
+ if len(args) > 2:
+ assert u is None
+ u = args[2]
+ if len(args) > 3:
+ assert v is None
+ v = args[3]
+ if len(args) > 4:
+ assert hue is None
+ hue = args[4]
+ if len(args) > 5:
+ raise ValueError(msg)
+ else:
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
+ del args
+
+ _is_facetgrid = kwargs.pop("_is_facetgrid", False)
+ if _is_facetgrid: # facetgrid call
+ meta_data = kwargs.pop("meta_data")
+ else:
+ meta_data = _infer_meta_data(
+ ds, x, y, hue, hue_style, add_guide, funcname=plotfunc.__name__
+ )
+
+ hue_style = meta_data["hue_style"]
+
+ # handle facetgrids first
+ if col or row:
+ allargs = locals().copy()
+ allargs["plotfunc"] = globals()[plotfunc.__name__]
+ allargs["data"] = ds
+ # remove kwargs to avoid passing the information twice
+ for arg in ["meta_data", "kwargs", "ds"]:
+ del allargs[arg]
+
+ return _easy_facetgrid(kind="dataset", **allargs, **kwargs)
+
+ figsize = kwargs.pop("figsize", None)
+ ax = get_axis(figsize, size, aspect, ax)
+
+ if hue_style == "continuous" and hue is not None:
+ if _is_facetgrid:
+ cbar_kwargs = meta_data["cbar_kwargs"]
+ cmap_params = meta_data["cmap_params"]
+ else:
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ plotfunc, ds[hue].values, **locals()
+ )
+
+ # subset that can be passed to scatter, hist2d
+ cmap_params_subset = {
+ vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"]
+ }
+
+ else:
+ cmap_params_subset = {}
+
+ if (u is not None or v is not None) and plotfunc.__name__ not in (
+ "quiver",
+ "streamplot",
+ ):
+ raise ValueError("u, v are only allowed for quiver or streamplot plots.")
+
+ primitive = plotfunc(
+ ds=ds,
+ x=x,
+ y=y,
+ ax=ax,
+ u=u,
+ v=v,
+ hue=hue,
+ hue_style=hue_style,
+ cmap_params=cmap_params_subset,
+ **kwargs,
+ )
+
+ if _is_facetgrid: # if this was called from Facetgrid.map_dataset,
+ return primitive # finish here. Else, make labels
+
+ if meta_data.get("xlabel", None):
+ ax.set_xlabel(meta_data.get("xlabel"))
+ if meta_data.get("ylabel", None):
+ ax.set_ylabel(meta_data.get("ylabel"))
+
+ if meta_data["add_legend"]:
+ ax.legend(handles=primitive, title=meta_data.get("hue_label", None))
+ if meta_data["add_colorbar"]:
+ cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
+ if "label" not in cbar_kwargs:
+ cbar_kwargs["label"] = meta_data.get("hue_label", None)
+ _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params)
+
+ if meta_data["add_quiverkey"]:
+ magnitude = _get_nice_quiver_magnitude(ds[u], ds[v])
+ units = ds[u].attrs.get("units", "")
+ ax.quiverkey(
+ primitive,
+ X=0.85,
+ Y=0.9,
+ U=magnitude,
+ label=f"{magnitude}\n{units}",
+ labelpos="E",
+ coordinates="figure",
+ )
+
+ if plotfunc.__name__ in ("quiver", "streamplot"):
+ title = ds[u]._title_for_slice()
+ else:
+ title = ds[x]._title_for_slice()
+ ax.set_title(title)
+
+ return primitive
+
+ # we want to actually expose the signature of newplotfunc
+ # and not the copied **kwargs from the plotfunc which
+ # functools.wraps adds, so delete the wrapped attr
+ del newplotfunc.__wrapped__
+
+ return newplotfunc
+
+
+@overload
+def quiver( # type: ignore[misc,unused-ignore] # None is hashable :(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: None = None, # no wrap -> primitive
+ row: None = None, # no wrap -> primitive
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ levels: ArrayLike | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ cmap: str | Colormap | None = None,
+ **kwargs: Any,
+) -> Quiver: ...
+
+
+@overload
+def quiver(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable, # wrap -> FacetGrid
+ row: Hashable | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ levels: ArrayLike | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ cmap: str | Colormap | None = None,
+ **kwargs: Any,
+) -> FacetGrid[Dataset]: ...
+
+
+@overload
+def quiver(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ levels: ArrayLike | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ cmap: str | Colormap | None = None,
+ **kwargs: Any,
+) -> FacetGrid[Dataset]: ...
+
+
@_dsplot
-def quiver(ds: Dataset, x: Hashable, y: Hashable, ax: Axes, u: Hashable, v:
- Hashable, **kwargs: Any) ->Quiver:
+def quiver(
+ ds: Dataset,
+ x: Hashable,
+ y: Hashable,
+ ax: Axes,
+ u: Hashable,
+ v: Hashable,
+ **kwargs: Any,
+) -> Quiver:
"""Quiver plot of Dataset variables.
Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`.
"""
- pass
+ import matplotlib as mpl
+
+ if x is None or y is None or u is None or v is None:
+ raise ValueError("Must specify x, y, u, v for quiver plots.")
+
+ dx, dy, du, dv = broadcast(ds[x], ds[y], ds[u], ds[v])
+
+ args = [dx.values, dy.values, du.values, dv.values]
+ hue = kwargs.pop("hue")
+ cmap_params = kwargs.pop("cmap_params")
+
+ if hue:
+ args.append(ds[hue].values)
+
+ # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
+ if not cmap_params["norm"]:
+ cmap_params["norm"] = mpl.colors.Normalize(
+ cmap_params.pop("vmin"), cmap_params.pop("vmax")
+ )
+
+ kwargs.pop("hue_style")
+ kwargs.setdefault("pivot", "middle")
+ hdl = ax.quiver(*args, **kwargs, **cmap_params)
+ return hdl
+
+
+@overload
+def streamplot( # type: ignore[misc,unused-ignore] # None is hashable :(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: None = None, # no wrap -> primitive
+ row: None = None, # no wrap -> primitive
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ levels: ArrayLike | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ cmap: str | Colormap | None = None,
+ **kwargs: Any,
+) -> LineCollection: ...
+
+
+@overload
+def streamplot(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable, # wrap -> FacetGrid
+ row: Hashable | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ levels: ArrayLike | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ cmap: str | Colormap | None = None,
+ **kwargs: Any,
+) -> FacetGrid[Dataset]: ...
+
+
+@overload
+def streamplot(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ u: Hashable | None = None,
+ v: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ col: Hashable | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: AspectOptions = None,
+ subplot_kws: dict[str, Any] | None = None,
+ add_guide: bool | None = None,
+ cbar_kwargs: dict[str, Any] | None = None,
+ cbar_ax: Axes | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ infer_intervals: bool | None = None,
+ center: float | None = None,
+ levels: ArrayLike | None = None,
+ robust: bool | None = None,
+ colors: str | ArrayLike | None = None,
+ extend: ExtendOptions = None,
+ cmap: str | Colormap | None = None,
+ **kwargs: Any,
+) -> FacetGrid[Dataset]: ...
@_dsplot
-def streamplot(ds: Dataset, x: Hashable, y: Hashable, ax: Axes, u: Hashable,
- v: Hashable, **kwargs: Any) ->LineCollection:
+def streamplot(
+ ds: Dataset,
+ x: Hashable,
+ y: Hashable,
+ ax: Axes,
+ u: Hashable,
+ v: Hashable,
+ **kwargs: Any,
+) -> LineCollection:
"""Plot streamlines of Dataset variables.
Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`.
"""
- pass
+ import matplotlib as mpl
+
+ if x is None or y is None or u is None or v is None:
+ raise ValueError("Must specify x, y, u, v for streamplot plots.")
+
+ # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to
+ # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so
+ # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so
+ # the dimension of y must be the first dimension. If x and y are both 2d, assume the
+ # user has got them right already.
+ xdim = ds[x].dims[0] if len(ds[x].dims) == 1 else None
+ ydim = ds[y].dims[0] if len(ds[y].dims) == 1 else None
+ if xdim is not None and ydim is None:
+ ydims = set(ds[y].dims) - {xdim}
+ if len(ydims) == 1:
+ ydim = next(iter(ydims))
+ if ydim is not None and xdim is None:
+ xdims = set(ds[x].dims) - {ydim}
+ if len(xdims) == 1:
+ xdim = next(iter(xdims))
+
+ dx, dy, du, dv = broadcast(ds[x], ds[y], ds[u], ds[v])
+
+ if xdim is not None and ydim is not None:
+ # Need to ensure the arrays are transposed correctly
+ dx = dx.transpose(ydim, xdim)
+ dy = dy.transpose(ydim, xdim)
+ du = du.transpose(ydim, xdim)
+ dv = dv.transpose(ydim, xdim)
+
+ hue = kwargs.pop("hue")
+ cmap_params = kwargs.pop("cmap_params")
+ if hue:
+ kwargs["color"] = ds[hue].values
-F = TypeVar('F', bound=Callable)
+ # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
+ if not cmap_params["norm"]:
+ cmap_params["norm"] = mpl.colors.Normalize(
+ cmap_params.pop("vmin"), cmap_params.pop("vmax")
+ )
+ kwargs.pop("hue_style")
+ hdl = ax.streamplot(
+ dx.values, dy.values, du.values, dv.values, **kwargs, **cmap_params
+ )
-def _update_doc_to_dataset(dataarray_plotfunc: Callable) ->Callable[[F], F]:
+ # Return .lines so colorbar creation works properly
+ return hdl.lines
+
+
+F = TypeVar("F", bound=Callable)
+
+
+def _update_doc_to_dataset(dataarray_plotfunc: Callable) -> Callable[[F], F]:
"""
Add a common docstring by re-using the DataArray one.
@@ -60,32 +667,252 @@ def _update_doc_to_dataset(dataarray_plotfunc: Callable) ->Callable[[F], F]:
dataarray_plotfunc : Callable
Function that returns a finished plot primitive.
"""
- pass
+ # Build on the original docstring
+ da_doc = dataarray_plotfunc.__doc__
+ if da_doc is None:
+ raise NotImplementedError("DataArray plot method requires a docstring")
+
+ da_str = """
+ Parameters
+ ----------
+ darray : DataArray
+ """
+ ds_str = """
+
+ The `y` DataArray will be used as base, any other variables are added as coords.
+
+ Parameters
+ ----------
+ ds : Dataset
+ """
+ # TODO: improve this?
+ if da_str in da_doc:
+ ds_doc = da_doc.replace(da_str, ds_str).replace("darray", "ds")
+ else:
+ ds_doc = da_doc
+
+ @functools.wraps(dataarray_plotfunc)
+ def wrapper(dataset_plotfunc: F) -> F:
+ dataset_plotfunc.__doc__ = ds_doc
+ return dataset_plotfunc
+
+ return wrapper
+
+
+def _normalize_args(
+ plotmethod: str, args: tuple[Any, ...], kwargs: dict[str, Any]
+) -> dict[str, Any]:
+ from xarray.core.dataarray import DataArray
+
+ # Determine positional arguments keyword by inspecting the
+ # signature of the plotmethod:
+ locals_ = dict(
+ inspect.signature(getattr(DataArray().plot, plotmethod))
+ .bind(*args, **kwargs)
+ .arguments.items()
+ )
+ locals_.update(locals_.pop("kwargs", {}))
+
+ return locals_
-def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]
- ) ->DataArray:
+
+def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataArray:
"""Create a temporary datarray with extra coords."""
- pass
+ from xarray.core.dataarray import DataArray
+
+ coords = dict(ds[y].coords)
+ dims = set(ds[y].dims)
+
+ # Add extra coords to the DataArray from valid kwargs, if using all
+ # kwargs there is a risk that we add unnecessary dataarrays as
+ # coords straining RAM further for example:
+ # ds.both and extend="both" would add ds.both to the coords:
+ valid_coord_kwargs = {"x", "z", "markersize", "hue", "row", "col", "u", "v"}
+ coord_kwargs = locals_.keys() & valid_coord_kwargs
+ for k in coord_kwargs:
+ key = locals_[k]
+ darray = ds.get(key)
+ if darray is not None:
+ coords[key] = darray
+ dims.update(darray.dims)
+
+ # Trim dataset from unneccessary dims:
+ ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future
+
+ # The dataarray has to include all the dims. Broadcast to that shape
+ # and add the additional coords:
+ _y = ds[y].broadcast_like(ds_trimmed)
+
+ return DataArray(_y, coords=coords)
+
+
+@overload
+def scatter( # type: ignore[misc,unused-ignore] # None is hashable :(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: None = None, # no wrap -> primitive
+ col: None = None, # no wrap -> primitive
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs: Any,
+) -> PathCollection: ...
+
+
+@overload
+def scatter(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable, # wrap -> FacetGrid
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs: Any,
+) -> FacetGrid[DataArray]: ...
+
+
+@overload
+def scatter(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable, # wrap -> FacetGrid
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs: Any,
+) -> FacetGrid[DataArray]: ...
@_update_doc_to_dataset(dataarray_plot.scatter)
-def scatter(ds: Dataset, *args: Any, x: (Hashable | None)=None, y: (
- Hashable | None)=None, z: (Hashable | None)=None, hue: (Hashable | None
- )=None, hue_style: HueStyleOptions=None, markersize: (Hashable | None)=
- None, linewidth: (Hashable | None)=None, figsize: (Iterable[float] |
- None)=None, size: (float | None)=None, aspect: (float | None)=None, ax:
- (Axes | None)=None, row: (Hashable | None)=None, col: (Hashable | None)
- =None, col_wrap: (int | None)=None, xincrease: (bool | None)=True,
- yincrease: (bool | None)=True, add_legend: (bool | None)=None,
- add_colorbar: (bool | None)=None, add_labels: (bool | Iterable[bool])=
- True, add_title: bool=True, subplot_kws: (dict[str, Any] | None)=None,
- xscale: ScaleOptions=None, yscale: ScaleOptions=None, xticks: (
- ArrayLike | None)=None, yticks: (ArrayLike | None)=None, xlim: (
- ArrayLike | None)=None, ylim: (ArrayLike | None)=None, cmap: (str |
- Colormap | None)=None, vmin: (float | None)=None, vmax: (float | None)=
- None, norm: (Normalize | None)=None, extend: ExtendOptions=None, levels:
- (ArrayLike | None)=None, **kwargs: Any) ->(PathCollection | FacetGrid[
- DataArray]):
+def scatter(
+ ds: Dataset,
+ *args: Any,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: float | None = None,
+ ax: Axes | None = None,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ xincrease: bool | None = True,
+ yincrease: bool | None = True,
+ add_legend: bool | None = None,
+ add_colorbar: bool | None = None,
+ add_labels: bool | Iterable[bool] = True,
+ add_title: bool = True,
+ subplot_kws: dict[str, Any] | None = None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: ArrayLike | None = None,
+ ylim: ArrayLike | None = None,
+ cmap: str | Colormap | None = None,
+ vmin: float | None = None,
+ vmax: float | None = None,
+ norm: Normalize | None = None,
+ extend: ExtendOptions = None,
+ levels: ArrayLike | None = None,
+ **kwargs: Any,
+) -> PathCollection | FacetGrid[DataArray]:
"""Scatter plot Dataset data variables against each other."""
- pass
+ locals_ = locals()
+ del locals_["ds"]
+ locals_.update(locals_.pop("kwargs", {}))
+ da = _temp_dataarray(ds, y, locals_)
+
+ return da.plot.scatter(*locals_.pop("args", ()), **locals_)
diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py
index af73de2d..613362ed 100644
--- a/xarray/plot/facetgrid.py
+++ b/xarray/plot/facetgrid.py
@@ -1,13 +1,29 @@
from __future__ import annotations
+
import functools
import itertools
import warnings
from collections.abc import Hashable, Iterable, MutableMapping
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, cast
+
import numpy as np
+
from xarray.core.formatting import format_item
from xarray.core.types import HueStyleOptions, T_DataArrayOrSet
-from xarray.plot.utils import _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, _add_legend, _determine_guide, _get_nice_quiver_magnitude, _guess_coords_to_plot, _infer_xy_labels, _Normalize, _parse_size, _process_cmap_cbar_kwargs, label_from_attrs
+from xarray.plot.utils import (
+ _LINEWIDTH_RANGE,
+ _MARKERSIZE_RANGE,
+ _add_legend,
+ _determine_guide,
+ _get_nice_quiver_magnitude,
+ _guess_coords_to_plot,
+ _infer_xy_labels,
+ _Normalize,
+ _parse_size,
+ _process_cmap_cbar_kwargs,
+ label_from_attrs,
+)
+
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
@@ -16,8 +32,14 @@ if TYPE_CHECKING:
from matplotlib.legend import Legend
from matplotlib.quiver import QuiverKey
from matplotlib.text import Annotation
+
from xarray.core.dataarray import DataArray
-_FONTSIZE = 'small'
+
+
+# Overrides axes.labelsize, xtick.major.size, ytick.major.size
+# from mpl.rcParams
+_FONTSIZE = "small"
+# For major ticks on x, y axes
_NTICKS = 5
@@ -25,10 +47,16 @@ def _nicetitle(coord, value, maxchar, template):
"""
Put coord, value in template and truncate at maxchar
"""
- pass
+ prettyvalue = format_item(value, quote_strings=False)
+ title = template.format(coord=coord, value=prettyvalue)
+
+ if len(title) > maxchar:
+ title = title[: (maxchar - 3)] + "..."
+ return title
-T_FacetGrid = TypeVar('T_FacetGrid', bound='FacetGrid')
+
+T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid")
class FacetGrid(Generic[T_DataArrayOrSet]):
@@ -71,6 +99,7 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
used as a sentinel value for axes that should remain empty, i.e.,
sometimes the rightmost grid positions in the bottom row.
"""
+
data: T_DataArrayOrSet
name_dicts: np.ndarray
fig: Figure
@@ -94,11 +123,19 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
_mappables: list[ScalarMappable]
_finalized: bool
- def __init__(self, data: T_DataArrayOrSet, col: (Hashable | None)=None,
- row: (Hashable | None)=None, col_wrap: (int | None)=None, sharex:
- bool=True, sharey: bool=True, figsize: (Iterable[float] | None)=
- None, aspect: float=1, size: float=3, subplot_kws: (dict[str, Any] |
- None)=None) ->None:
+ def __init__(
+ self,
+ data: T_DataArrayOrSet,
+ col: Hashable | None = None,
+ row: Hashable | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ figsize: Iterable[float] | None = None,
+ aspect: float = 1,
+ size: float = 3,
+ subplot_kws: dict[str, Any] | None = None,
+ ) -> None:
"""
Parameters
----------
@@ -127,13 +164,19 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
(:py:func:`matplotlib.pyplot.subplots`).
"""
+
import matplotlib.pyplot as plt
+
+ # Handle corner case of nonunique coordinates
rep_col = col is not None and not data[col].to_index().is_unique
rep_row = row is not None and not data[row].to_index().is_unique
if rep_col or rep_row:
raise ValueError(
- 'Coordinates used for faceting cannot contain repeated (nonunique) values.'
- )
+ "Coordinates used for faceting cannot "
+ "contain repeated (nonunique) values."
+ )
+
+ # single_group is the grouping variable, if there is exactly one
single_group: bool | Hashable
if col and row:
single_group = False
@@ -141,51 +184,81 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
ncol = len(data[col])
nfacet = nrow * ncol
if col_wrap is not None:
- warnings.warn(
- 'Ignoring col_wrap since both col and row were passed')
+ warnings.warn("Ignoring col_wrap since both col and row were passed")
elif row and not col:
single_group = row
elif not row and col:
single_group = col
else:
- raise ValueError(
- 'Pass a coordinate name as an argument for row or col')
+ raise ValueError("Pass a coordinate name as an argument for row or col")
+
+ # Compute grid shape
if single_group:
nfacet = len(data[single_group])
if col:
+ # idea - could add heuristic for nice shapes like 3x4
ncol = nfacet
if row:
ncol = 1
if col_wrap is not None:
+ # Overrides previous settings
ncol = col_wrap
nrow = int(np.ceil(nfacet / ncol))
+
+ # Set the subplot kwargs
subplot_kws = {} if subplot_kws is None else subplot_kws
+
if figsize is None:
+ # Calculate the base figure size with extra horizontal space for a
+ # colorbar
cbar_space = 1
- figsize = ncol * size * aspect + cbar_space, nrow * size
- fig, axs = plt.subplots(nrow, ncol, sharex=sharex, sharey=sharey,
- squeeze=False, figsize=figsize, subplot_kw=subplot_kws)
+ figsize = (ncol * size * aspect + cbar_space, nrow * size)
+
+ fig, axs = plt.subplots(
+ nrow,
+ ncol,
+ sharex=sharex,
+ sharey=sharey,
+ squeeze=False,
+ figsize=figsize,
+ subplot_kw=subplot_kws,
+ )
+
+ # Set up the lists of names for the row and column facet variables
col_names = list(data[col].to_numpy()) if col else []
row_names = list(data[row].to_numpy()) if row else []
+
if single_group:
- full: list[dict[Hashable, Any] | None] = [{single_group: x} for
- x in data[single_group].to_numpy()]
- empty: list[dict[Hashable, Any] | None] = [None for x in range(
- nrow * ncol - len(full))]
+ full: list[dict[Hashable, Any] | None] = [
+ {single_group: x} for x in data[single_group].to_numpy()
+ ]
+ empty: list[dict[Hashable, Any] | None] = [
+ None for x in range(nrow * ncol - len(full))
+ ]
name_dict_list = full + empty
else:
rowcols = itertools.product(row_names, col_names)
name_dict_list = [{row: r, col: c} for r, c in rowcols]
+
name_dicts = np.array(name_dict_list).reshape(nrow, ncol)
+
+ # Set up the class attributes
+ # ---------------------------
+
+ # First the public API
self.data = data
self.name_dicts = name_dicts
self.fig = fig
self.axs = axs
self.row_names = row_names
self.col_names = col_names
+
+ # guides
self.figlegend = None
self.quiverkey = None
self.cbar = None
+
+ # Next the private variables
self._single_group = single_group
self._nrow = nrow
self._row_var = row
@@ -200,8 +273,45 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
self._mappables = []
self._finalized = False
- def map_dataarray(self: T_FacetGrid, func: Callable, x: (Hashable |
- None), y: (Hashable | None), **kwargs: Any) ->T_FacetGrid:
+ @property
+ def axes(self) -> np.ndarray:
+ warnings.warn(
+ (
+ "self.axes is deprecated since 2022.11 in order to align with "
+ "matplotlibs plt.subplots, use self.axs instead."
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.axs
+
+ @axes.setter
+ def axes(self, axs: np.ndarray) -> None:
+ warnings.warn(
+ (
+ "self.axes is deprecated since 2022.11 in order to align with "
+ "matplotlibs plt.subplots, use self.axs instead."
+ ),
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.axs = axs
+
+ @property
+ def _left_axes(self) -> np.ndarray:
+ return self.axs[:, 0]
+
+ @property
+ def _bottom_axes(self) -> np.ndarray:
+ return self.axs[-1, :]
+
+ def map_dataarray(
+ self: T_FacetGrid,
+ func: Callable,
+ x: Hashable | None,
+ y: Hashable | None,
+ **kwargs: Any,
+ ) -> T_FacetGrid:
"""
Apply a plotting function to a 2d facet's subset of the data.
@@ -222,12 +332,64 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
self : FacetGrid object
"""
- pass
- def map_plot1d(self: T_FacetGrid, func: Callable, x: (Hashable | None),
- y: (Hashable | None), *, z: (Hashable | None)=None, hue: (Hashable |
- None)=None, markersize: (Hashable | None)=None, linewidth: (
- Hashable | None)=None, **kwargs: Any) ->T_FacetGrid:
+ if kwargs.get("cbar_ax", None) is not None:
+ raise ValueError("cbar_ax not supported by FacetGrid.")
+
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ func, self.data.to_numpy(), **kwargs
+ )
+
+ self._cmap_extend = cmap_params.get("extend")
+
+ # Order is important
+ func_kwargs = {
+ k: v
+ for k, v in kwargs.items()
+ if k not in {"cmap", "colors", "cbar_kwargs", "levels"}
+ }
+ func_kwargs.update(cmap_params)
+ func_kwargs["add_colorbar"] = False
+ if func.__name__ != "surface":
+ func_kwargs["add_labels"] = False
+
+ # Get x, y labels for the first subplot
+ x, y = _infer_xy_labels(
+ darray=self.data.loc[self.name_dicts.flat[0]],
+ x=x,
+ y=y,
+ imshow=func.__name__ == "imshow",
+ rgb=kwargs.get("rgb", None),
+ )
+
+ for d, ax in zip(self.name_dicts.flat, self.axs.flat):
+ # None is the sentinel value
+ if d is not None:
+ subset = self.data.loc[d]
+ mappable = func(
+ subset, x=x, y=y, ax=ax, **func_kwargs, _is_facetgrid=True
+ )
+ self._mappables.append(mappable)
+
+ self._finalize_grid(x, y)
+
+ if kwargs.get("add_colorbar", True):
+ self.add_colorbar(**cbar_kwargs)
+
+ return self
+
+ def map_plot1d(
+ self: T_FacetGrid,
+ func: Callable,
+ x: Hashable | None,
+ y: Hashable | None,
+ *,
+ z: Hashable | None = None,
+ hue: Hashable | None = None,
+ markersize: Hashable | None = None,
+ linewidth: Hashable | None = None,
+ **kwargs: Any,
+ ) -> T_FacetGrid:
"""
Apply a plotting function to a 1d facet's subset of the data.
@@ -248,17 +410,357 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
self : FacetGrid object
"""
- pass
+ # Copy data to allow converting categoricals to integers and storing
+ # them in self.data. It is not possible to copy in the init
+ # unfortunately as there are tests that relies on self.data being
+ # mutable (test_names_appear_somewhere()). Maybe something to deprecate
+ # not sure how much that is used outside these tests.
+ self.data = self.data.copy()
+
+ if kwargs.get("cbar_ax", None) is not None:
+ raise ValueError("cbar_ax not supported by FacetGrid.")
+
+ if func.__name__ == "scatter":
+ size_ = kwargs.pop("_size", markersize)
+ size_r = _MARKERSIZE_RANGE
+ else:
+ size_ = kwargs.pop("_size", linewidth)
+ size_r = _LINEWIDTH_RANGE
+
+ # Guess what coords to use if some of the values in coords_to_plot are None:
+ coords_to_plot: MutableMapping[str, Hashable | None] = dict(
+ x=x, z=z, hue=hue, size=size_
+ )
+ coords_to_plot = _guess_coords_to_plot(self.data, coords_to_plot, kwargs)
+
+ # Handle hues:
+ hue = coords_to_plot["hue"]
+ hueplt = self.data.coords[hue] if hue else None # TODO: _infer_line_data2 ?
+ hueplt_norm = _Normalize(hueplt)
+ self._hue_var = hueplt
+ cbar_kwargs = kwargs.pop("cbar_kwargs", {})
+ if hueplt_norm.data is not None:
+ if not hueplt_norm.data_is_numeric:
+ # TODO: Ticks seems a little too hardcoded, since it will always
+ # show all the values. But maybe it's ok, since plotting hundreds
+ # of categorical data isn't that meaningful anyway.
+ cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks)
+ kwargs.update(levels=hueplt_norm.levels)
+
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ func,
+ cast("DataArray", hueplt_norm.values).data,
+ cbar_kwargs=cbar_kwargs,
+ **kwargs,
+ )
+ self._cmap_extend = cmap_params.get("extend")
+ else:
+ cmap_params = {}
+
+ # Handle sizes:
+ size_ = coords_to_plot["size"]
+ sizeplt = self.data.coords[size_] if size_ else None
+ sizeplt_norm = _Normalize(data=sizeplt, width=size_r)
+ if sizeplt_norm.data is not None:
+ self.data[size_] = sizeplt_norm.values
+
+ # Add kwargs that are sent to the plotting function, # order is important ???
+ func_kwargs = {
+ k: v
+ for k, v in kwargs.items()
+ if k not in {"cmap", "colors", "cbar_kwargs", "levels"}
+ }
+ func_kwargs.update(cmap_params)
+ # Annotations will be handled later, skip those parts in the plotfunc:
+ func_kwargs["add_colorbar"] = False
+ func_kwargs["add_legend"] = False
+ func_kwargs["add_title"] = False
+
+ add_labels_ = np.zeros(self.axs.shape + (3,), dtype=bool)
+ if kwargs.get("z") is not None:
+ # 3d plots looks better with all labels. 3d plots can't sharex either so it
+ # is easy to get lost while rotating the plots:
+ add_labels_[:] = True
+ else:
+ # Subplots should have labels on the left and bottom edges only:
+ add_labels_[-1, :, 0] = True # x
+ add_labels_[:, 0, 1] = True # y
+ # add_labels_[:, :, 2] = True # z
+
+ # Set up the lists of names for the row and column facet variables:
+ if self._single_group:
+ full = tuple(
+ {self._single_group: x}
+ for x in range(0, self.data[self._single_group].size)
+ )
+ empty = tuple(None for x in range(self._nrow * self._ncol - len(full)))
+ name_d = full + empty
+ else:
+ rowcols = itertools.product(
+ range(0, self.data[self._row_var].size),
+ range(0, self.data[self._col_var].size),
+ )
+ name_d = tuple({self._row_var: r, self._col_var: c} for r, c in rowcols)
+ name_dicts = np.array(name_d).reshape(self._nrow, self._ncol)
+
+ # Plot the data for each subplot:
+ for add_lbls, d, ax in zip(
+ add_labels_.reshape((self.axs.size, -1)), name_dicts.flat, self.axs.flat
+ ):
+ func_kwargs["add_labels"] = add_lbls
+ # None is the sentinel value
+ if d is not None:
+ subset = self.data.isel(d)
+ mappable = func(
+ subset,
+ x=x,
+ y=y,
+ ax=ax,
+ hue=hue,
+ _size=size_,
+ **func_kwargs,
+ _is_facetgrid=True,
+ )
+ self._mappables.append(mappable)
+
+ # Add titles and some touch ups:
+ self._finalize_grid()
+ self._set_lims()
+
+ add_colorbar, add_legend = _determine_guide(
+ hueplt_norm,
+ sizeplt_norm,
+ kwargs.get("add_colorbar", None),
+ kwargs.get("add_legend", None),
+ # kwargs.get("add_guide", None),
+ # kwargs.get("hue_style", None),
+ )
+
+ if add_legend:
+ use_legend_elements = False if func.__name__ == "hist" else True
+ if use_legend_elements:
+ self.add_legend(
+ use_legend_elements=use_legend_elements,
+ hueplt_norm=hueplt_norm if not add_colorbar else _Normalize(None),
+ sizeplt_norm=sizeplt_norm,
+ primitive=self._mappables,
+ legend_ax=self.fig,
+ plotfunc=func.__name__,
+ )
+ else:
+ self.add_legend(use_legend_elements=use_legend_elements)
- def _finalize_grid(self, *axlabels: Hashable) ->None:
+ if add_colorbar:
+ # Colorbar is after legend so it correctly fits the plot:
+ if "label" not in cbar_kwargs:
+ cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data)
+
+ self.add_colorbar(**cbar_kwargs)
+
+ return self
+
+ def map_dataarray_line(
+ self: T_FacetGrid,
+ func: Callable,
+ x: Hashable | None,
+ y: Hashable | None,
+ hue: Hashable | None,
+ add_legend: bool = True,
+ _labels=None,
+ **kwargs: Any,
+ ) -> T_FacetGrid:
+ from xarray.plot.dataarray_plot import _infer_line_data
+
+ for d, ax in zip(self.name_dicts.flat, self.axs.flat):
+ # None is the sentinel value
+ if d is not None:
+ subset = self.data.loc[d]
+ mappable = func(
+ subset,
+ x=x,
+ y=y,
+ ax=ax,
+ hue=hue,
+ add_legend=False,
+ _labels=False,
+ **kwargs,
+ )
+ self._mappables.append(mappable)
+
+ xplt, yplt, hueplt, huelabel = _infer_line_data(
+ darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue
+ )
+ xlabel = label_from_attrs(xplt)
+ ylabel = label_from_attrs(yplt)
+
+ self._hue_var = hueplt
+ self._finalize_grid(xlabel, ylabel)
+
+ if add_legend and hueplt is not None and huelabel is not None:
+ self.add_legend(label=huelabel)
+
+ return self
+
+ def map_dataset(
+ self: T_FacetGrid,
+ func: Callable,
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ hue: Hashable | None = None,
+ hue_style: HueStyleOptions = None,
+ add_guide: bool | None = None,
+ **kwargs: Any,
+ ) -> T_FacetGrid:
+ from xarray.plot.dataset_plot import _infer_meta_data
+
+ kwargs["add_guide"] = False
+
+ if kwargs.get("markersize", None):
+ kwargs["size_mapping"] = _parse_size(
+ self.data[kwargs["markersize"]], kwargs.pop("size_norm", None)
+ )
+
+ meta_data = _infer_meta_data(
+ self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__
+ )
+ kwargs["meta_data"] = meta_data
+
+ if hue and meta_data["hue_style"] == "continuous":
+ cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
+ func, self.data[hue].to_numpy(), **kwargs
+ )
+ kwargs["meta_data"]["cmap_params"] = cmap_params
+ kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs
+
+ kwargs["_is_facetgrid"] = True
+
+ if func.__name__ == "quiver" and "scale" not in kwargs:
+ raise ValueError("Please provide scale.")
+ # TODO: come up with an algorithm for reasonable scale choice
+
+ for d, ax in zip(self.name_dicts.flat, self.axs.flat):
+ # None is the sentinel value
+ if d is not None:
+ subset = self.data.loc[d]
+ maybe_mappable = func(
+ ds=subset, x=x, y=y, hue=hue, hue_style=hue_style, ax=ax, **kwargs
+ )
+ # TODO: this is needed to get legends to work.
+ # but maybe_mappable is a list in that case :/
+ self._mappables.append(maybe_mappable)
+
+ self._finalize_grid(meta_data["xlabel"], meta_data["ylabel"])
+
+ if hue:
+ hue_label = meta_data.pop("hue_label", None)
+ self._hue_label = hue_label
+ if meta_data["add_legend"]:
+ self._hue_var = meta_data["hue"]
+ self.add_legend(label=hue_label)
+ elif meta_data["add_colorbar"]:
+ self.add_colorbar(label=hue_label, **cbar_kwargs)
+
+ if meta_data["add_quiverkey"]:
+ self.add_quiverkey(kwargs["u"], kwargs["v"])
+
+ return self
+
+ def _finalize_grid(self, *axlabels: Hashable) -> None:
"""Finalize the annotations and layout."""
- pass
+ if not self._finalized:
+ self.set_axis_labels(*axlabels)
+ self.set_titles()
+ self.fig.tight_layout()
+
+ for ax, namedict in zip(self.axs.flat, self.name_dicts.flat):
+ if namedict is None:
+ ax.set_visible(False)
+
+ self._finalized = True
+
+ def _adjust_fig_for_guide(self, guide) -> None:
+ # Draw the plot to set the bounding boxes correctly
+ if hasattr(self.fig.canvas, "get_renderer"):
+ renderer = self.fig.canvas.get_renderer()
+ else:
+ raise RuntimeError("MPL backend has no renderer")
+ self.fig.draw(renderer)
+
+ # Calculate and set the new width of the figure so the legend fits
+ guide_width = guide.get_window_extent(renderer).width / self.fig.dpi
+ figure_width = self.fig.get_figwidth()
+ total_width = figure_width + guide_width
+ self.fig.set_figwidth(total_width)
+
+ # Draw the plot again to get the new transformations
+ self.fig.draw(renderer)
+
+ # Now calculate how much space we need on the right side
+ guide_width = guide.get_window_extent(renderer).width / self.fig.dpi
+ space_needed = guide_width / total_width + 0.02
+ # margin = .01
+ # _space_needed = margin + space_needed
+ right = 1 - space_needed
- def add_colorbar(self, **kwargs: Any) ->None:
+ # Place the subplot axes to give space for the legend
+ self.fig.subplots_adjust(right=right)
+
+ def add_legend(
+ self,
+ *,
+ label: str | None = None,
+ use_legend_elements: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ if use_legend_elements:
+ self.figlegend = _add_legend(**kwargs)
+ else:
+ self.figlegend = self.fig.legend(
+ handles=self._mappables[-1],
+ labels=list(self._hue_var.to_numpy()),
+ title=label if label is not None else label_from_attrs(self._hue_var),
+ loc=kwargs.pop("loc", "center right"),
+ **kwargs,
+ )
+ self._adjust_fig_for_guide(self.figlegend)
+
+ def add_colorbar(self, **kwargs: Any) -> None:
"""Draw a colorbar."""
- pass
+ kwargs = kwargs.copy()
+ if self._cmap_extend is not None:
+ kwargs.setdefault("extend", self._cmap_extend)
+ # dont pass extend as kwarg if it is in the mappable
+ if hasattr(self._mappables[-1], "extend"):
+ kwargs.pop("extend", None)
+ if "label" not in kwargs:
+ from xarray import DataArray
+
+ assert isinstance(self.data, DataArray)
+ kwargs.setdefault("label", label_from_attrs(self.data))
+ self.cbar = self.fig.colorbar(
+ self._mappables[-1], ax=list(self.axs.flat), **kwargs
+ )
- def _get_largest_lims(self) ->dict[str, tuple[float, float]]:
+ def add_quiverkey(self, u: Hashable, v: Hashable, **kwargs: Any) -> None:
+ kwargs = kwargs.copy()
+
+ magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v])
+ units = self.data[u].attrs.get("units", "")
+ self.quiverkey = self.axs.flat[-1].quiverkey(
+ self._mappables[-1],
+ X=0.8,
+ Y=0.9,
+ U=magnitude,
+ label=f"{magnitude}\n{units}",
+ labelpos="E",
+ coordinates="figure",
+ )
+
+ # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0
+ # https://github.com/matplotlib/matplotlib/issues/18530
+ # self._adjust_fig_for_guide(self.quiverkey.text)
+
+ def _get_largest_lims(self) -> dict[str, tuple[float, float]]:
"""
Get largest limits in the facetgrid.
@@ -274,11 +776,29 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
>>> round(fg._get_largest_lims()["x"][0], 3)
np.float64(-0.334)
"""
- pass
+ lims_largest: dict[str, tuple[float, float]] = dict(
+ x=(np.inf, -np.inf), y=(np.inf, -np.inf), z=(np.inf, -np.inf)
+ )
+ for axis in ("x", "y", "z"):
+ # Find the plot with the largest xlim values:
+ lower, upper = lims_largest[axis]
+ for ax in self.axs.flat:
+ get_lim: None | Callable[[], tuple[float, float]] = getattr(
+ ax, f"get_{axis}lim", None
+ )
+ if get_lim:
+ lower_new, upper_new = get_lim()
+ lower, upper = (min(lower, lower_new), max(upper, upper_new))
+ lims_largest[axis] = (lower, upper)
- def _set_lims(self, x: (tuple[float, float] | None)=None, y: (tuple[
- float, float] | None)=None, z: (tuple[float, float] | None)=None
- ) ->None:
+ return lims_largest
+
+ def _set_lims(
+ self,
+ x: tuple[float, float] | None = None,
+ y: tuple[float, float] | None = None,
+ z: tuple[float, float] | None = None,
+ ) -> None:
"""
Set the same limits for all the subplots in the facetgrid.
@@ -299,26 +819,55 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
>>> fg.axs[0, 0].get_xlim(), fg.axs[0, 0].get_ylim()
((np.float64(-0.3), np.float64(0.3)), (np.float64(0.0), np.float64(2.0)))
"""
- pass
+ lims_largest = self._get_largest_lims()
+
+ # Set limits:
+ for ax in self.axs.flat:
+ for (axis, data_limit), parameter_limit in zip(
+ lims_largest.items(), (x, y, z)
+ ):
+ set_lim = getattr(ax, f"set_{axis}lim", None)
+ if set_lim:
+ set_lim(data_limit if parameter_limit is None else parameter_limit)
- def set_axis_labels(self, *axlabels: Hashable) ->None:
+ def set_axis_labels(self, *axlabels: Hashable) -> None:
"""Set axis labels on the left column and bottom row of the grid."""
- pass
+ from xarray.core.dataarray import DataArray
+
+ for var, axis in zip(axlabels, ["x", "y", "z"]):
+ if var is not None:
+ if isinstance(var, DataArray):
+ getattr(self, f"set_{axis}labels")(label_from_attrs(var))
+ else:
+ getattr(self, f"set_{axis}labels")(str(var))
+
+ def _set_labels(
+ self, axis: str, axes: Iterable, label: str | None = None, **kwargs
+ ) -> None:
+ if label is None:
+ label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")])
+ for ax in axes:
+ getattr(ax, f"set_{axis}label")(label, **kwargs)
- def set_xlabels(self, label: (None | str)=None, **kwargs: Any) ->None:
+ def set_xlabels(self, label: None | str = None, **kwargs: Any) -> None:
"""Label the x axis on the bottom row of the grid."""
- pass
+ self._set_labels("x", self._bottom_axes, label, **kwargs)
- def set_ylabels(self, label: (None | str)=None, **kwargs: Any) ->None:
+ def set_ylabels(self, label: None | str = None, **kwargs: Any) -> None:
"""Label the y axis on the left column of the grid."""
- pass
+ self._set_labels("y", self._left_axes, label, **kwargs)
- def set_zlabels(self, label: (None | str)=None, **kwargs: Any) ->None:
+ def set_zlabels(self, label: None | str = None, **kwargs: Any) -> None:
"""Label the z axis."""
- pass
+ self._set_labels("z", self._left_axes, label, **kwargs)
- def set_titles(self, template: str='{coord} = {value}', maxchar: int=30,
- size=None, **kwargs) ->None:
+ def set_titles(
+ self,
+ template: str = "{coord} = {value}",
+ maxchar: int = 30,
+ size=None,
+ **kwargs,
+ ) -> None:
"""
Draw titles either above each facet or on the grid margins.
@@ -336,10 +885,57 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
self: FacetGrid object
"""
- pass
+ import matplotlib as mpl
- def set_ticks(self, max_xticks: int=_NTICKS, max_yticks: int=_NTICKS,
- fontsize: (str | int)=_FONTSIZE) ->None:
+ if size is None:
+ size = mpl.rcParams["axes.labelsize"]
+
+ nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)
+
+ if self._single_group:
+ for d, ax in zip(self.name_dicts.flat, self.axs.flat):
+ # Only label the ones with data
+ if d is not None:
+ coord, value = list(d.items()).pop()
+ title = nicetitle(coord, value, maxchar=maxchar)
+ ax.set_title(title, size=size, **kwargs)
+ else:
+ # The row titles on the right edge of the grid
+ for index, (ax, row_name, handle) in enumerate(
+ zip(self.axs[:, -1], self.row_names, self.row_labels)
+ ):
+ title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar)
+ if not handle:
+ self.row_labels[index] = ax.annotate(
+ title,
+ xy=(1.02, 0.5),
+ xycoords="axes fraction",
+ rotation=270,
+ ha="left",
+ va="center",
+ **kwargs,
+ )
+ else:
+ handle.set_text(title)
+ handle.update(kwargs)
+
+ # The column titles on the top row
+ for index, (ax, col_name, handle) in enumerate(
+ zip(self.axs[0, :], self.col_names, self.col_labels)
+ ):
+ title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar)
+ if not handle:
+ self.col_labels[index] = ax.set_title(title, size=size, **kwargs)
+ else:
+ handle.set_text(title)
+ handle.update(kwargs)
+
+ def set_ticks(
+ self,
+ max_xticks: int = _NTICKS,
+ max_yticks: int = _NTICKS,
+ fontsize: str | int = _FONTSIZE,
+ ) -> None:
"""
Set and control tick behavior.
@@ -355,10 +951,23 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
self : FacetGrid object
"""
- pass
+ from matplotlib.ticker import MaxNLocator
+
+ # Both are necessary
+ x_major_locator = MaxNLocator(nbins=max_xticks)
+ y_major_locator = MaxNLocator(nbins=max_yticks)
+
+ for ax in self.axs.flat:
+ ax.xaxis.set_major_locator(x_major_locator)
+ ax.yaxis.set_major_locator(y_major_locator)
+ for tick in itertools.chain(
+ ax.xaxis.get_major_ticks(), ax.yaxis.get_major_ticks()
+ ):
+ tick.label1.set_fontsize(fontsize)
- def map(self: T_FacetGrid, func: Callable, *args: Hashable, **kwargs: Any
- ) ->T_FacetGrid:
+ def map(
+ self: T_FacetGrid, func: Callable, *args: Hashable, **kwargs: Any
+ ) -> T_FacetGrid:
"""
Apply a plotting function to each facet's subset of the data.
@@ -381,20 +990,85 @@ class FacetGrid(Generic[T_DataArrayOrSet]):
self : FacetGrid object
"""
- pass
+ import matplotlib.pyplot as plt
+
+ for ax, namedict in zip(self.axs.flat, self.name_dicts.flat):
+ if namedict is not None:
+ data = self.data.loc[namedict]
+ plt.sca(ax)
+ innerargs = [data[a].to_numpy() for a in args]
+ maybe_mappable = func(*innerargs, **kwargs)
+ # TODO: better way to verify that an artist is mappable?
+ # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522
+ if maybe_mappable and hasattr(maybe_mappable, "autoscale_None"):
+ self._mappables.append(maybe_mappable)
+
+ self._finalize_grid(*args[:2])
+ return self
-def _easy_facetgrid(data: T_DataArrayOrSet, plotfunc: Callable, kind:
- Literal['line', 'dataarray', 'dataset', 'plot1d'], x: (Hashable | None)
- =None, y: (Hashable | None)=None, row: (Hashable | None)=None, col: (
- Hashable | None)=None, col_wrap: (int | None)=None, sharex: bool=True,
- sharey: bool=True, aspect: (float | None)=None, size: (float | None)=
- None, subplot_kws: (dict[str, Any] | None)=None, ax: (Axes | None)=None,
- figsize: (Iterable[float] | None)=None, **kwargs: Any) ->FacetGrid[
- T_DataArrayOrSet]:
+
+def _easy_facetgrid(
+ data: T_DataArrayOrSet,
+ plotfunc: Callable,
+ kind: Literal["line", "dataarray", "dataset", "plot1d"],
+ x: Hashable | None = None,
+ y: Hashable | None = None,
+ row: Hashable | None = None,
+ col: Hashable | None = None,
+ col_wrap: int | None = None,
+ sharex: bool = True,
+ sharey: bool = True,
+ aspect: float | None = None,
+ size: float | None = None,
+ subplot_kws: dict[str, Any] | None = None,
+ ax: Axes | None = None,
+ figsize: Iterable[float] | None = None,
+ **kwargs: Any,
+) -> FacetGrid[T_DataArrayOrSet]:
"""
Convenience method to call xarray.plot.FacetGrid from 2d plotting methods
kwargs are the arguments to 2d plotting method
"""
- pass
+ if ax is not None:
+ raise ValueError("Can't use axes when making faceted plots.")
+ if aspect is None:
+ aspect = 1
+ if size is None:
+ size = 3
+ elif figsize is not None:
+ raise ValueError("cannot provide both `figsize` and `size` arguments")
+ if kwargs.get("z") is not None:
+ # 3d plots doesn't support sharex, sharey, reset to mpl defaults:
+ sharex = False
+ sharey = False
+
+ g = FacetGrid(
+ data=data,
+ col=col,
+ row=row,
+ col_wrap=col_wrap,
+ sharex=sharex,
+ sharey=sharey,
+ figsize=figsize,
+ aspect=aspect,
+ size=size,
+ subplot_kws=subplot_kws,
+ )
+
+ if kind == "line":
+ return g.map_dataarray_line(plotfunc, x, y, **kwargs)
+
+ if kind == "dataarray":
+ return g.map_dataarray(plotfunc, x, y, **kwargs)
+
+ if kind == "plot1d":
+ return g.map_plot1d(plotfunc, x, y, **kwargs)
+
+ if kind == "dataset":
+ return g.map_dataset(plotfunc, x, y, **kwargs)
+
+ raise ValueError(
+ f"kind must be one of `line`, `dataarray`, `dataset` or `plot1d`, got {kind}"
+ )
diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py
index 6309e4c6..a0abe0c8 100644
--- a/xarray/plot/utils.py
+++ b/xarray/plot/utils.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+
import itertools
import textwrap
import warnings
@@ -6,44 +7,163 @@ from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequenc
from datetime import date, datetime
from inspect import getfullargspec
from typing import TYPE_CHECKING, Any, Callable, Literal, overload
+
import numpy as np
import pandas as pd
+
from xarray.core.indexes import PandasMultiIndex
from xarray.core.options import OPTIONS
from xarray.core.utils import is_scalar, module_available
from xarray.namedarray.pycompat import DuckArrayModule
-nc_time_axis_available = module_available('nc_time_axis')
+
+nc_time_axis_available = module_available("nc_time_axis")
+
+
try:
import cftime
except ImportError:
cftime = None
+
+
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.colors import Normalize
from matplotlib.ticker import FuncFormatter
from numpy.typing import ArrayLike
+
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import AspectOptions, ScaleOptions
+
try:
import matplotlib.pyplot as plt
except ImportError:
- plt: Any = None
+ plt: Any = None # type: ignore
+
ROBUST_PERCENTILE = 2.0
-_MARKERSIZE_RANGE = 18.0, 36.0, 72.0
-_LINEWIDTH_RANGE = 1.5, 1.5, 6.0
+
+# copied from seaborn
+_MARKERSIZE_RANGE = (18.0, 36.0, 72.0)
+_LINEWIDTH_RANGE = (1.5, 1.5, 6.0)
+
+
+def _determine_extend(calc_data, vmin, vmax):
+ extend_min = calc_data.min() < vmin
+ extend_max = calc_data.max() > vmax
+ if extend_min and extend_max:
+ return "both"
+ elif extend_min:
+ return "min"
+ elif extend_max:
+ return "max"
+ else:
+ return "neither"
def _build_discrete_cmap(cmap, levels, extend, filled):
"""
Build a discrete colormap and normalization of the data.
"""
- pass
+ import matplotlib as mpl
+
+ if len(levels) == 1:
+ levels = [levels[0], levels[0]]
+
+ if not filled:
+ # non-filled contour plots
+ extend = "max"
+ if extend == "both":
+ ext_n = 2
+ elif extend in ["min", "max"]:
+ ext_n = 1
+ else:
+ ext_n = 0
-def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
- center=None, robust=False, extend=None, levels=None, filled=True, norm=
- None, _is_facetgrid=False):
+ n_colors = len(levels) + ext_n - 1
+ pal = _color_palette(cmap, n_colors)
+
+ new_cmap, cnorm = mpl.colors.from_levels_and_colors(levels, pal, extend=extend)
+ # copy the old cmap name, for easier testing
+ new_cmap.name = getattr(cmap, "name", cmap)
+
+ # copy colors to use for bad, under, and over values in case they have been
+ # set to non-default values
+ try:
+ # matplotlib<3.2 only uses bad color for masked values
+ bad = cmap(np.ma.masked_invalid([np.nan]))[0]
+ except TypeError:
+ # cmap was a str or list rather than a color-map object, so there are
+ # no bad, under or over values to check or copy
+ pass
+ else:
+ under = cmap(-np.inf)
+ over = cmap(np.inf)
+
+ new_cmap.set_bad(bad)
+
+ # Only update under and over if they were explicitly changed by the user
+ # (i.e. are different from the lowest or highest values in cmap). Otherwise
+ # leave unchanged so new_cmap uses its default values (its own lowest and
+ # highest values).
+ if under != cmap(0):
+ new_cmap.set_under(under)
+ if over != cmap(cmap.N - 1):
+ new_cmap.set_over(over)
+
+ return new_cmap, cnorm
+
+
+def _color_palette(cmap, n_colors):
+ import matplotlib.pyplot as plt
+ from matplotlib.colors import ListedColormap
+
+ colors_i = np.linspace(0, 1.0, n_colors)
+ if isinstance(cmap, (list, tuple)):
+ # we have a list of colors
+ cmap = ListedColormap(cmap, N=n_colors)
+ pal = cmap(colors_i)
+ elif isinstance(cmap, str):
+ # we have some sort of named palette
+ try:
+ # is this a matplotlib cmap?
+ cmap = plt.get_cmap(cmap)
+ pal = cmap(colors_i)
+ except ValueError:
+ # ValueError happens when mpl doesn't like a colormap, try seaborn
+ try:
+ from seaborn import color_palette
+
+ pal = color_palette(cmap, n_colors=n_colors)
+ except (ValueError, ImportError):
+ # or maybe we just got a single color as a string
+ cmap = ListedColormap([cmap], N=n_colors)
+ pal = cmap(colors_i)
+ else:
+ # cmap better be a LinearSegmentedColormap (e.g. viridis)
+ pal = cmap(colors_i)
+
+ return pal
+
+
+# _determine_cmap_params is adapted from Seaborn:
+# https://github.com/mwaskom/seaborn/blob/v0.6/seaborn/matrix.py#L158
+# Used under the terms of Seaborn's license, see licenses/SEABORN_LICENSE.
+
+
+def _determine_cmap_params(
+ plot_data,
+ vmin=None,
+ vmax=None,
+ cmap=None,
+ center=None,
+ robust=False,
+ extend=None,
+ levels=None,
+ filled=True,
+ norm=None,
+ _is_facetgrid=False,
+):
"""
Use some heuristics to set good defaults for colorbar and range.
@@ -57,84 +177,448 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap_params : dict
Use depends on the type of the plotting function
"""
- pass
-
-
-def _infer_xy_labels_3d(darray: (DataArray | Dataset), x: (Hashable | None),
- y: (Hashable | None), rgb: (Hashable | None)) ->tuple[Hashable, Hashable]:
+ import matplotlib as mpl
+
+ if isinstance(levels, Iterable):
+ levels = sorted(levels)
+
+ calc_data = np.ravel(plot_data[np.isfinite(plot_data)])
+
+ # Handle all-NaN input data gracefully
+ if calc_data.size == 0:
+ # Arbitrary default for when all values are NaN
+ calc_data = np.array(0.0)
+
+ # Setting center=False prevents a divergent cmap
+ possibly_divergent = center is not False
+
+ # Set center to 0 so math below makes sense but remember its state
+ center_is_none = False
+ if center is None:
+ center = 0
+ center_is_none = True
+
+ # Setting both vmin and vmax prevents a divergent cmap
+ if (vmin is not None) and (vmax is not None):
+ possibly_divergent = False
+
+ # Setting vmin or vmax implies linspaced levels
+ user_minmax = (vmin is not None) or (vmax is not None)
+
+ # vlim might be computed below
+ vlim = None
+
+ # save state; needed later
+ vmin_was_none = vmin is None
+ vmax_was_none = vmax is None
+
+ if vmin is None:
+ if robust:
+ vmin = np.percentile(calc_data, ROBUST_PERCENTILE)
+ else:
+ vmin = calc_data.min()
+ elif possibly_divergent:
+ vlim = abs(vmin - center)
+
+ if vmax is None:
+ if robust:
+ vmax = np.percentile(calc_data, 100 - ROBUST_PERCENTILE)
+ else:
+ vmax = calc_data.max()
+ elif possibly_divergent:
+ vlim = abs(vmax - center)
+
+ if possibly_divergent:
+ levels_are_divergent = (
+ isinstance(levels, Iterable) and levels[0] * levels[-1] < 0
+ )
+ # kwargs not specific about divergent or not: infer defaults from data
+ divergent = (
+ ((vmin < 0) and (vmax > 0)) or not center_is_none or levels_are_divergent
+ )
+ else:
+ divergent = False
+
+ # A divergent map should be symmetric around the center value
+ if divergent:
+ if vlim is None:
+ vlim = max(abs(vmin - center), abs(vmax - center))
+ vmin, vmax = -vlim, vlim
+
+ # Now add in the centering value and set the limits
+ vmin += center
+ vmax += center
+
+ # now check norm and harmonize with vmin, vmax
+ if norm is not None:
+ if norm.vmin is None:
+ norm.vmin = vmin
+ else:
+ if not vmin_was_none and vmin != norm.vmin:
+ raise ValueError("Cannot supply vmin and a norm with a different vmin.")
+ vmin = norm.vmin
+
+ if norm.vmax is None:
+ norm.vmax = vmax
+ else:
+ if not vmax_was_none and vmax != norm.vmax:
+ raise ValueError("Cannot supply vmax and a norm with a different vmax.")
+ vmax = norm.vmax
+
+ # if BoundaryNorm, then set levels
+ if isinstance(norm, mpl.colors.BoundaryNorm):
+ levels = norm.boundaries
+
+ # Choose default colormaps if not provided
+ if cmap is None:
+ if divergent:
+ cmap = OPTIONS["cmap_divergent"]
+ else:
+ cmap = OPTIONS["cmap_sequential"]
+
+ # Handle discrete levels
+ if levels is not None:
+ if is_scalar(levels):
+ if user_minmax:
+ levels = np.linspace(vmin, vmax, levels)
+ elif levels == 1:
+ levels = np.asarray([(vmin + vmax) / 2])
+ else:
+ # N in MaxNLocator refers to bins, not ticks
+ ticker = mpl.ticker.MaxNLocator(levels - 1)
+ levels = ticker.tick_values(vmin, vmax)
+ vmin, vmax = levels[0], levels[-1]
+
+ # GH3734
+ if vmin == vmax:
+ vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax)
+
+ if extend is None:
+ extend = _determine_extend(calc_data, vmin, vmax)
+
+ if (levels is not None) and (not isinstance(norm, mpl.colors.BoundaryNorm)):
+ cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled)
+ norm = newnorm if norm is None else norm
+
+ # vmin & vmax needs to be None if norm is passed
+ # TODO: always return a norm with vmin and vmax
+ if norm is not None:
+ vmin = None
+ vmax = None
+
+ return dict(
+ vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm
+ )
+
+
+def _infer_xy_labels_3d(
+ darray: DataArray | Dataset,
+ x: Hashable | None,
+ y: Hashable | None,
+ rgb: Hashable | None,
+) -> tuple[Hashable, Hashable]:
"""
Determine x and y labels for showing RGB images.
Attempts to infer which dimension is RGB/RGBA by size and order of dims.
"""
- pass
-
-
-def _infer_xy_labels(darray: (DataArray | Dataset), x: (Hashable | None), y:
- (Hashable | None), imshow: bool=False, rgb: (Hashable | None)=None
- ) ->tuple[Hashable, Hashable]:
+ assert rgb is None or rgb != x
+ assert rgb is None or rgb != y
+ # Start by detecting and reporting invalid combinations of arguments
+ assert darray.ndim == 3
+ not_none = [a for a in (x, y, rgb) if a is not None]
+ if len(set(not_none)) < len(not_none):
+ raise ValueError(
+ "Dimension names must be None or unique strings, but imshow was "
+ f"passed x={x!r}, y={y!r}, and rgb={rgb!r}."
+ )
+ for label in not_none:
+ if label not in darray.dims:
+ raise ValueError(f"{label!r} is not a dimension")
+
+ # Then calculate rgb dimension if certain and check validity
+ could_be_color = [
+ label
+ for label in darray.dims
+ if darray[label].size in (3, 4) and label not in (x, y)
+ ]
+ if rgb is None and not could_be_color:
+ raise ValueError(
+ "A 3-dimensional array was passed to imshow(), but there is no "
+ "dimension that could be color. At least one dimension must be "
+ "of size 3 (RGB) or 4 (RGBA), and not given as x or y."
+ )
+ if rgb is None and len(could_be_color) == 1:
+ rgb = could_be_color[0]
+ if rgb is not None and darray[rgb].size not in (3, 4):
+ raise ValueError(
+ f"Cannot interpret dim {rgb!r} of size {darray[rgb].size} as RGB or RGBA."
+ )
+
+ # If rgb dimension is still unknown, there must be two or three dimensions
+ # in could_be_color. We therefore warn, and use a heuristic to break ties.
+ if rgb is None:
+ assert len(could_be_color) in (2, 3)
+ rgb = could_be_color[-1]
+ warnings.warn(
+ "Several dimensions of this array could be colors. Xarray "
+ f"will use the last possible dimension ({rgb!r}) to match "
+ "matplotlib.pyplot.imshow. You can pass names of x, y, "
+ "and/or rgb dimensions to override this guess."
+ )
+ assert rgb is not None
+
+ # Finally, we pick out the red slice and delegate to the 2D version:
+ return _infer_xy_labels(darray.isel({rgb: 0}), x, y)
+
+
+def _infer_xy_labels(
+ darray: DataArray | Dataset,
+ x: Hashable | None,
+ y: Hashable | None,
+ imshow: bool = False,
+ rgb: Hashable | None = None,
+) -> tuple[Hashable, Hashable]:
"""
Determine x and y labels. For use in _plot2d
darray must be a 2 dimensional data array, or 3d for imshow only.
"""
- pass
-
-
-def _assert_valid_xy(darray: (DataArray | Dataset), xy: (Hashable | None),
- name: str) ->None:
+ if (x is not None) and (x == y):
+ raise ValueError("x and y cannot be equal.")
+
+ if imshow and darray.ndim == 3:
+ return _infer_xy_labels_3d(darray, x, y, rgb)
+
+ if x is None and y is None:
+ if darray.ndim != 2:
+ raise ValueError("DataArray must be 2d")
+ y, x = darray.dims
+ elif x is None:
+ _assert_valid_xy(darray, y, "y")
+ x = darray.dims[0] if y == darray.dims[1] else darray.dims[1]
+ elif y is None:
+ _assert_valid_xy(darray, x, "x")
+ y = darray.dims[0] if x == darray.dims[1] else darray.dims[1]
+ else:
+ _assert_valid_xy(darray, x, "x")
+ _assert_valid_xy(darray, y, "y")
+
+ if darray._indexes.get(x, 1) is darray._indexes.get(y, 2):
+ if isinstance(darray._indexes[x], PandasMultiIndex):
+ raise ValueError("x and y cannot be levels of the same MultiIndex")
+
+ return x, y
+
+
+# TODO: Can by used to more than x or y, rename?
+def _assert_valid_xy(
+ darray: DataArray | Dataset, xy: Hashable | None, name: str
+) -> None:
"""
make sure x and y passed to plotting functions are valid
"""
- pass
+ # MultiIndex cannot be plotted; no point in allowing them here
+ multiindex_dims = {
+ idx.dim
+ for idx in darray.xindexes.get_unique()
+ if isinstance(idx, PandasMultiIndex)
+ }
+
+ valid_xy = (set(darray.dims) | set(darray.coords)) - multiindex_dims
+
+ if (xy is not None) and (xy not in valid_xy):
+ valid_xy_str = "', '".join(sorted(tuple(str(v) for v in valid_xy)))
+ raise ValueError(
+ f"{name} must be one of None, '{valid_xy_str}'. Received '{xy}' instead."
+ )
+
+
+def get_axis(
+ figsize: Iterable[float] | None = None,
+ size: float | None = None,
+ aspect: AspectOptions = None,
+ ax: Axes | None = None,
+ **subplot_kws: Any,
+) -> Axes:
+ try:
+ import matplotlib as mpl
+ import matplotlib.pyplot as plt
+ except ImportError:
+ raise ImportError("matplotlib is required for plot.utils.get_axis")
-def _get_units_from_attrs(da: DataArray) ->str:
- """Extracts and formats the unit/units from a attributes."""
- pass
+ if figsize is not None:
+ if ax is not None:
+ raise ValueError("cannot provide both `figsize` and `ax` arguments")
+ if size is not None:
+ raise ValueError("cannot provide both `figsize` and `size` arguments")
+ _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws)
+ return ax
+ if size is not None:
+ if ax is not None:
+ raise ValueError("cannot provide both `size` and `ax` arguments")
+ if aspect is None or aspect == "auto":
+ width, height = mpl.rcParams["figure.figsize"]
+ faspect = width / height
+ elif aspect == "equal":
+ faspect = 1
+ else:
+ faspect = aspect
+ figsize = (size * faspect, size)
+ _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws)
+ return ax
-def label_from_attrs(da: (DataArray | None), extra: str='') ->str:
- """Makes informative labels if variable metadata (attrs) follows
- CF conventions."""
- pass
+ if aspect is not None:
+ raise ValueError("cannot provide `aspect` argument without `size`")
+
+ if subplot_kws and ax is not None:
+ raise ValueError("cannot use subplot_kws with existing ax")
+
+ if ax is None:
+ ax = _maybe_gca(**subplot_kws)
+
+ return ax
+
+
+def _maybe_gca(**subplot_kws: Any) -> Axes:
+ import matplotlib.pyplot as plt
+
+ # can call gcf unconditionally: either it exists or would be created by plt.axes
+ f = plt.gcf()
+ # only call gca if an active axes exists
+ if f.axes:
+ # can not pass kwargs to active axes
+ return plt.gca()
-def _interval_to_mid_points(array: Iterable[pd.Interval]) ->np.ndarray:
+ return plt.axes(**subplot_kws)
+
+
+def _get_units_from_attrs(da: DataArray) -> str:
+ """Extracts and formats the unit/units from a attributes."""
+ pint_array_type = DuckArrayModule("pint").type
+ units = " [{}]"
+ if isinstance(da.data, pint_array_type):
+ return units.format(str(da.data.units))
+ if "units" in da.attrs:
+ return units.format(da.attrs["units"])
+ if "unit" in da.attrs:
+ return units.format(da.attrs["unit"])
+ return ""
+
+
+def label_from_attrs(da: DataArray | None, extra: str = "") -> str:
+ """Makes informative labels if variable metadata (attrs) follows
+ CF conventions."""
+ if da is None:
+ return ""
+
+ name: str = "{}"
+ if "long_name" in da.attrs:
+ name = name.format(da.attrs["long_name"])
+ elif "standard_name" in da.attrs:
+ name = name.format(da.attrs["standard_name"])
+ elif da.name is not None:
+ name = name.format(da.name)
+ else:
+ name = ""
+
+ units = _get_units_from_attrs(da)
+
+ # Treat `name` differently if it's a latex sequence
+ if name.startswith("$") and (name.count("$") % 2 == 0):
+ return "$\n$".join(
+ textwrap.wrap(name + extra + units, 60, break_long_words=False)
+ )
+ else:
+ return "\n".join(textwrap.wrap(name + extra + units, 30))
+
+
+def _interval_to_mid_points(array: Iterable[pd.Interval]) -> np.ndarray:
"""
Helper function which returns an array
with the Intervals' mid points.
"""
- pass
+
+ return np.array([x.mid for x in array])
-def _interval_to_bound_points(array: Sequence[pd.Interval]) ->np.ndarray:
+def _interval_to_bound_points(array: Sequence[pd.Interval]) -> np.ndarray:
"""
Helper function which returns an array
with the Intervals' boundaries.
"""
- pass
+ array_boundaries = np.array([x.left for x in array])
+ array_boundaries = np.concatenate((array_boundaries, np.array([array[-1].right])))
-def _interval_to_double_bound_points(xarray: Iterable[pd.Interval], yarray:
- Iterable) ->tuple[np.ndarray, np.ndarray]:
+ return array_boundaries
+
+
+def _interval_to_double_bound_points(
+ xarray: Iterable[pd.Interval], yarray: Iterable
+) -> tuple[np.ndarray, np.ndarray]:
"""
Helper function to deal with a xarray consisting of pd.Intervals. Each
interval is replaced with both boundaries. I.e. the length of xarray
doubles. yarray is modified so it matches the new shape of xarray.
"""
- pass
+ xarray1 = np.array([x.left for x in xarray])
+ xarray2 = np.array([x.right for x in xarray])
+
+ xarray_out = np.array(list(itertools.chain.from_iterable(zip(xarray1, xarray2))))
+ yarray_out = np.array(list(itertools.chain.from_iterable(zip(yarray, yarray))))
-def _resolve_intervals_1dplot(xval: np.ndarray, yval: np.ndarray, kwargs: dict
- ) ->tuple[np.ndarray, np.ndarray, str, str, dict]:
+ return xarray_out, yarray_out
+
+
+def _resolve_intervals_1dplot(
+ xval: np.ndarray, yval: np.ndarray, kwargs: dict
+) -> tuple[np.ndarray, np.ndarray, str, str, dict]:
"""
Helper function to replace the values of x and/or y coordinate arrays
containing pd.Interval with their mid-points or - for step plots - double
points which double the length.
"""
- pass
+ x_suffix = ""
+ y_suffix = ""
+
+ # Is it a step plot? (see matplotlib.Axes.step)
+ if kwargs.get("drawstyle", "").startswith("steps-"):
+ remove_drawstyle = False
+
+ # Convert intervals to double points
+ x_is_interval = _valid_other_type(xval, pd.Interval)
+ y_is_interval = _valid_other_type(yval, pd.Interval)
+ if x_is_interval and y_is_interval:
+ raise TypeError("Can't step plot intervals against intervals.")
+ elif x_is_interval:
+ xval, yval = _interval_to_double_bound_points(xval, yval)
+ remove_drawstyle = True
+ elif y_is_interval:
+ yval, xval = _interval_to_double_bound_points(yval, xval)
+ remove_drawstyle = True
+
+ # Remove steps-* to be sure that matplotlib is not confused
+ if remove_drawstyle:
+ del kwargs["drawstyle"]
+
+ # Is it another kind of plot?
+ else:
+ # Convert intervals to mid points and adjust labels
+ if _valid_other_type(xval, pd.Interval):
+ xval = _interval_to_mid_points(xval)
+ x_suffix = "_center"
+ if _valid_other_type(yval, pd.Interval):
+ yval = _interval_to_mid_points(yval)
+ y_suffix = "_center"
+
+ # return converted arguments
+ return xval, yval, x_suffix, y_suffix, kwargs
def _resolve_intervals_2dplot(val, func_name):
@@ -143,41 +627,185 @@ def _resolve_intervals_2dplot(val, func_name):
pd.Interval with their mid-points or - for pcolormesh - boundaries which
increases length by 1.
"""
- pass
+ label_extra = ""
+ if _valid_other_type(val, pd.Interval):
+ if func_name == "pcolormesh":
+ val = _interval_to_bound_points(val)
+ else:
+ val = _interval_to_mid_points(val)
+ label_extra = "_center"
+
+ return val, label_extra
-def _valid_other_type(x: ArrayLike, types: (type[object] | tuple[type[
- object], ...])) ->bool:
+def _valid_other_type(
+ x: ArrayLike, types: type[object] | tuple[type[object], ...]
+) -> bool:
"""
Do all elements of x have a type from types?
"""
- pass
+ return all(isinstance(el, types) for el in np.ravel(x))
def _valid_numpy_subdtype(x, numpy_types):
"""
Is any dtype from numpy_types superior to the dtype of x?
"""
- pass
+ # If any of the types given in numpy_types is understood as numpy.generic,
+ # all possible x will be considered valid. This is probably unwanted.
+ for t in numpy_types:
+ assert not np.issubdtype(np.generic, t)
+ return any(np.issubdtype(x.dtype, t) for t in numpy_types)
-def _ensure_plottable(*args) ->None:
+
+def _ensure_plottable(*args) -> None:
"""
Raise exception if there is anything in args that can't be plotted on an
axis by matplotlib.
"""
- pass
-
-
-def _update_axes(ax: Axes, xincrease: (bool | None), yincrease: (bool |
- None), xscale: ScaleOptions=None, yscale: ScaleOptions=None, xticks: (
- ArrayLike | None)=None, yticks: (ArrayLike | None)=None, xlim: (tuple[
- float, float] | None)=None, ylim: (tuple[float, float] | None)=None
- ) ->None:
+ numpy_types: tuple[type[object], ...] = (
+ np.floating,
+ np.integer,
+ np.timedelta64,
+ np.datetime64,
+ np.bool_,
+ np.str_,
+ )
+ other_types: tuple[type[object], ...] = (datetime, date)
+ cftime_datetime_types: tuple[type[object], ...] = (
+ () if cftime is None else (cftime.datetime,)
+ )
+ other_types += cftime_datetime_types
+
+ for x in args:
+ if not (
+ _valid_numpy_subdtype(np.asarray(x), numpy_types)
+ or _valid_other_type(np.asarray(x), other_types)
+ ):
+ raise TypeError(
+ "Plotting requires coordinates to be numeric, boolean, "
+ "or dates of type numpy.datetime64, "
+ "datetime.datetime, cftime.datetime or "
+ f"pandas.Interval. Received data of type {np.asarray(x).dtype} instead."
+ )
+ if _valid_other_type(np.asarray(x), cftime_datetime_types):
+ if nc_time_axis_available:
+ # Register cftime datetypes to matplotlib.units.registry,
+ # otherwise matplotlib will raise an error:
+ import nc_time_axis # noqa: F401
+ else:
+ raise ImportError(
+ "Plotting of arrays of cftime.datetime "
+ "objects or arrays indexed by "
+ "cftime.datetime objects requires the "
+ "optional `nc-time-axis` (v1.2.0 or later) "
+ "package."
+ )
+
+
+def _is_numeric(arr):
+ numpy_types = [np.floating, np.integer]
+ return _valid_numpy_subdtype(arr, numpy_types)
+
+
+def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params):
+ cbar_kwargs.setdefault("extend", cmap_params["extend"])
+ if cbar_ax is None:
+ cbar_kwargs.setdefault("ax", ax)
+ else:
+ cbar_kwargs.setdefault("cax", cbar_ax)
+
+ # dont pass extend as kwarg if it is in the mappable
+ if hasattr(primitive, "extend"):
+ cbar_kwargs.pop("extend")
+
+ fig = ax.get_figure()
+ cbar = fig.colorbar(primitive, **cbar_kwargs)
+
+ return cbar
+
+
+def _rescale_imshow_rgb(darray, vmin, vmax, robust):
+ assert robust or vmin is not None or vmax is not None
+
+ # Calculate vmin and vmax automatically for `robust=True`
+ if robust:
+ if vmax is None:
+ vmax = np.nanpercentile(darray, 100 - ROBUST_PERCENTILE)
+ if vmin is None:
+ vmin = np.nanpercentile(darray, ROBUST_PERCENTILE)
+ # If not robust and one bound is None, calculate the default other bound
+ # and check that an interval between them exists.
+ elif vmax is None:
+ vmax = 255 if np.issubdtype(darray.dtype, np.integer) else 1
+ if vmax < vmin:
+ raise ValueError(
+ f"vmin={vmin!r} is less than the default vmax ({vmax!r}) - you must supply "
+ "a vmax > vmin in this case."
+ )
+ elif vmin is None:
+ vmin = 0
+ if vmin > vmax:
+ raise ValueError(
+ f"vmax={vmax!r} is less than the default vmin (0) - you must supply "
+ "a vmin < vmax in this case."
+ )
+ # Scale interval [vmin .. vmax] to [0 .. 1], with darray as 64-bit float
+ # to avoid precision loss, integer over/underflow, etc with extreme inputs.
+ # After scaling, downcast to 32-bit float. This substantially reduces
+ # memory usage after we hand `darray` off to matplotlib.
+ darray = ((darray.astype("f8") - vmin) / (vmax - vmin)).astype("f4")
+ return np.minimum(np.maximum(darray, 0), 1)
+
+
+def _update_axes(
+ ax: Axes,
+ xincrease: bool | None,
+ yincrease: bool | None,
+ xscale: ScaleOptions = None,
+ yscale: ScaleOptions = None,
+ xticks: ArrayLike | None = None,
+ yticks: ArrayLike | None = None,
+ xlim: tuple[float, float] | None = None,
+ ylim: tuple[float, float] | None = None,
+) -> None:
"""
Update axes with provided parameters
"""
- pass
+ if xincrease is None:
+ pass
+ elif xincrease and ax.xaxis_inverted():
+ ax.invert_xaxis()
+ elif not xincrease and not ax.xaxis_inverted():
+ ax.invert_xaxis()
+
+ if yincrease is None:
+ pass
+ elif yincrease and ax.yaxis_inverted():
+ ax.invert_yaxis()
+ elif not yincrease and not ax.yaxis_inverted():
+ ax.invert_yaxis()
+
+ # The default xscale, yscale needs to be None.
+ # If we set a scale it resets the axes formatters,
+ # This means that set_xscale('linear') on a datetime axis
+ # will remove the date labels. So only set the scale when explicitly
+ # asked to. https://github.com/matplotlib/matplotlib/issues/8740
+ if xscale is not None:
+ ax.set_xscale(xscale)
+ if yscale is not None:
+ ax.set_yscale(yscale)
+
+ if xticks is not None:
+ ax.set_xticks(xticks)
+ if yticks is not None:
+ ax.set_yticks(yticks)
+
+ if xlim is not None:
+ ax.set_xlim(xlim)
+ if ylim is not None:
+ ax.set_ylim(ylim)
def _is_monotonic(coord, axis=0):
@@ -189,7 +817,17 @@ def _is_monotonic(coord, axis=0):
>>> _is_monotonic(np.array([0, 2, 1]))
np.False_
"""
- pass
+ if coord.shape[axis] < 3:
+ return True
+ else:
+ n = coord.shape[axis]
+ delta_pos = coord.take(np.arange(1, n), axis=axis) >= coord.take(
+ np.arange(0, n - 1), axis=axis
+ )
+ delta_neg = coord.take(np.arange(1, n), axis=axis) <= coord.take(
+ np.arange(0, n - 1), axis=axis
+ )
+ return np.all(delta_pos) or np.all(delta_neg)
def _infer_interval_breaks(coord, axis=0, scale=None, check_monotonic=False):
@@ -203,13 +841,54 @@ def _infer_interval_breaks(coord, axis=0, scale=None, check_monotonic=False):
array([3.16227766e-03, 3.16227766e-02, 3.16227766e-01, 3.16227766e+00,
3.16227766e+01, 3.16227766e+02])
"""
- pass
-
-
-def _process_cmap_cbar_kwargs(func, data, cmap=None, colors=None,
- cbar_kwargs: (Iterable[tuple[str, Any]] | Mapping[str, Any] | None)=
- None, levels=None, _is_facetgrid=False, **kwargs) ->tuple[dict[str, Any
- ], dict[str, Any]]:
+ coord = np.asarray(coord)
+
+ if check_monotonic and not _is_monotonic(coord, axis=axis):
+ raise ValueError(
+ "The input coordinate is not sorted in increasing "
+ "order along axis %d. This can lead to unexpected "
+ "results. Consider calling the `sortby` method on "
+ "the input DataArray. To plot data with categorical "
+ "axes, consider using the `heatmap` function from "
+ "the `seaborn` statistical plotting library." % axis
+ )
+
+ # If logscale, compute the intervals in the logarithmic space
+ if scale == "log":
+ if (coord <= 0).any():
+ raise ValueError(
+ "Found negative or zero value in coordinates. "
+ + "Coordinates must be positive on logscale plots."
+ )
+ coord = np.log10(coord)
+
+ deltas = 0.5 * np.diff(coord, axis=axis)
+ if deltas.size == 0:
+ deltas = np.array(0.0)
+ first = np.take(coord, [0], axis=axis) - np.take(deltas, [0], axis=axis)
+ last = np.take(coord, [-1], axis=axis) + np.take(deltas, [-1], axis=axis)
+ trim_last = tuple(
+ slice(None, -1) if n == axis else slice(None) for n in range(coord.ndim)
+ )
+ interval_breaks = np.concatenate(
+ [first, coord[trim_last] + deltas, last], axis=axis
+ )
+ if scale == "log":
+ # Recovert the intervals into the linear space
+ return np.power(10, interval_breaks)
+ return interval_breaks
+
+
+def _process_cmap_cbar_kwargs(
+ func,
+ data,
+ cmap=None,
+ colors=None,
+ cbar_kwargs: Iterable[tuple[str, Any]] | Mapping[str, Any] | None = None,
+ levels=None,
+ _is_facetgrid=False,
+ **kwargs,
+) -> tuple[dict[str, Any], dict[str, Any]]:
"""
Parameters
----------
@@ -222,11 +901,70 @@ def _process_cmap_cbar_kwargs(func, data, cmap=None, colors=None,
cmap_params : dict
cbar_kwargs : dict
"""
- pass
-
-
-def legend_elements(self, prop='colors', num='auto', fmt=None, func=lambda
- x: x, **kwargs):
+ if func.__name__ == "surface":
+ # Leave user to specify cmap settings for surface plots
+ kwargs["cmap"] = cmap
+ return {
+ k: kwargs.get(k, None)
+ for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
+ }, {}
+
+ cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs)
+
+ if "contour" in func.__name__ and levels is None:
+ levels = 7 # this is the matplotlib default
+
+ # colors is mutually exclusive with cmap
+ if cmap and colors:
+ raise ValueError("Can't specify both cmap and colors.")
+
+ # colors is only valid when levels is supplied or the plot is of type
+ # contour or contourf
+ if colors and (("contour" not in func.__name__) and (levels is None)):
+ raise ValueError("Can only specify colors with contour or levels")
+
+ # we should not be getting a list of colors in cmap anymore
+ # is there a better way to do this test?
+ if isinstance(cmap, (list, tuple)):
+ raise ValueError(
+ "Specifying a list of colors in cmap is deprecated. "
+ "Use colors keyword instead."
+ )
+
+ cmap_kwargs = {
+ "plot_data": data,
+ "levels": levels,
+ "cmap": colors if colors else cmap,
+ "filled": func.__name__ != "contour",
+ }
+
+ cmap_args = getfullargspec(_determine_cmap_params).args
+ cmap_kwargs.update((a, kwargs[a]) for a in cmap_args if a in kwargs)
+ if not _is_facetgrid:
+ cmap_params = _determine_cmap_params(**cmap_kwargs)
+ else:
+ cmap_params = {
+ k: cmap_kwargs[k]
+ for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
+ }
+
+ return cmap_params, cbar_kwargs
+
+
+def _get_nice_quiver_magnitude(u, v):
+ import matplotlib as mpl
+
+ ticker = mpl.ticker.MaxNLocator(3)
+ mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
+ magnitude = ticker.tick_values(0, mean)[-2]
+ return magnitude
+
+
+# Copied from matplotlib, tweaked so func can return strings.
+# https://github.com/matplotlib/matplotlib/issues/19555
+def legend_elements(
+ self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs
+):
"""
Create legend handles and labels for a PathCollection.
@@ -289,17 +1027,322 @@ def legend_elements(self, prop='colors', num='auto', fmt=None, func=lambda
labels : list of str
The string labels for elements of the legend.
"""
- pass
+ import warnings
+
+ import matplotlib as mpl
+
+ mlines = mpl.lines
+
+ handles = []
+ labels = []
+
+ if prop == "colors":
+ arr = self.get_array()
+ if arr is None:
+ warnings.warn(
+ "Collection without array used. Make sure to "
+ "specify the values to be colormapped via the "
+ "`c` argument."
+ )
+ return handles, labels
+ _size = kwargs.pop("size", mpl.rcParams["lines.markersize"])
+
+ def _get_color_and_size(value):
+ return self.cmap(self.norm(value)), _size
+
+ elif prop == "sizes":
+ if isinstance(self, mpl.collections.LineCollection):
+ arr = self.get_linewidths()
+ else:
+ arr = self.get_sizes()
+ _color = kwargs.pop("color", "k")
+
+ def _get_color_and_size(value):
+ return _color, np.sqrt(value)
+
+ else:
+ raise ValueError(
+ "Valid values for `prop` are 'colors' or "
+ f"'sizes'. You supplied '{prop}' instead."
+ )
+
+ # Get the unique values and their labels:
+ values = np.unique(arr)
+ label_values = np.asarray(func(values))
+ label_values_are_numeric = np.issubdtype(label_values.dtype, np.number)
+
+ # Handle the label format:
+ if fmt is None and label_values_are_numeric:
+ fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
+ elif fmt is None and not label_values_are_numeric:
+ fmt = mpl.ticker.StrMethodFormatter("{x}")
+ elif isinstance(fmt, str):
+ fmt = mpl.ticker.StrMethodFormatter(fmt)
+ fmt.create_dummy_axis()
+
+ if num == "auto":
+ num = 9
+ if len(values) <= num:
+ num = None
+
+ if label_values_are_numeric:
+ label_values_min = label_values.min()
+ label_values_max = label_values.max()
+ fmt.axis.set_view_interval(label_values_min, label_values_max)
+ fmt.axis.set_data_interval(label_values_min, label_values_max)
+
+ if num is not None:
+ # Labels are numerical but larger than the target
+ # number of elements, reduce to target using matplotlibs
+ # ticker classes:
+ if isinstance(num, mpl.ticker.Locator):
+ loc = num
+ elif np.iterable(num):
+ loc = mpl.ticker.FixedLocator(num)
+ else:
+ num = int(num)
+ loc = mpl.ticker.MaxNLocator(
+ nbins=num, min_n_ticks=num - 1, steps=[1, 2, 2.5, 3, 5, 6, 8, 10]
+ )
+
+ # Get nicely spaced label_values:
+ label_values = loc.tick_values(label_values_min, label_values_max)
+
+ # Remove extrapolated label_values:
+ cond = (label_values >= label_values_min) & (
+ label_values <= label_values_max
+ )
+ label_values = label_values[cond]
+
+ # Get the corresponding values by creating a linear interpolant
+ # with small step size:
+ values_interp = np.linspace(values.min(), values.max(), 256)
+ label_values_interp = func(values_interp)
+ ix = np.argsort(label_values_interp)
+ values = np.interp(label_values, label_values_interp[ix], values_interp[ix])
+ elif num is not None and not label_values_are_numeric:
+ # Labels are not numerical so modifying label_values is not
+ # possible, instead filter the array with nicely distributed
+ # indexes:
+ if type(num) == int: # noqa: E721
+ loc = mpl.ticker.LinearLocator(num)
+ else:
+ raise ValueError("`num` only supports integers for non-numeric labels.")
+
+ ind = loc.tick_values(0, len(label_values) - 1).astype(int)
+ label_values = label_values[ind]
+ values = values[ind]
+
+ # Some formatters requires set_locs:
+ if hasattr(fmt, "set_locs"):
+ fmt.set_locs(label_values)
+
+ # Default settings for handles, add or override with kwargs:
+ kw = dict(markeredgewidth=self.get_linewidths()[0], alpha=self.get_alpha())
+ kw.update(kwargs)
+
+ for val, lab in zip(values, label_values):
+ color, size = _get_color_and_size(val)
+
+ if isinstance(self, mpl.collections.PathCollection):
+ kw.update(linestyle="", marker=self.get_paths()[0], markersize=size)
+ elif isinstance(self, mpl.collections.LineCollection):
+ kw.update(linestyle=self.get_linestyle()[0], linewidth=size)
+
+ h = mlines.Line2D([0], [0], color=color, **kw)
+
+ handles.append(h)
+ labels.append(fmt(lab))
+
+ return handles, labels
def _legend_add_subtitle(handles, labels, text):
"""Add a subtitle to legend handles."""
- pass
+ import matplotlib.pyplot as plt
+
+ if text and len(handles) > 1:
+ # Create a blank handle that's not visible, the
+ # invisibillity will be used to discern which are subtitles
+ # or not:
+ blank_handle = plt.Line2D([], [], label=text)
+ blank_handle.set_visible(False)
+
+ # Subtitles are shown first:
+ handles = [blank_handle] + handles
+ labels = [text] + labels
+
+ return handles, labels
def _adjust_legend_subtitles(legend):
"""Make invisible-handle "subtitles" entries look more like titles."""
- pass
+ import matplotlib.pyplot as plt
+
+ # Legend title not in rcParams until 3.0
+ font_size = plt.rcParams.get("legend.title_fontsize", None)
+ hpackers = legend.findobj(plt.matplotlib.offsetbox.VPacker)[0].get_children()
+ hpackers = [v for v in hpackers if isinstance(v, plt.matplotlib.offsetbox.HPacker)]
+ for hpack in hpackers:
+ areas = hpack.get_children()
+ if len(areas) < 2:
+ continue
+ draw_area, text_area = areas
+
+ handles = draw_area.get_children()
+
+ # Assume that all artists that are not visible are
+ # subtitles:
+ if not all(artist.get_visible() for artist in handles):
+ # Remove the dummy marker which will bring the text
+ # more to the center:
+ draw_area.set_width(0)
+ for text in text_area.get_children():
+ if font_size is not None:
+ # The sutbtitles should have the same font size
+ # as normal legend titles:
+ text.set_size(font_size)
+
+
+def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
+ dvars = set(ds.variables.keys())
+ error_msg = f" must be one of ({', '.join(sorted(tuple(str(v) for v in dvars)))})"
+
+ if x not in dvars:
+ raise ValueError(f"Expected 'x' {error_msg}. Received {x} instead.")
+
+ if y not in dvars:
+ raise ValueError(f"Expected 'y' {error_msg}. Received {y} instead.")
+
+ if hue is not None and hue not in dvars:
+ raise ValueError(f"Expected 'hue' {error_msg}. Received {hue} instead.")
+
+ if hue:
+ hue_is_numeric = _is_numeric(ds[hue].values)
+
+ if hue_style is None:
+ hue_style = "continuous" if hue_is_numeric else "discrete"
+
+ if not hue_is_numeric and (hue_style == "continuous"):
+ raise ValueError(
+ f"Cannot create a colorbar for a non numeric coordinate: {hue}"
+ )
+
+ if add_guide is None or add_guide is True:
+ add_colorbar = True if hue_style == "continuous" else False
+ add_legend = True if hue_style == "discrete" else False
+ else:
+ add_colorbar = False
+ add_legend = False
+ else:
+ if add_guide is True and funcname not in ("quiver", "streamplot"):
+ raise ValueError("Cannot set add_guide when hue is None.")
+ add_legend = False
+ add_colorbar = False
+
+ if (add_guide or add_guide is None) and funcname == "quiver":
+ add_quiverkey = True
+ if hue:
+ add_colorbar = True
+ if not hue_style:
+ hue_style = "continuous"
+ elif hue_style != "continuous":
+ raise ValueError(
+ "hue_style must be 'continuous' or None for .plot.quiver or "
+ ".plot.streamplot"
+ )
+ else:
+ add_quiverkey = False
+
+ if (add_guide or add_guide is None) and funcname == "streamplot":
+ if hue:
+ add_colorbar = True
+ if not hue_style:
+ hue_style = "continuous"
+ elif hue_style != "continuous":
+ raise ValueError(
+ "hue_style must be 'continuous' or None for .plot.quiver or "
+ ".plot.streamplot"
+ )
+
+ if hue_style is not None and hue_style not in ["discrete", "continuous"]:
+ raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.")
+
+ if hue:
+ hue_label = label_from_attrs(ds[hue])
+ hue = ds[hue]
+ else:
+ hue_label = None
+ hue = None
+
+ return {
+ "add_colorbar": add_colorbar,
+ "add_legend": add_legend,
+ "add_quiverkey": add_quiverkey,
+ "hue_label": hue_label,
+ "hue_style": hue_style,
+ "xlabel": label_from_attrs(ds[x]),
+ "ylabel": label_from_attrs(ds[y]),
+ "hue": hue,
+ }
+
+
+@overload
+def _parse_size(
+ data: None,
+ norm: tuple[float | None, float | None, bool] | Normalize | None,
+) -> None: ...
+
+
+@overload
+def _parse_size(
+ data: DataArray,
+ norm: tuple[float | None, float | None, bool] | Normalize | None,
+) -> pd.Series: ...
+
+
+# copied from seaborn
+def _parse_size(
+ data: DataArray | None,
+ norm: tuple[float | None, float | None, bool] | Normalize | None,
+) -> None | pd.Series:
+ import matplotlib as mpl
+
+ if data is None:
+ return None
+
+ flatdata = data.values.flatten()
+
+ if not _is_numeric(flatdata):
+ levels = np.unique(flatdata)
+ numbers = np.arange(1, 1 + len(levels))[::-1]
+ else:
+ levels = numbers = np.sort(np.unique(flatdata))
+
+ min_width, default_width, max_width = _MARKERSIZE_RANGE
+ # width_range = min_width, max_width
+
+ if norm is None:
+ norm = mpl.colors.Normalize()
+ elif isinstance(norm, tuple):
+ norm = mpl.colors.Normalize(*norm)
+ elif not isinstance(norm, mpl.colors.Normalize):
+ err = "``size_norm`` must be None, tuple, or Normalize object."
+ raise ValueError(err)
+ assert isinstance(norm, mpl.colors.Normalize)
+
+ norm.clip = True
+ if not norm.scaled():
+ norm(np.asarray(numbers))
+ # limits = norm.vmin, norm.vmax
+
+ scl = norm(numbers)
+ widths = np.asarray(min_width + scl * (max_width - min_width))
+ if scl.mask.any():
+ widths[scl.mask] = 0
+ sizes = dict(zip(levels, widths))
+
+ return pd.Series(sizes)
class _Normalize(Sequence):
@@ -317,42 +1360,63 @@ class _Normalize(Sequence):
Normalize the data to these (min, default, max) values.
The default is None.
"""
+
_data: DataArray | None
_data_unique: np.ndarray
_data_unique_index: np.ndarray
_data_unique_inverse: np.ndarray
_data_is_numeric: bool
_width: tuple[float, float, float] | None
- __slots__ = ('_data', '_data_unique', '_data_unique_index',
- '_data_unique_inverse', '_data_is_numeric', '_width')
- def __init__(self, data: (DataArray | None), width: (tuple[float, float,
- float] | None)=None, _is_facetgrid: bool=False) ->None:
+ __slots__ = (
+ "_data",
+ "_data_unique",
+ "_data_unique_index",
+ "_data_unique_inverse",
+ "_data_is_numeric",
+ "_width",
+ )
+
+ def __init__(
+ self,
+ data: DataArray | None,
+ width: tuple[float, float, float] | None = None,
+ _is_facetgrid: bool = False,
+ ) -> None:
self._data = data
self._width = width if not _is_facetgrid else None
- pint_array_type = DuckArrayModule('pint').type
- to_unique = data.to_numpy() if isinstance(data if data is None else
- data.data, pint_array_type) else data
- data_unique, data_unique_inverse = np.unique(to_unique,
- return_inverse=True)
+
+ pint_array_type = DuckArrayModule("pint").type
+ to_unique = (
+ data.to_numpy() # type: ignore[union-attr]
+ if isinstance(data if data is None else data.data, pint_array_type)
+ else data
+ )
+ data_unique, data_unique_inverse = np.unique(to_unique, return_inverse=True) # type: ignore[call-overload]
self._data_unique = data_unique
self._data_unique_index = np.arange(0, data_unique.size)
self._data_unique_inverse = data_unique_inverse
self._data_is_numeric = False if data is None else _is_numeric(data)
- def __repr__(self) ->str:
+ def __repr__(self) -> str:
with np.printoptions(precision=4, suppress=True, threshold=5):
- return f"""<_Normalize(data, width={self._width})>
-{self._data_unique} -> {self._values_unique}"""
+ return (
+ f"<_Normalize(data, width={self._width})>\n"
+ f"{self._data_unique} -> {self._values_unique}"
+ )
- def __len__(self) ->int:
+ def __len__(self) -> int:
return len(self._data_unique)
def __getitem__(self, key):
return self._data_unique[key]
@property
- def data_is_numeric(self) ->bool:
+ def data(self) -> DataArray | None:
+ return self._data
+
+ @property
+ def data_is_numeric(self) -> bool:
"""
Check if data is numeric.
@@ -376,25 +1440,48 @@ class _Normalize(Sequence):
>>> _Normalize(a).data_is_numeric
True
"""
- pass
+ return self._data_is_numeric
+
+ @overload
+ def _calc_widths(self, y: np.ndarray) -> np.ndarray: ...
- def _calc_widths(self, y: (np.ndarray | DataArray)) ->(np.ndarray |
- DataArray):
+ @overload
+ def _calc_widths(self, y: DataArray) -> DataArray: ...
+
+ def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray:
"""
Normalize the values so they're in between self._width.
"""
- pass
+ if self._width is None:
+ return y
+
+ xmin, xdefault, xmax = self._width
+
+ diff_maxy_miny = np.max(y) - np.min(y)
+ if diff_maxy_miny == 0:
+ # Use default with if y is constant:
+ widths = xdefault + 0 * y
+ else:
+ # Normalize in between xmin and xmax:
+ k = (y - np.min(y)) / diff_maxy_miny
+ widths = xmin + k * (xmax - xmin)
+ return widths
+
+ @overload
+ def _indexes_centered(self, x: np.ndarray) -> np.ndarray: ...
+
+ @overload
+ def _indexes_centered(self, x: DataArray) -> DataArray: ...
- def _indexes_centered(self, x: (np.ndarray | DataArray)) ->(np.ndarray |
- DataArray):
+ def _indexes_centered(self, x: np.ndarray | DataArray) -> np.ndarray | DataArray:
"""
Offset indexes to make sure being in the center of self.levels.
["a", "b", "c"] -> [1, 3, 5]
"""
- pass
+ return x * 2 + 1
@property
- def values(self) ->(DataArray | None):
+ def values(self) -> DataArray | None:
"""
Return a normalized number array for the unique levels.
@@ -428,10 +1515,20 @@ class _Normalize(Sequence):
Dimensions without coordinates: dim_0
"""
- pass
+ if self.data is None:
+ return None
+
+ val: DataArray
+ if self.data_is_numeric:
+ val = self.data
+ else:
+ arr = self._indexes_centered(self._data_unique_inverse)
+ val = self.data.copy(data=arr.reshape(self.data.shape))
+
+ return self._calc_widths(val)
@property
- def _values_unique(self) ->(np.ndarray | None):
+ def _values_unique(self) -> np.ndarray | None:
"""
Return unique values.
@@ -451,10 +1548,19 @@ class _Normalize(Sequence):
>>> _Normalize(a, width=(18, 36, 72))._values_unique
array([18., 27., 54., 72.])
"""
- pass
+ if self.data is None:
+ return None
+
+ val: np.ndarray
+ if self.data_is_numeric:
+ val = self._data_unique
+ else:
+ val = self._indexes_centered(self._data_unique_index)
+
+ return self._calc_widths(val)
@property
- def ticks(self) ->(np.ndarray | None):
+ def ticks(self) -> np.ndarray | None:
"""
Return ticks for plt.colorbar if the data is not numeric.
@@ -464,10 +1570,16 @@ class _Normalize(Sequence):
>>> _Normalize(a).ticks
array([1, 3, 5])
"""
- pass
+ val: None | np.ndarray
+ if self.data_is_numeric:
+ val = None
+ else:
+ val = self._indexes_centered(self._data_unique_index)
+
+ return val
@property
- def levels(self) ->np.ndarray:
+ def levels(self) -> np.ndarray:
"""
Return discrete levels that will evenly bound self.values.
["a", "b", "c"] -> [0, 2, 4, 6]
@@ -478,10 +1590,26 @@ class _Normalize(Sequence):
>>> _Normalize(a).levels
array([0, 2, 4, 6])
"""
- pass
+ return (
+ np.append(self._data_unique_index, np.max(self._data_unique_index) + 1) * 2
+ )
+
+ @property
+ def _lookup(self) -> pd.Series:
+ if self._values_unique is None:
+ raise ValueError("self.data can't be None.")
+
+ return pd.Series(dict(zip(self._values_unique, self._data_unique)))
+
+ def _lookup_arr(self, x) -> np.ndarray:
+ # Use reindex to be less sensitive to float errors. reindex only
+ # works with sorted index.
+ # Return as numpy array since legend_elements
+ # seems to require that:
+ return self._lookup.sort_index().reindex(x, method="nearest").to_numpy()
@property
- def format(self) ->FuncFormatter:
+ def format(self) -> FuncFormatter:
"""
Return a FuncFormatter that maps self.values elements back to
the original value as a string. Useful with plt.colorbar.
@@ -499,10 +1627,15 @@ class _Normalize(Sequence):
>>> aa.format(1)
'3.0'
"""
- pass
+ import matplotlib.pyplot as plt
+
+ def _func(x: Any, pos: None | Any = None):
+ return f"{self._lookup_arr([x])[0]}"
+
+ return plt.FuncFormatter(_func)
@property
- def func(self) ->Callable[[Any, None | Any], Any]:
+ def func(self) -> Callable[[Any, None | Any], Any]:
"""
Return a lambda function that maps self.values elements back to
the original value as a numpy array. Useful with ax.legend_elements.
@@ -520,13 +1653,95 @@ class _Normalize(Sequence):
>>> aa.func([0.16, 1])
array([0.5, 3. ])
"""
- pass
-
-def _guess_coords_to_plot(darray: DataArray, coords_to_plot: MutableMapping
- [str, Hashable | None], kwargs: dict, default_guess: tuple[str, ...]=(
- 'x',), ignore_guess_kwargs: tuple[tuple[str, ...], ...]=((),)
- ) ->MutableMapping[str, Hashable]:
+ def _func(x: Any, pos: None | Any = None):
+ return self._lookup_arr(x)
+
+ return _func
+
+
+def _determine_guide(
+ hueplt_norm: _Normalize,
+ sizeplt_norm: _Normalize,
+ add_colorbar: None | bool = None,
+ add_legend: None | bool = None,
+ plotfunc_name: str | None = None,
+) -> tuple[bool, bool]:
+ if plotfunc_name == "hist":
+ return False, False
+
+ if (add_colorbar) and hueplt_norm.data is None:
+ raise KeyError("Cannot create a colorbar when hue is None.")
+ if add_colorbar is None:
+ if hueplt_norm.data is not None:
+ add_colorbar = True
+ else:
+ add_colorbar = False
+
+ if add_legend and hueplt_norm.data is None and sizeplt_norm.data is None:
+ raise KeyError("Cannot create a legend when hue and markersize is None.")
+ if add_legend is None:
+ if (
+ not add_colorbar
+ and (hueplt_norm.data is not None and hueplt_norm.data_is_numeric is False)
+ or sizeplt_norm.data is not None
+ ):
+ add_legend = True
+ else:
+ add_legend = False
+
+ return add_colorbar, add_legend
+
+
+def _add_legend(
+ hueplt_norm: _Normalize,
+ sizeplt_norm: _Normalize,
+ primitive,
+ legend_ax,
+ plotfunc: str,
+):
+ primitive = primitive if isinstance(primitive, list) else [primitive]
+
+ handles, labels = [], []
+ for huesizeplt, prop in [
+ (hueplt_norm, "colors"),
+ (sizeplt_norm, "sizes"),
+ ]:
+ if huesizeplt.data is not None:
+ # Get legend handles and labels that displays the
+ # values correctly. Order might be different because
+ # legend_elements uses np.unique instead of pd.unique,
+ # FacetGrid.add_legend might have troubles with this:
+ hdl, lbl = [], []
+ for p in primitive:
+ hdl_, lbl_ = legend_elements(p, prop, num="auto", func=huesizeplt.func)
+ hdl += hdl_
+ lbl += lbl_
+
+ # Only save unique values:
+ u, ind = np.unique(lbl, return_index=True)
+ ind = np.argsort(ind)
+ lbl = u[ind].tolist()
+ hdl = np.array(hdl)[ind].tolist()
+
+ # Add a subtitle:
+ hdl, lbl = _legend_add_subtitle(hdl, lbl, label_from_attrs(huesizeplt.data))
+ handles += hdl
+ labels += lbl
+ legend = legend_ax.legend(handles, labels, framealpha=0.5)
+ _adjust_legend_subtitles(legend)
+
+ return legend
+
+
+def _guess_coords_to_plot(
+ darray: DataArray,
+ coords_to_plot: MutableMapping[str, Hashable | None],
+ kwargs: dict,
+ default_guess: tuple[str, ...] = ("x",),
+ # TODO: Can this be normalized, plt.cbook.normalize_kwargs?
+ ignore_guess_kwargs: tuple[tuple[str, ...], ...] = ((),),
+) -> MutableMapping[str, Hashable]:
"""
Guess what coords to plot if some of the values in coords_to_plot are None which
happens when the user has not defined all available ways of visualizing
@@ -587,10 +1802,28 @@ def _guess_coords_to_plot(darray: DataArray, coords_to_plot: MutableMapping
... )
{'x': 'y', 'z': None, 'hue': 'z', 'size': 'x'}
"""
- pass
+ coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None}
+ available_coords = tuple(
+ k for k in darray.coords.keys() if k not in coords_to_plot_exist.values()
+ )
+
+ # If dims_plot[k] isn't defined then fill with one of the available dims, unless
+ # one of related mpl kwargs has been used. This should have similar behaviour as
+ # * plt.plot(x, y) -> Multiple lines with different colors if y is 2d.
+ # * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d.
+ for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs):
+ if coords_to_plot.get(k, None) is None and all(
+ kwargs.get(ign_kw, None) is None for ign_kw in ign_kws
+ ):
+ coords_to_plot[k] = dim
+ for k, dim in coords_to_plot.items():
+ _assert_valid_xy(darray, dim, k)
-def _set_concise_date(ax: Axes, axis: Literal['x', 'y', 'z']='x') ->None:
+ return coords_to_plot
+
+
+def _set_concise_date(ax: Axes, axis: Literal["x", "y", "z"] = "x") -> None:
"""
Use ConciseDateFormatter which is meant to improve the
strings chosen for the ticklabels, and to minimize the
@@ -605,4 +1838,10 @@ def _set_concise_date(ax: Axes, axis: Literal['x', 'y', 'z']='x') ->None:
axis : Literal["x", "y", "z"], optional
Which axis to make concise. The default is "x".
"""
- pass
+ import matplotlib.dates as mdates
+
+ locator = mdates.AutoDateLocator()
+ formatter = mdates.ConciseDateFormatter(locator)
+ _axis = getattr(ax, f"{axis}axis")
+ _axis.set_major_locator(locator)
+ _axis.set_major_formatter(formatter)
diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py
index 678f4a45..2a4c17e1 100644
--- a/xarray/testing/assertions.py
+++ b/xarray/testing/assertions.py
@@ -1,10 +1,13 @@
"""Testing functions exposed to the user API"""
+
import functools
import warnings
from collections.abc import Hashable
from typing import Union, overload
+
import numpy as np
import pandas as pd
+
from xarray.core import duck_array_ops, formatting, utils
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
@@ -15,8 +18,41 @@ from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_in
from xarray.core.variable import IndexVariable, Variable
+def ensure_warnings(func):
+ # sometimes tests elevate warnings to errors
+ # -> make sure that does not happen in the assert_* functions
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ __tracebackhide__ = True
+
+ with warnings.catch_warnings():
+ # only remove filters that would "error"
+ warnings.filters = [f for f in warnings.filters if f[0] != "error"]
+
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def _decode_string_data(data):
+ if data.dtype.kind == "S":
+ return np.char.decode(data, "utf-8", "replace")
+ return data
+
+
+def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=True):
+ if any(arr.dtype.kind == "S" for arr in [arr1, arr2]) and decode_bytes:
+ arr1 = _decode_string_data(arr1)
+ arr2 = _decode_string_data(arr2)
+ exact_dtypes = ["M", "m", "O", "S", "U"]
+ if any(arr.dtype.kind in exact_dtypes for arr in [arr1, arr2]):
+ return duck_array_ops.array_equiv(arr1, arr2)
+ else:
+ return duck_array_ops.allclose_or_equiv(arr1, arr2, rtol=rtol, atol=atol)
+
+
@ensure_warnings
-def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool=False):
+def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False):
"""
Two DataTrees are considered isomorphic if every node has the same number of children.
@@ -44,16 +80,43 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool=False):
assert_equal
assert_identical
"""
- pass
+ __tracebackhide__ = True
+ assert isinstance(a, type(b))
+
+ if isinstance(a, DataTree):
+ if from_root:
+ a = a.root
+ b = b.root
+
+ assert a.isomorphic(b, from_root=from_root), diff_datatree_repr(
+ a, b, "isomorphic"
+ )
+ else:
+ raise TypeError(f"{type(a)} not of type DataTree")
def maybe_transpose_dims(a, b, check_dim_order: bool):
"""Helper for assert_equal/allclose/identical"""
- pass
+ __tracebackhide__ = True
+ if not isinstance(a, (Variable, DataArray, Dataset)):
+ return b
+ if not check_dim_order and set(a.dims) == set(b.dims):
+ # Ensure transpose won't fail if a dimension is missing
+ # If this is the case, the difference will be caught by the caller
+ return b.transpose(*a.dims)
+ return b
+
+
+@overload
+def assert_equal(a, b): ...
+
+
+@overload
+def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...
@ensure_warnings
-def assert_equal(a, b, from_root=True, check_dim_order: bool=True):
+def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.
@@ -84,7 +147,33 @@ def assert_equal(a, b, from_root=True, check_dim_order: bool=True):
assert_identical, assert_allclose, Dataset.equals, DataArray.equals
numpy.testing.assert_array_equal
"""
- pass
+ __tracebackhide__ = True
+ assert (
+ type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates)
+ )
+ b = maybe_transpose_dims(a, b, check_dim_order)
+ if isinstance(a, (Variable, DataArray)):
+ assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
+ elif isinstance(a, Dataset):
+ assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals")
+ elif isinstance(a, Coordinates):
+ assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
+ elif isinstance(a, DataTree):
+ if from_root:
+ a = a.root
+ b = b.root
+
+ assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
+ else:
+ raise TypeError(f"{type(a)} not supported by assertion comparison")
+
+
+@overload
+def assert_identical(a, b): ...
+
+
+@overload
+def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...
@ensure_warnings
@@ -115,12 +204,35 @@ def assert_identical(a, b, from_root=True):
--------
assert_equal, assert_allclose, Dataset.equals, DataArray.equals
"""
- pass
+ __tracebackhide__ = True
+ assert (
+ type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates)
+ )
+ if isinstance(a, Variable):
+ assert a.identical(b), formatting.diff_array_repr(a, b, "identical")
+ elif isinstance(a, DataArray):
+ assert a.name == b.name
+ assert a.identical(b), formatting.diff_array_repr(a, b, "identical")
+ elif isinstance(a, (Dataset, Variable)):
+ assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical")
+ elif isinstance(a, Coordinates):
+ assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
+ elif isinstance(a, DataTree):
+ if from_root:
+ a = a.root
+ b = b.root
+
+ assert a.identical(b, from_root=from_root), diff_datatree_repr(
+ a, b, "identical"
+ )
+ else:
+ raise TypeError(f"{type(a)} not supported by assertion comparison")
@ensure_warnings
-def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True,
- check_dim_order: bool=True):
+def assert_allclose(
+ a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True
+):
"""Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects.
Raises an AssertionError if two objects are not equal up to desired
@@ -147,20 +259,93 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True,
--------
assert_identical, assert_equal, numpy.testing.assert_allclose
"""
- pass
+ __tracebackhide__ = True
+ assert type(a) == type(b)
+ b = maybe_transpose_dims(a, b, check_dim_order)
+
+ equiv = functools.partial(
+ _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes
+ )
+ equiv.__name__ = "allclose" # type: ignore[attr-defined]
+
+ def compat_variable(a, b):
+ a = getattr(a, "variable", a)
+ b = getattr(b, "variable", b)
+ return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data))
+
+ if isinstance(a, Variable):
+ allclose = compat_variable(a, b)
+ assert allclose, formatting.diff_array_repr(a, b, compat=equiv)
+ elif isinstance(a, DataArray):
+ allclose = utils.dict_equiv(
+ a.coords, b.coords, compat=compat_variable
+ ) and compat_variable(a.variable, b.variable)
+ assert allclose, formatting.diff_array_repr(a, b, compat=equiv)
+ elif isinstance(a, Dataset):
+ allclose = a._coord_names == b._coord_names and utils.dict_equiv(
+ a.variables, b.variables, compat=compat_variable
+ )
+ assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv)
+ else:
+ raise TypeError(f"{type(a)} not supported by assertion comparison")
+
+
+def _format_message(x, y, err_msg, verbose):
+ diff = x - y
+ abs_diff = max(abs(diff))
+ rel_diff = "not implemented"
+
+ n_diff = np.count_nonzero(diff)
+ n_total = diff.size
+
+ fraction = f"{n_diff} / {n_total}"
+ percentage = float(n_diff / n_total * 100)
+
+ parts = [
+ "Arrays are not equal",
+ err_msg,
+ f"Mismatched elements: {fraction} ({percentage:.0f}%)",
+ f"Max absolute difference: {abs_diff}",
+ f"Max relative difference: {rel_diff}",
+ ]
+ if verbose:
+ parts += [
+ f" x: {x!r}",
+ f" y: {y!r}",
+ ]
+
+ return "\n".join(parts)
@ensure_warnings
-def assert_duckarray_allclose(actual, desired, rtol=1e-07, atol=0, err_msg=
- '', verbose=True):
+def assert_duckarray_allclose(
+ actual, desired, rtol=1e-07, atol=0, err_msg="", verbose=True
+):
"""Like `np.testing.assert_allclose`, but for duckarrays."""
- pass
+ __tracebackhide__ = True
+
+ allclose = duck_array_ops.allclose_or_equiv(actual, desired, rtol=rtol, atol=atol)
+ assert allclose, _format_message(actual, desired, err_msg=err_msg, verbose=verbose)
@ensure_warnings
-def assert_duckarray_equal(x, y, err_msg='', verbose=True):
+def assert_duckarray_equal(x, y, err_msg="", verbose=True):
"""Like `np.testing.assert_array_equal`, but for duckarrays"""
- pass
+ __tracebackhide__ = True
+
+ if not utils.is_duck_array(x) and not utils.is_scalar(x):
+ x = np.asarray(x)
+
+ if not utils.is_duck_array(y) and not utils.is_scalar(y):
+ y = np.asarray(y)
+
+ if (utils.is_duck_array(x) and utils.is_scalar(y)) or (
+ utils.is_scalar(x) and utils.is_duck_array(y)
+ ):
+ equiv = (x == y).all()
+ else:
+ equiv = duck_array_ops.array_equiv(x, y)
+ assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose)
def assert_chunks_equal(a, b):
@@ -174,15 +359,166 @@ def assert_chunks_equal(a, b):
b : xarray.Dataset or xarray.DataArray
The second object to compare.
"""
- pass
-
-def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset,
- Variable], check_default_indexes: bool):
+ if isinstance(a, DataArray) != isinstance(b, DataArray):
+ raise TypeError("a and b have mismatched types")
+
+ left = a.unify_chunks()
+ right = b.unify_chunks()
+ assert left.chunks == right.chunks
+
+
+def _assert_indexes_invariants_checks(
+ indexes, possible_coord_variables, dims, check_default=True
+):
+ assert isinstance(indexes, dict), indexes
+ assert all(isinstance(v, Index) for v in indexes.values()), {
+ k: type(v) for k, v in indexes.items()
+ }
+
+ index_vars = {
+ k for k, v in possible_coord_variables.items() if isinstance(v, IndexVariable)
+ }
+ assert indexes.keys() <= index_vars, (set(indexes), index_vars)
+
+ # check pandas index wrappers vs. coordinate data adapters
+ for k, index in indexes.items():
+ if isinstance(index, PandasIndex):
+ pd_index = index.index
+ var = possible_coord_variables[k]
+ assert (index.dim,) == var.dims, (pd_index, var)
+ if k == index.dim:
+ # skip multi-index levels here (checked below)
+ assert index.coord_dtype == var.dtype, (index.coord_dtype, var.dtype)
+ assert isinstance(var._data.array, pd.Index), var._data.array
+ # TODO: check identity instead of equality?
+ assert pd_index.equals(var._data.array), (pd_index, var)
+ if isinstance(index, PandasMultiIndex):
+ pd_index = index.index
+ for name in index.index.names:
+ assert name in possible_coord_variables, (pd_index, index_vars)
+ var = possible_coord_variables[name]
+ assert (index.dim,) == var.dims, (pd_index, var)
+ assert index.level_coords_dtype[name] == var.dtype, (
+ index.level_coords_dtype[name],
+ var.dtype,
+ )
+ assert isinstance(var._data.array, pd.MultiIndex), var._data.array
+ assert pd_index.equals(var._data.array), (pd_index, var)
+ # check all all levels are in `indexes`
+ assert name in indexes, (name, set(indexes))
+ # index identity is used to find unique indexes in `indexes`
+ assert index is indexes[name], (pd_index, indexes[name].index)
+
+ if check_default:
+ defaults = default_indexes(possible_coord_variables, dims)
+ assert indexes.keys() == defaults.keys(), (set(indexes), set(defaults))
+ assert all(v.equals(defaults[k]) for k, v in indexes.items()), (
+ indexes,
+ defaults,
+ )
+
+
+def _assert_variable_invariants(var: Variable, name: Hashable = None):
+ if name is None:
+ name_or_empty: tuple = ()
+ else:
+ name_or_empty = (name,)
+ assert isinstance(var._dims, tuple), name_or_empty + (var._dims,)
+ assert len(var._dims) == len(var._data.shape), name_or_empty + (
+ var._dims,
+ var._data.shape,
+ )
+ assert isinstance(var._encoding, (type(None), dict)), name_or_empty + (
+ var._encoding,
+ )
+ assert isinstance(var._attrs, (type(None), dict)), name_or_empty + (var._attrs,)
+
+
+def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):
+ assert isinstance(da._variable, Variable), da._variable
+ _assert_variable_invariants(da._variable)
+
+ assert isinstance(da._coords, dict), da._coords
+ assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
+ assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
+ da.dims,
+ {k: v.dims for k, v in da._coords.items()},
+ )
+ assert all(
+ isinstance(v, IndexVariable) for (k, v) in da._coords.items() if v.dims == (k,)
+ ), {k: type(v) for k, v in da._coords.items()}
+ for k, v in da._coords.items():
+ _assert_variable_invariants(v, k)
+
+ if da._indexes is not None:
+ _assert_indexes_invariants_checks(
+ da._indexes, da._coords, da.dims, check_default=check_default_indexes
+ )
+
+
+def _assert_dataset_invariants(ds: Dataset, check_default_indexes: bool):
+ assert isinstance(ds._variables, dict), type(ds._variables)
+ assert all(isinstance(v, Variable) for v in ds._variables.values()), ds._variables
+ for k, v in ds._variables.items():
+ _assert_variable_invariants(v, k)
+
+ assert isinstance(ds._coord_names, set), ds._coord_names
+ assert ds._coord_names <= ds._variables.keys(), (
+ ds._coord_names,
+ set(ds._variables),
+ )
+
+ assert type(ds._dims) is dict, ds._dims # noqa: E721
+ assert all(isinstance(v, int) for v in ds._dims.values()), ds._dims
+ var_dims: set[Hashable] = set()
+ for v in ds._variables.values():
+ var_dims.update(v.dims)
+ assert ds._dims.keys() == var_dims, (set(ds._dims), var_dims)
+ assert all(
+ ds._dims[k] == v.sizes[k] for v in ds._variables.values() for k in v.sizes
+ ), (ds._dims, {k: v.sizes for k, v in ds._variables.items()})
+
+ if check_default_indexes:
+ assert all(
+ isinstance(v, IndexVariable)
+ for (k, v) in ds._variables.items()
+ if v.dims == (k,)
+ ), {k: type(v) for k, v in ds._variables.items() if v.dims == (k,)}
+
+ if ds._indexes is not None:
+ _assert_indexes_invariants_checks(
+ ds._indexes, ds._variables, ds._dims, check_default=check_default_indexes
+ )
+
+ assert isinstance(ds._encoding, (type(None), dict))
+ assert isinstance(ds._attrs, (type(None), dict))
+
+
+def _assert_internal_invariants(
+ xarray_obj: Union[DataArray, Dataset, Variable], check_default_indexes: bool
+):
"""Validate that an xarray object satisfies its own internal invariants.
This exists for the benefit of xarray's own test suite, but may be useful
in external projects if they (ill-advisedly) create objects using xarray's
private APIs.
"""
- pass
+ if isinstance(xarray_obj, Variable):
+ _assert_variable_invariants(xarray_obj)
+ elif isinstance(xarray_obj, DataArray):
+ _assert_dataarray_invariants(
+ xarray_obj, check_default_indexes=check_default_indexes
+ )
+ elif isinstance(xarray_obj, Dataset):
+ _assert_dataset_invariants(
+ xarray_obj, check_default_indexes=check_default_indexes
+ )
+ elif isinstance(xarray_obj, Coordinates):
+ _assert_dataset_invariants(
+ xarray_obj.to_dataset(), check_default_indexes=check_default_indexes
+ )
+ else:
+ raise TypeError(
+ f"{type(xarray_obj)} is not a supported type for xarray invariant checks"
+ )
diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py
index 11355cbf..085b70e5 100644
--- a/xarray/testing/strategies.py
+++ b/xarray/testing/strategies.py
@@ -1,31 +1,46 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Protocol, Union, overload
+
try:
import hypothesis.strategies as st
except ImportError as e:
raise ImportError(
- '`xarray.testing.strategies` requires `hypothesis` to be installed.'
- ) from e
+ "`xarray.testing.strategies` requires `hypothesis` to be installed."
+ ) from e
+
import hypothesis.extra.numpy as npst
import numpy as np
from hypothesis.errors import InvalidArgument
+
import xarray as xr
from xarray.core.types import T_DuckArray
+
if TYPE_CHECKING:
from xarray.core.types import _DTypeLikeNested, _ShapeLike
-__all__ = ['supported_dtypes', 'pandas_index_dtypes', 'names',
- 'dimension_names', 'dimension_sizes', 'attrs', 'variables',
- 'unique_subset_of']
-class ArrayStrategyFn(Protocol[T_DuckArray]):
+__all__ = [
+ "supported_dtypes",
+ "pandas_index_dtypes",
+ "names",
+ "dimension_names",
+ "dimension_sizes",
+ "attrs",
+ "variables",
+ "unique_subset_of",
+]
+
- def __call__(self, *, shape: '_ShapeLike', dtype: '_DTypeLikeNested'
- ) ->st.SearchStrategy[T_DuckArray]:
- ...
+class ArrayStrategyFn(Protocol[T_DuckArray]):
+ def __call__(
+ self,
+ *,
+ shape: "_ShapeLike",
+ dtype: "_DTypeLikeNested",
+ ) -> st.SearchStrategy[T_DuckArray]: ...
-def supported_dtypes() ->st.SearchStrategy[np.dtype]:
+def supported_dtypes() -> st.SearchStrategy[np.dtype]:
"""
Generates only those numpy dtypes which xarray can handle.
@@ -38,21 +53,43 @@ def supported_dtypes() ->st.SearchStrategy[np.dtype]:
--------
:ref:`testing.hypothesis`_
"""
- pass
-
-
-def pandas_index_dtypes() ->st.SearchStrategy[np.dtype]:
+ # TODO should this be exposed publicly?
+ # We should at least decide what the set of numpy dtypes that xarray officially supports is.
+ return (
+ npst.integer_dtypes(endianness="=")
+ | npst.unsigned_integer_dtypes(endianness="=")
+ | npst.floating_dtypes(endianness="=")
+ | npst.complex_number_dtypes(endianness="=")
+ # | npst.datetime64_dtypes()
+ # | npst.timedelta64_dtypes()
+ # | npst.unicode_string_dtypes()
+ )
+
+
+def pandas_index_dtypes() -> st.SearchStrategy[np.dtype]:
"""
Dtypes supported by pandas indexes.
Restrict datetime64 and timedelta64 to ns frequency till Xarray relaxes that.
"""
- pass
+ return (
+ npst.integer_dtypes(endianness="=", sizes=(32, 64))
+ | npst.unsigned_integer_dtypes(endianness="=", sizes=(32, 64))
+ | npst.floating_dtypes(endianness="=", sizes=(32, 64))
+ # TODO: unset max_period
+ | npst.datetime64_dtypes(endianness="=", max_period="ns")
+ # TODO: set max_period="D"
+ | npst.timedelta64_dtypes(endianness="=", max_period="ns")
+ | npst.unicode_string_dtypes(endianness="=")
+ )
-_readable_characters = st.characters(categories=['L', 'N'], max_codepoint=383)
+# TODO Generalize to all valid unicode characters once formatting bugs in xarray's reprs are fixed + docs can handle it.
+_readable_characters = st.characters(
+ categories=["L", "N"], max_codepoint=0x017F
+) # only use characters within the "Latin Extended-A" subset of unicode
-def names() ->st.SearchStrategy[str]:
+def names() -> st.SearchStrategy[str]:
"""
Generates arbitrary string names for dimensions / variables.
@@ -62,11 +99,19 @@ def names() ->st.SearchStrategy[str]:
--------
:ref:`testing.hypothesis`_
"""
- pass
-
-
-def dimension_names(*, name_strategy=names(), min_dims: int=0, max_dims: int=3
- ) ->st.SearchStrategy[list[Hashable]]:
+ return st.text(
+ _readable_characters,
+ min_size=1,
+ max_size=5,
+ )
+
+
+def dimension_names(
+ *,
+ name_strategy=names(),
+ min_dims: int = 0,
+ max_dims: int = 3,
+) -> st.SearchStrategy[list[Hashable]]:
"""
Generates an arbitrary list of valid dimension names.
@@ -81,12 +126,23 @@ def dimension_names(*, name_strategy=names(), min_dims: int=0, max_dims: int=3
max_dims
Maximum number of dimensions in generated list.
"""
- pass
-
-def dimension_sizes(*, dim_names: st.SearchStrategy[Hashable]=names(),
- min_dims: int=0, max_dims: int=3, min_side: int=1, max_side: Union[int,
- None]=None) ->st.SearchStrategy[Mapping[Hashable, int]]:
+ return st.lists(
+ elements=name_strategy,
+ min_size=min_dims,
+ max_size=max_dims,
+ unique=True,
+ )
+
+
+def dimension_sizes(
+ *,
+ dim_names: st.SearchStrategy[Hashable] = names(),
+ min_dims: int = 0,
+ max_dims: int = 3,
+ min_side: int = 1,
+ max_side: Union[int, None] = None,
+) -> st.SearchStrategy[Mapping[Hashable, int]]:
"""
Generates an arbitrary mapping from dimension names to lengths.
@@ -114,19 +170,38 @@ def dimension_sizes(*, dim_names: st.SearchStrategy[Hashable]=names(),
--------
:ref:`testing.hypothesis`_
"""
- pass
+ if max_side is None:
+ max_side = min_side + 3
-_readable_strings = st.text(_readable_characters, max_size=5)
+ return st.dictionaries(
+ keys=dim_names,
+ values=st.integers(min_value=min_side, max_value=max_side),
+ min_size=min_dims,
+ max_size=max_dims,
+ )
+
+
+_readable_strings = st.text(
+ _readable_characters,
+ max_size=5,
+)
_attr_keys = _readable_strings
-_small_arrays = npst.arrays(shape=npst.array_shapes(max_side=2, max_dims=2),
- dtype=npst.scalar_dtypes() | npst.byte_string_dtypes() | npst.
- unicode_string_dtypes())
+_small_arrays = npst.arrays(
+ shape=npst.array_shapes(
+ max_side=2,
+ max_dims=2,
+ ),
+ dtype=npst.scalar_dtypes()
+ | npst.byte_string_dtypes()
+ | npst.unicode_string_dtypes(),
+)
_attr_values = st.none() | st.booleans() | _readable_strings | _small_arrays
+
simple_attrs = st.dictionaries(_attr_keys, _attr_values)
-def attrs() ->st.SearchStrategy[Mapping[Hashable, Any]]:
+def attrs() -> st.SearchStrategy[Mapping[Hashable, Any]]:
"""
Generates arbitrary valid attributes dictionaries for xarray objects.
@@ -138,15 +213,25 @@ def attrs() ->st.SearchStrategy[Mapping[Hashable, Any]]:
--------
:ref:`testing.hypothesis`_
"""
- pass
+ return st.recursive(
+ st.dictionaries(_attr_keys, _attr_values),
+ lambda children: st.dictionaries(_attr_keys, children),
+ max_leaves=3,
+ )
@st.composite
-def variables(draw: st.DrawFn, *, array_strategy_fn: Union[ArrayStrategyFn,
- None]=None, dims: Union[st.SearchStrategy[Union[Sequence[Hashable],
- Mapping[Hashable, int]]], None]=None, dtype: st.SearchStrategy[np.dtype
- ]=supported_dtypes(), attrs: st.SearchStrategy[Mapping]=attrs()
- ) ->xr.Variable:
+def variables(
+ draw: st.DrawFn,
+ *,
+ array_strategy_fn: Union[ArrayStrategyFn, None] = None,
+ dims: Union[
+ st.SearchStrategy[Union[Sequence[Hashable], Mapping[Hashable, int]]],
+ None,
+ ] = None,
+ dtype: st.SearchStrategy[np.dtype] = supported_dtypes(),
+ attrs: st.SearchStrategy[Mapping] = attrs(),
+) -> xr.Variable:
"""
Generates arbitrary xarray.Variable objects.
@@ -228,13 +313,108 @@ def variables(draw: st.DrawFn, *, array_strategy_fn: Union[ArrayStrategyFn,
--------
:ref:`testing.hypothesis`_
"""
- pass
+
+ if not isinstance(dims, st.SearchStrategy) and dims is not None:
+ raise InvalidArgument(
+ f"dims must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dims)}. "
+ "To specify fixed contents, use hypothesis.strategies.just()."
+ )
+ if not isinstance(dtype, st.SearchStrategy) and dtype is not None:
+ raise InvalidArgument(
+ f"dtype must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(dtype)}. "
+ "To specify fixed contents, use hypothesis.strategies.just()."
+ )
+ if not isinstance(attrs, st.SearchStrategy) and attrs is not None:
+ raise InvalidArgument(
+ f"attrs must be provided as a hypothesis.strategies.SearchStrategy object (or None), but got type {type(attrs)}. "
+ "To specify fixed contents, use hypothesis.strategies.just()."
+ )
+
+ _array_strategy_fn: ArrayStrategyFn
+ if array_strategy_fn is None:
+ # For some reason if I move the default value to the function signature definition mypy incorrectly says the ignore is no longer necessary, making it impossible to satisfy mypy
+ _array_strategy_fn = npst.arrays # type: ignore[assignment] # npst.arrays has extra kwargs that we aren't using later
+ elif not callable(array_strategy_fn):
+ raise InvalidArgument(
+ "array_strategy_fn must be a Callable that accepts the kwargs dtype and shape and returns a hypothesis "
+ "strategy which generates corresponding array-like objects."
+ )
+ else:
+ _array_strategy_fn = (
+ array_strategy_fn # satisfy mypy that this new variable cannot be None
+ )
+
+ _dtype = draw(dtype)
+
+ if dims is not None:
+ # generate dims first then draw data to match
+ _dims = draw(dims)
+ if isinstance(_dims, Sequence):
+ dim_names = list(_dims)
+ valid_shapes = npst.array_shapes(min_dims=len(_dims), max_dims=len(_dims))
+ _shape = draw(valid_shapes)
+ array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
+ elif isinstance(_dims, (Mapping, dict)):
+ # should be a mapping of form {dim_names: lengths}
+ dim_names, _shape = list(_dims.keys()), tuple(_dims.values())
+ array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
+ else:
+ raise InvalidArgument(
+ f"Invalid type returned by dims strategy - drew an object of type {type(dims)}"
+ )
+ else:
+ # nothing provided, so generate everything consistently
+ # We still generate the shape first here just so that we always pass shape to array_strategy_fn
+ _shape = draw(npst.array_shapes())
+ array_strategy = _array_strategy_fn(shape=_shape, dtype=_dtype)
+ dim_names = draw(dimension_names(min_dims=len(_shape), max_dims=len(_shape)))
+
+ _data = draw(array_strategy)
+
+ if _data.shape != _shape:
+ raise ValueError(
+ "array_strategy_fn returned an array object with a different shape than it was passed."
+ f"Passed {_shape}, but returned {_data.shape}."
+ "Please either specify a consistent shape via the dims kwarg or ensure the array_strategy_fn callable "
+ "obeys the shape argument passed to it."
+ )
+ if _data.dtype != _dtype:
+ raise ValueError(
+ "array_strategy_fn returned an array object with a different dtype than it was passed."
+ f"Passed {_dtype}, but returned {_data.dtype}"
+ "Please either specify a consistent dtype via the dtype kwarg or ensure the array_strategy_fn callable "
+ "obeys the dtype argument passed to it."
+ )
+
+ return xr.Variable(dims=dim_names, data=_data, attrs=draw(attrs))
+
+
+@overload
+def unique_subset_of(
+ objs: Sequence[Hashable],
+ *,
+ min_size: int = 0,
+ max_size: Union[int, None] = None,
+) -> st.SearchStrategy[Sequence[Hashable]]: ...
+
+
+@overload
+def unique_subset_of(
+ objs: Mapping[Hashable, Any],
+ *,
+ min_size: int = 0,
+ max_size: Union[int, None] = None,
+) -> st.SearchStrategy[Mapping[Hashable, Any]]: ...
@st.composite
-def unique_subset_of(draw: st.DrawFn, objs: Union[Sequence[Hashable],
- Mapping[Hashable, Any]], *, min_size: int=0, max_size: Union[int, None]
- =None) ->Union[Sequence[Hashable], Mapping[Hashable, Any]]:
+def unique_subset_of(
+ draw: st.DrawFn,
+ objs: Union[Sequence[Hashable], Mapping[Hashable, Any]],
+ *,
+ min_size: int = 0,
+ max_size: Union[int, None] = None,
+) -> Union[Sequence[Hashable], Mapping[Hashable, Any]]:
"""
Return a strategy which generates a unique subset of the given objects.
@@ -268,4 +448,25 @@ def unique_subset_of(draw: st.DrawFn, objs: Union[Sequence[Hashable],
--------
:ref:`testing.hypothesis`_
"""
- pass
+ if not isinstance(objs, Iterable):
+ raise TypeError(
+ f"Object to sample from must be an Iterable or a Mapping, but received type {type(objs)}"
+ )
+
+ if len(objs) == 0:
+ raise ValueError("Can't sample from a length-zero object.")
+
+ keys = list(objs.keys()) if isinstance(objs, Mapping) else objs
+
+ subset_keys = draw(
+ st.lists(
+ st.sampled_from(keys),
+ unique=True,
+ min_size=min_size,
+ max_size=max_size,
+ )
+ )
+
+ return (
+ {k: objs[k] for k in subset_keys} if isinstance(objs, Mapping) else subset_keys
+ )
diff --git a/xarray/tutorial.py b/xarray/tutorial.py
index 9e891efe..82bb3940 100644
--- a/xarray/tutorial.py
+++ b/xarray/tutorial.py
@@ -5,27 +5,88 @@ Useful for:
* building tutorials in the documentation.
"""
+
from __future__ import annotations
+
import os
import pathlib
from typing import TYPE_CHECKING
+
import numpy as np
+
from xarray.backends.api import open_dataset as _open_dataset
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
+
if TYPE_CHECKING:
from xarray.backends.api import T_Engine
-_default_cache_dir_name = 'xarray_tutorial_data'
-base_url = 'https://github.com/pydata/xarray-data'
-version = 'master'
-external_urls = {}
-file_formats = {'air_temperature': 3, 'air_temperature_gradient': 4,
- 'ASE_ice_velocity': 4, 'basin_mask': 4, 'ersstv5': 4, 'rasm': 3,
- 'ROMS_example': 4, 'tiny': 3, 'eraint_uvz': 3}
-def open_dataset(name: str, cache: bool=True, cache_dir: (None | str | os.
- PathLike)=None, *, engine: T_Engine=None, **kws) ->Dataset:
+_default_cache_dir_name = "xarray_tutorial_data"
+base_url = "https://github.com/pydata/xarray-data"
+version = "master"
+
+
+def _construct_cache_dir(path):
+ import pooch
+
+ if isinstance(path, os.PathLike):
+ path = os.fspath(path)
+ elif path is None:
+ path = pooch.os_cache(_default_cache_dir_name)
+
+ return path
+
+
+external_urls = {} # type: dict
+file_formats = {
+ "air_temperature": 3,
+ "air_temperature_gradient": 4,
+ "ASE_ice_velocity": 4,
+ "basin_mask": 4,
+ "ersstv5": 4,
+ "rasm": 3,
+ "ROMS_example": 4,
+ "tiny": 3,
+ "eraint_uvz": 3,
+}
+
+
+def _check_netcdf_engine_installed(name):
+ version = file_formats.get(name)
+ if version == 3:
+ try:
+ import scipy # noqa
+ except ImportError:
+ try:
+ import netCDF4 # noqa
+ except ImportError:
+ raise ImportError(
+ f"opening tutorial dataset {name} requires either scipy or "
+ "netCDF4 to be installed."
+ )
+ if version == 4:
+ try:
+ import h5netcdf # noqa
+ except ImportError:
+ try:
+ import netCDF4 # noqa
+ except ImportError:
+ raise ImportError(
+ f"opening tutorial dataset {name} requires either h5netcdf "
+ "or netCDF4 to be installed."
+ )
+
+
+# idea borrowed from Seaborn
+def open_dataset(
+ name: str,
+ cache: bool = True,
+ cache_dir: None | str | os.PathLike = None,
+ *,
+ engine: T_Engine = None,
+ **kws,
+) -> Dataset:
"""
Open a dataset from the online repository (requires internet).
@@ -62,10 +123,51 @@ def open_dataset(name: str, cache: bool=True, cache_dir: (None | str | os.
open_dataset
load_dataset
"""
- pass
-
-
-def load_dataset(*args, **kwargs) ->Dataset:
+ try:
+ import pooch
+ except ImportError as e:
+ raise ImportError(
+ "tutorial.open_dataset depends on pooch to download and manage datasets."
+ " To proceed please install pooch."
+ ) from e
+
+ logger = pooch.get_logger()
+ logger.setLevel("WARNING")
+
+ cache_dir = _construct_cache_dir(cache_dir)
+ if name in external_urls:
+ url = external_urls[name]
+ else:
+ path = pathlib.Path(name)
+ if not path.suffix:
+ # process the name
+ default_extension = ".nc"
+ if engine is None:
+ _check_netcdf_engine_installed(name)
+ path = path.with_suffix(default_extension)
+ elif path.suffix == ".grib":
+ if engine is None:
+ engine = "cfgrib"
+ try:
+ import cfgrib # noqa
+ except ImportError as e:
+ raise ImportError(
+ "Reading this tutorial dataset requires the cfgrib package."
+ ) from e
+
+ url = f"{base_url}/raw/{version}/{path.name}"
+
+ # retrieve the file
+ filepath = pooch.retrieve(url=url, known_hash=None, path=cache_dir)
+ ds = _open_dataset(filepath, engine=engine, **kws)
+ if not cache:
+ ds = ds.load()
+ pathlib.Path(filepath).unlink()
+
+ return ds
+
+
+def load_dataset(*args, **kwargs) -> Dataset:
"""
Open, load into memory, and close a dataset from the online repository
(requires internet).
@@ -102,10 +204,11 @@ def load_dataset(*args, **kwargs) ->Dataset:
open_dataset
load_dataset
"""
- pass
+ with open_dataset(*args, **kwargs) as ds:
+ return ds.load()
-def scatter_example_dataset(*, seed: (None | int)=None) ->Dataset:
+def scatter_example_dataset(*, seed: None | int = None) -> Dataset:
"""
Create an example dataset.
@@ -114,4 +217,28 @@ def scatter_example_dataset(*, seed: (None | int)=None) ->Dataset:
seed : int, optional
Seed for the random number generation.
"""
- pass
+ rng = np.random.default_rng(seed)
+ A = DataArray(
+ np.zeros([3, 11, 4, 4]),
+ dims=["x", "y", "z", "w"],
+ coords={
+ "x": np.arange(3),
+ "y": np.linspace(0, 1, 11),
+ "z": np.arange(4),
+ "w": 0.1 * rng.standard_normal(4),
+ },
+ )
+ B = 0.1 * A.x**2 + A.y**2.5 + 0.1 * A.z * A.w
+ A = -0.1 * A.x + A.y / (5 + A.z) + A.w
+ ds = Dataset({"A": A, "B": B})
+ ds["w"] = ["one", "two", "three", "five"]
+
+ ds.x.attrs["units"] = "xunits"
+ ds.y.attrs["units"] = "yunits"
+ ds.z.attrs["units"] = "zunits"
+ ds.w.attrs["units"] = "wunits"
+
+ ds.A.attrs["units"] = "Aunits"
+ ds.B.attrs["units"] = "Bunits"
+
+ return ds
diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py
index 65af5b80..3d52253e 100644
--- a/xarray/util/deprecation_helpers.py
+++ b/xarray/util/deprecation_helpers.py
@@ -1,16 +1,52 @@
+# For reference, here is a copy of the scikit-learn copyright notice:
+
+# BSD 3-Clause License
+
+# Copyright (c) 2007-2021 The scikit-learn developers.
+# All rights reserved.
+
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+
+# * Redistributions of source code must retain the above copyright notice, this
+# list of conditions and the following disclaimer.
+
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+
+# * Neither the name of the copyright holder nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
+
+
import inspect
import warnings
from functools import wraps
from typing import Callable, TypeVar
+
from xarray.core.utils import emit_user_level_warning
-T = TypeVar('T', bound=Callable)
+
+T = TypeVar("T", bound=Callable)
+
POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD
KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY
POSITIONAL_ONLY = inspect.Parameter.POSITIONAL_ONLY
EMPTY = inspect.Parameter.empty
-def _deprecate_positional_args(version) ->Callable[[T], T]:
+def _deprecate_positional_args(version) -> Callable[[T], T]:
"""Decorator for methods that issues warnings for positional arguments
Using the keyword-only argument syntax in pep 3102, arguments after the
@@ -39,13 +75,70 @@ def _deprecate_positional_args(version) ->Callable[[T], T]:
This function is adapted from scikit-learn under the terms of its license. See
licences/SCIKIT_LEARN_LICENSE
"""
- pass
+ def _decorator(func):
+ signature = inspect.signature(func)
-def deprecate_dims(func: T, old_name='dims') ->T:
+ pos_or_kw_args = []
+ kwonly_args = []
+ for name, param in signature.parameters.items():
+ if param.kind in (POSITIONAL_OR_KEYWORD, POSITIONAL_ONLY):
+ pos_or_kw_args.append(name)
+ elif param.kind == KEYWORD_ONLY:
+ kwonly_args.append(name)
+ if param.default is EMPTY:
+ # IMHO `def f(a, *, b):` does not make sense -> disallow it
+ # if removing this constraint -> need to add these to kwargs as well
+ raise TypeError("Keyword-only param without default disallowed.")
+
+ @wraps(func)
+ def inner(*args, **kwargs):
+ name = func.__name__
+ n_extra_args = len(args) - len(pos_or_kw_args)
+ if n_extra_args > 0:
+ extra_args = ", ".join(kwonly_args[:n_extra_args])
+
+ warnings.warn(
+ f"Passing '{extra_args}' as positional argument(s) to {name} "
+ f"was deprecated in version {version} and will raise an error two "
+ "releases later. Please pass them as keyword arguments."
+ "",
+ FutureWarning,
+ stacklevel=2,
+ )
+
+ zip_args = zip(kwonly_args[:n_extra_args], args[-n_extra_args:])
+ kwargs.update({name: arg for name, arg in zip_args})
+
+ return func(*args[:-n_extra_args], **kwargs)
+
+ return func(*args, **kwargs)
+
+ return inner
+
+ return _decorator
+
+
+def deprecate_dims(func: T, old_name="dims") -> T:
"""
For functions that previously took `dims` as a kwarg, and have now transitioned to
`dim`. This decorator will issue a warning if `dims` is passed while forwarding it
to `dim`.
"""
- pass
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if old_name in kwargs:
+ emit_user_level_warning(
+ f"The `{old_name}` argument has been renamed to `dim`, and will be removed "
+ "in the future. This renaming is taking place throughout xarray over the "
+ "next few releases.",
+ # Upgrade to `DeprecationWarning` in the future, when the renaming is complete.
+ PendingDeprecationWarning,
+ )
+ kwargs["dim"] = kwargs.pop(old_name)
+ return func(*args, **kwargs)
+
+ # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing
+ # within the function.
+ return wrapper # type: ignore
diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py
index b2eadf0c..d93c94b1 100644
--- a/xarray/util/generate_aggregations.py
+++ b/xarray/util/generate_aggregations.py
@@ -12,10 +12,13 @@ The second run of pytest is deliberate, since the first will return an error
while replacing the doctests.
"""
+
import collections
import textwrap
from dataclasses import dataclass, field
-MODULE_PREAMBLE = """""\"Mixin classes with reduction operations.""\"
+
+MODULE_PREAMBLE = '''\
+"""Mixin classes with reduction operations."""
# This file was generated using xarray.util.generate_aggregations. Do not edit manually.
@@ -34,8 +37,10 @@ if TYPE_CHECKING:
from xarray.core.dataset import Dataset
flox_available = module_available("flox")
-"""
-NAMED_ARRAY_MODULE_PREAMBLE = """""\"Mixin classes with reduction operations.""\"
+'''
+
+NAMED_ARRAY_MODULE_PREAMBLE = '''\
+"""Mixin classes with reduction operations."""
# This file was generated using xarray.util.generate_aggregations. Do not edit manually.
from __future__ import annotations
@@ -45,7 +50,8 @@ from typing import Any, Callable
from xarray.core import duck_array_ops
from xarray.core.types import Dims, Self
-"""
+'''
+
AGGREGATIONS_PREAMBLE = """
class {obj}{cls}Aggregations:
@@ -62,6 +68,7 @@ class {obj}{cls}Aggregations:
**kwargs: Any,
) -> Self:
raise NotImplementedError()"""
+
NAMED_ARRAY_AGGREGATIONS_PREAMBLE = """
class {obj}{cls}Aggregations:
@@ -77,6 +84,8 @@ class {obj}{cls}Aggregations:
**kwargs: Any,
) -> Self:
raise NotImplementedError()"""
+
+
GROUPBY_PREAMBLE = """
class {obj}{cls}Aggregations:
@@ -100,6 +109,7 @@ class {obj}{cls}Aggregations:
**kwargs: Any,
) -> {obj}:
raise NotImplementedError()"""
+
RESAMPLE_PREAMBLE = """
class {obj}{cls}Aggregations:
@@ -123,18 +133,20 @@ class {obj}{cls}Aggregations:
**kwargs: Any,
) -> {obj}:
raise NotImplementedError()"""
-TEMPLATE_REDUCTION_SIGNATURE = """
+
+TEMPLATE_REDUCTION_SIGNATURE = '''
def {method}(
self,
dim: Dims = None,{kw_only}{extra_kwargs}{keep_attrs}
**kwargs: Any,
) -> Self:
- ""\"
+ """
Reduce this {obj}'s data by applying ``{method}`` along some dimension(s).
Parameters
- ----------"""
-TEMPLATE_REDUCTION_SIGNATURE_GROUPBY = """
+ ----------'''
+
+TEMPLATE_REDUCTION_SIGNATURE_GROUPBY = '''
def {method}(
self,
dim: Dims = None,
@@ -142,89 +154,104 @@ TEMPLATE_REDUCTION_SIGNATURE_GROUPBY = """
keep_attrs: bool | None = None,
**kwargs: Any,
) -> {obj}:
- ""\"
+ """
Reduce this {obj}'s data by applying ``{method}`` along some dimension(s).
Parameters
- ----------"""
+ ----------'''
+
TEMPLATE_RETURNS = """
Returns
-------
reduced : {obj}
New {obj} with ``{method}`` applied to its data and the
indicated dimension(s) removed"""
+
TEMPLATE_SEE_ALSO = """
See Also
--------
{see_also_methods}
:ref:`{docref}`
User guide on {docref_description}."""
+
TEMPLATE_NOTES = """
Notes
-----
{notes}"""
+
_DIM_DOCSTRING = """dim : str, Iterable of Hashable, "..." or None, default: None
Name of dimension[s] along which to apply ``{method}``. For e.g. ``dim="x"``
or ``dim=["x", "y"]``. If "..." or None, will reduce over all dimensions."""
+
_DIM_DOCSTRING_GROUPBY = """dim : str, Iterable of Hashable, "..." or None, default: None
Name of dimension[s] along which to apply ``{method}``. For e.g. ``dim="x"``
or ``dim=["x", "y"]``. If None, will reduce over the {cls} dimensions.
If "...", will reduce over all dimensions."""
+
_SKIPNA_DOCSTRING = """skipna : bool or None, optional
If True, skip missing values (as marked by NaN). By default, only
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or ``skipna=True`` has not been
implemented (object, datetime64 or timedelta64)."""
+
_MINCOUNT_DOCSTRING = """min_count : int or None, optional
The required number of valid values to perform the operation. If
fewer than min_count non-NA values are present the result will be
NA. Only used if skipna is set to True or defaults to True for the
array's dtype. Changed in version 0.17.0: if specified on an integer
array and skipna=True, the result will be a float array."""
+
_DDOF_DOCSTRING = """ddof : int, default: 0
“Delta Degrees of Freedom”: the divisor used in the calculation is ``N - ddof``,
where ``N`` represents the number of elements."""
+
_KEEP_ATTRS_DOCSTRING = """keep_attrs : bool or None, optional
If True, ``attrs`` will be copied from the original
object to the new one. If False, the new object will be
returned without attributes."""
+
_KWARGS_DOCSTRING = """**kwargs : Any
Additional keyword arguments passed on to the appropriate array
function for calculating ``{method}`` on this object's data.
These could include dask-specific kwargs like ``split_every``."""
-_NUMERIC_ONLY_NOTES = (
- 'Non-numeric variables will be removed prior to reducing.')
+
+_NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing."
+
_FLOX_NOTES_TEMPLATE = """Use the ``flox`` package to significantly speed up {kind} computations,
especially with dask arrays. Xarray will use flox by default if installed.
Pass flox-specific keyword arguments in ``**kwargs``.
See the `flox documentation <https://flox.readthedocs.io>`_ for more."""
-_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind='groupby')
-_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format(kind='resampling')
-ExtraKwarg = collections.namedtuple('ExtraKwarg', 'docs kwarg call example')
-skipna = ExtraKwarg(docs=_SKIPNA_DOCSTRING, kwarg=
- 'skipna: bool | None = None,', call='skipna=skipna,', example=
- """
-
+_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="groupby")
+_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="resampling")
+
+ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example")
+skipna = ExtraKwarg(
+ docs=_SKIPNA_DOCSTRING,
+ kwarg="skipna: bool | None = None,",
+ call="skipna=skipna,",
+ example="""\n
Use ``skipna`` to control whether NaNs are ignored.
- >>> {calculation}(skipna=False)"""
- )
-min_count = ExtraKwarg(docs=_MINCOUNT_DOCSTRING, kwarg=
- 'min_count: int | None = None,', call='min_count=min_count,', example=
- """
-
+ >>> {calculation}(skipna=False)""",
+)
+min_count = ExtraKwarg(
+ docs=_MINCOUNT_DOCSTRING,
+ kwarg="min_count: int | None = None,",
+ call="min_count=min_count,",
+ example="""\n
Specify ``min_count`` for finer control over when NaNs are ignored.
- >>> {calculation}(skipna=True, min_count=2)"""
- )
-ddof = ExtraKwarg(docs=_DDOF_DOCSTRING, kwarg='ddof: int = 0,', call=
- 'ddof=ddof,', example=
- """
-
+ >>> {calculation}(skipna=True, min_count=2)""",
+)
+ddof = ExtraKwarg(
+ docs=_DDOF_DOCSTRING,
+ kwarg="ddof: int = 0,",
+ call="ddof=ddof,",
+ example="""\n
Specify ``ddof=1`` for an unbiased estimate.
- >>> {calculation}(skipna=True, ddof=1)"""
- )
+ >>> {calculation}(skipna=True, ddof=1)""",
+)
@dataclass
@@ -237,29 +264,36 @@ class DataStructure:
class Method:
-
- def __init__(self, name, bool_reduce=False, extra_kwargs=tuple(),
- numeric_only=False, see_also_modules=('numpy', 'dask.array'),
- min_flox_version=None):
+ def __init__(
+ self,
+ name,
+ bool_reduce=False,
+ extra_kwargs=tuple(),
+ numeric_only=False,
+ see_also_modules=("numpy", "dask.array"),
+ min_flox_version=None,
+ ):
self.name = name
self.extra_kwargs = extra_kwargs
self.numeric_only = numeric_only
self.see_also_modules = see_also_modules
self.min_flox_version = min_flox_version
if bool_reduce:
- self.array_method = f'array_{name}'
+ self.array_method = f"array_{name}"
self.np_example_array = """
... np.array([True, True, True, True, True, False], dtype=bool)"""
+
else:
self.array_method = name
- self.np_example_array = (
- '\n ... np.array([1, 2, 3, 0, 2, np.nan])')
+ self.np_example_array = """
+ ... np.array([1, 2, 3, 0, 2, np.nan])"""
@dataclass
class AggregationGenerator:
_dim_docstring = _DIM_DOCSTRING
_template_signature = TEMPLATE_REDUCTION_SIGNATURE
+
cls: str
datastructure: DataStructure
methods: tuple[Method, ...]
@@ -268,37 +302,217 @@ class AggregationGenerator:
example_call_preamble: str
definition_preamble: str
has_keep_attrs: bool = True
- notes: str = ''
+ notes: str = ""
preamble: str = field(init=False)
def __post_init__(self):
- self.preamble = self.definition_preamble.format(obj=self.
- datastructure.name, cls=self.cls)
+ self.preamble = self.definition_preamble.format(
+ obj=self.datastructure.name, cls=self.cls
+ )
+
+ def generate_methods(self):
+ yield [self.preamble]
+ for method in self.methods:
+ yield self.generate_method(method)
+
+ def generate_method(self, method):
+ has_kw_only = method.extra_kwargs or self.has_keep_attrs
+
+ template_kwargs = dict(
+ obj=self.datastructure.name,
+ method=method.name,
+ keep_attrs=(
+ "\n keep_attrs: bool | None = None,"
+ if self.has_keep_attrs
+ else ""
+ ),
+ kw_only="\n *," if has_kw_only else "",
+ )
+
+ if method.extra_kwargs:
+ extra_kwargs = "\n " + "\n ".join(
+ [kwarg.kwarg for kwarg in method.extra_kwargs if kwarg.kwarg]
+ )
+ else:
+ extra_kwargs = ""
+
+ yield self._template_signature.format(
+ **template_kwargs,
+ extra_kwargs=extra_kwargs,
+ )
+
+ for text in [
+ self._dim_docstring.format(method=method.name, cls=self.cls),
+ *(kwarg.docs for kwarg in method.extra_kwargs if kwarg.docs),
+ _KEEP_ATTRS_DOCSTRING if self.has_keep_attrs else None,
+ _KWARGS_DOCSTRING.format(method=method.name),
+ ]:
+ if text:
+ yield textwrap.indent(text, 8 * " ")
+
+ yield TEMPLATE_RETURNS.format(**template_kwargs)
+
+ # we want Dataset.count to refer to DataArray.count
+ # but we also want DatasetGroupBy.count to refer to Dataset.count
+ # The generic aggregations have self.cls == ''
+ others = (
+ self.datastructure.see_also_modules
+ if self.cls == ""
+ else (self.datastructure.name,)
+ )
+ see_also_methods = "\n".join(
+ " " * 8 + f"{mod}.{method.name}"
+ for mod in (method.see_also_modules + others)
+ )
+ # Fixes broken links mentioned in #8055
+ yield TEMPLATE_SEE_ALSO.format(
+ **template_kwargs,
+ docref=self.docref,
+ docref_description=self.docref_description,
+ see_also_methods=see_also_methods,
+ )
+
+ notes = self.notes
+ if method.numeric_only:
+ if notes != "":
+ notes += "\n\n"
+ notes += _NUMERIC_ONLY_NOTES
+
+ if notes != "":
+ yield TEMPLATE_NOTES.format(notes=textwrap.indent(notes, 8 * " "))
+
+ yield textwrap.indent(self.generate_example(method=method), "")
+ yield ' """'
+
+ yield self.generate_code(method, self.has_keep_attrs)
+
+ def generate_example(self, method):
+ created = self.datastructure.create_example.format(
+ example_array=method.np_example_array
+ )
+ calculation = f"{self.datastructure.example_var_name}{self.example_call_preamble}.{method.name}"
+ if method.extra_kwargs:
+ extra_examples = "".join(
+ kwarg.example for kwarg in method.extra_kwargs if kwarg.example
+ ).format(calculation=calculation, method=method.name)
+ else:
+ extra_examples = ""
+
+ return f"""
+ Examples
+ --------{created}
+ >>> {self.datastructure.example_var_name}
+
+ >>> {calculation}(){extra_examples}"""
class GroupByAggregationGenerator(AggregationGenerator):
_dim_docstring = _DIM_DOCSTRING_GROUPBY
_template_signature = TEMPLATE_REDUCTION_SIGNATURE_GROUPBY
+ def generate_code(self, method, has_keep_attrs):
+ extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call]
+
+ if self.datastructure.numeric_only:
+ extra_kwargs.append(f"numeric_only={method.numeric_only},")
+
+ # median isn't enabled yet, because it would break if a single group was present in multiple
+ # chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median
+ method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod")
+ if method_is_not_flox_supported:
+ indent = 12
+ else:
+ indent = 16
+
+ if extra_kwargs:
+ extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), indent * " ")
+ else:
+ extra_kwargs = ""
+
+ if method_is_not_flox_supported:
+ return f"""\
+ return self.reduce(
+ duck_array_ops.{method.array_method},
+ dim=dim,{extra_kwargs}
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )"""
+
+ min_version_check = f"""
+ and module_available("flox", minversion="{method.min_flox_version}")"""
+
+ return (
+ """\
+ if (
+ flox_available
+ and OPTIONS["use_flox"]"""
+ + (min_version_check if method.min_flox_version is not None else "")
+ + f"""
+ and contains_only_chunked_or_numpy(self._obj)
+ ):
+ return self._flox_reduce(
+ func="{method.name}",
+ dim=dim,{extra_kwargs}
+ # fill_value=fill_value,
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )
+ else:
+ return self.reduce(
+ duck_array_ops.{method.array_method},
+ dim=dim,{extra_kwargs}
+ keep_attrs=keep_attrs,
+ **kwargs,
+ )"""
+ )
+
class GenericAggregationGenerator(AggregationGenerator):
- pass
-
-
-AGGREGATION_METHODS = Method('count', see_also_modules=('pandas.DataFrame',
- 'dask.dataframe.DataFrame')), Method('all', bool_reduce=True), Method('any'
- , bool_reduce=True), Method('max', extra_kwargs=(skipna,)), Method('min',
- extra_kwargs=(skipna,)), Method('mean', extra_kwargs=(skipna,),
- numeric_only=True), Method('prod', extra_kwargs=(skipna, min_count),
- numeric_only=True), Method('sum', extra_kwargs=(skipna, min_count),
- numeric_only=True), Method('std', extra_kwargs=(skipna, ddof),
- numeric_only=True), Method('var', extra_kwargs=(skipna, ddof),
- numeric_only=True), Method('median', extra_kwargs=(skipna,),
- numeric_only=True, min_flox_version='0.9.2'), Method('cumsum',
- extra_kwargs=(skipna,), numeric_only=True), Method('cumprod',
- extra_kwargs=(skipna,), numeric_only=True)
-DATASET_OBJECT = DataStructure(name='Dataset', create_example=
- """
+ def generate_code(self, method, has_keep_attrs):
+ extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call]
+
+ if self.datastructure.numeric_only:
+ extra_kwargs.append(f"numeric_only={method.numeric_only},")
+
+ if extra_kwargs:
+ extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), 12 * " ")
+ else:
+ extra_kwargs = ""
+ keep_attrs = (
+ "\n" + 12 * " " + "keep_attrs=keep_attrs," if has_keep_attrs else ""
+ )
+ return f"""\
+ return self.reduce(
+ duck_array_ops.{method.array_method},
+ dim=dim,{extra_kwargs}{keep_attrs}
+ **kwargs,
+ )"""
+
+
+AGGREGATION_METHODS = (
+ # Reductions:
+ Method("count", see_also_modules=("pandas.DataFrame", "dask.dataframe.DataFrame")),
+ Method("all", bool_reduce=True),
+ Method("any", bool_reduce=True),
+ Method("max", extra_kwargs=(skipna,)),
+ Method("min", extra_kwargs=(skipna,)),
+ Method("mean", extra_kwargs=(skipna,), numeric_only=True),
+ Method("prod", extra_kwargs=(skipna, min_count), numeric_only=True),
+ Method("sum", extra_kwargs=(skipna, min_count), numeric_only=True),
+ Method("std", extra_kwargs=(skipna, ddof), numeric_only=True),
+ Method("var", extra_kwargs=(skipna, ddof), numeric_only=True),
+ Method(
+ "median", extra_kwargs=(skipna,), numeric_only=True, min_flox_version="0.9.2"
+ ),
+ # Cumulatives:
+ Method("cumsum", extra_kwargs=(skipna,), numeric_only=True),
+ Method("cumprod", extra_kwargs=(skipna,), numeric_only=True),
+)
+
+
+DATASET_OBJECT = DataStructure(
+ name="Dataset",
+ create_example="""
>>> da = xr.DataArray({example_array},
... dims="time",
... coords=dict(
@@ -306,69 +520,137 @@ DATASET_OBJECT = DataStructure(name='Dataset', create_example=
... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])),
... ),
... )
- >>> ds = xr.Dataset(dict(da=da))"""
- , example_var_name='ds', numeric_only=True, see_also_modules=('DataArray',)
- )
-DATAARRAY_OBJECT = DataStructure(name='DataArray', create_example=
- """
+ >>> ds = xr.Dataset(dict(da=da))""",
+ example_var_name="ds",
+ numeric_only=True,
+ see_also_modules=("DataArray",),
+)
+DATAARRAY_OBJECT = DataStructure(
+ name="DataArray",
+ create_example="""
>>> da = xr.DataArray({example_array},
... dims="time",
... coords=dict(
... time=("time", pd.date_range("2001-01-01", freq="ME", periods=6)),
... labels=("time", np.array(["a", "b", "c", "c", "b", "a"])),
... ),
- ... )"""
- , example_var_name='da', numeric_only=False, see_also_modules=('Dataset',))
-DATASET_GENERATOR = GenericAggregationGenerator(cls='', datastructure=
- DATASET_OBJECT, methods=AGGREGATION_METHODS, docref='agg',
- docref_description='reduction or aggregation operations',
- example_call_preamble='', definition_preamble=AGGREGATIONS_PREAMBLE)
-DATAARRAY_GENERATOR = GenericAggregationGenerator(cls='', datastructure=
- DATAARRAY_OBJECT, methods=AGGREGATION_METHODS, docref='agg',
- docref_description='reduction or aggregation operations',
- example_call_preamble='', definition_preamble=AGGREGATIONS_PREAMBLE)
-DATAARRAY_GROUPBY_GENERATOR = GroupByAggregationGenerator(cls='GroupBy',
- datastructure=DATAARRAY_OBJECT, methods=AGGREGATION_METHODS, docref=
- 'groupby', docref_description='groupby operations',
- example_call_preamble='.groupby("labels")', definition_preamble=
- GROUPBY_PREAMBLE, notes=_FLOX_GROUPBY_NOTES)
-DATAARRAY_RESAMPLE_GENERATOR = GroupByAggregationGenerator(cls='Resample',
- datastructure=DATAARRAY_OBJECT, methods=AGGREGATION_METHODS, docref=
- 'resampling', docref_description='resampling operations',
- example_call_preamble='.resample(time="3ME")', definition_preamble=
- RESAMPLE_PREAMBLE, notes=_FLOX_RESAMPLE_NOTES)
-DATASET_GROUPBY_GENERATOR = GroupByAggregationGenerator(cls='GroupBy',
- datastructure=DATASET_OBJECT, methods=AGGREGATION_METHODS, docref=
- 'groupby', docref_description='groupby operations',
- example_call_preamble='.groupby("labels")', definition_preamble=
- GROUPBY_PREAMBLE, notes=_FLOX_GROUPBY_NOTES)
-DATASET_RESAMPLE_GENERATOR = GroupByAggregationGenerator(cls='Resample',
- datastructure=DATASET_OBJECT, methods=AGGREGATION_METHODS, docref=
- 'resampling', docref_description='resampling operations',
- example_call_preamble='.resample(time="3ME")', definition_preamble=
- RESAMPLE_PREAMBLE, notes=_FLOX_RESAMPLE_NOTES)
-NAMED_ARRAY_OBJECT = DataStructure(name='NamedArray', create_example=
- """
+ ... )""",
+ example_var_name="da",
+ numeric_only=False,
+ see_also_modules=("Dataset",),
+)
+DATASET_GENERATOR = GenericAggregationGenerator(
+ cls="",
+ datastructure=DATASET_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="agg",
+ docref_description="reduction or aggregation operations",
+ example_call_preamble="",
+ definition_preamble=AGGREGATIONS_PREAMBLE,
+)
+DATAARRAY_GENERATOR = GenericAggregationGenerator(
+ cls="",
+ datastructure=DATAARRAY_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="agg",
+ docref_description="reduction or aggregation operations",
+ example_call_preamble="",
+ definition_preamble=AGGREGATIONS_PREAMBLE,
+)
+DATAARRAY_GROUPBY_GENERATOR = GroupByAggregationGenerator(
+ cls="GroupBy",
+ datastructure=DATAARRAY_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="groupby",
+ docref_description="groupby operations",
+ example_call_preamble='.groupby("labels")',
+ definition_preamble=GROUPBY_PREAMBLE,
+ notes=_FLOX_GROUPBY_NOTES,
+)
+DATAARRAY_RESAMPLE_GENERATOR = GroupByAggregationGenerator(
+ cls="Resample",
+ datastructure=DATAARRAY_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="resampling",
+ docref_description="resampling operations",
+ example_call_preamble='.resample(time="3ME")',
+ definition_preamble=RESAMPLE_PREAMBLE,
+ notes=_FLOX_RESAMPLE_NOTES,
+)
+DATASET_GROUPBY_GENERATOR = GroupByAggregationGenerator(
+ cls="GroupBy",
+ datastructure=DATASET_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="groupby",
+ docref_description="groupby operations",
+ example_call_preamble='.groupby("labels")',
+ definition_preamble=GROUPBY_PREAMBLE,
+ notes=_FLOX_GROUPBY_NOTES,
+)
+DATASET_RESAMPLE_GENERATOR = GroupByAggregationGenerator(
+ cls="Resample",
+ datastructure=DATASET_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="resampling",
+ docref_description="resampling operations",
+ example_call_preamble='.resample(time="3ME")',
+ definition_preamble=RESAMPLE_PREAMBLE,
+ notes=_FLOX_RESAMPLE_NOTES,
+)
+
+NAMED_ARRAY_OBJECT = DataStructure(
+ name="NamedArray",
+ create_example="""
>>> from xarray.namedarray.core import NamedArray
>>> na = NamedArray(
... "x",{example_array},
- ... )"""
- , example_var_name='na', numeric_only=False, see_also_modules=(
- 'Dataset', 'DataArray'))
-NAMED_ARRAY_GENERATOR = GenericAggregationGenerator(cls='', datastructure=
- NAMED_ARRAY_OBJECT, methods=AGGREGATION_METHODS, docref='agg',
- docref_description='reduction or aggregation operations',
- example_call_preamble='', definition_preamble=
- NAMED_ARRAY_AGGREGATIONS_PREAMBLE, has_keep_attrs=False)
-if __name__ == '__main__':
+ ... )""",
+ example_var_name="na",
+ numeric_only=False,
+ see_also_modules=("Dataset", "DataArray"),
+)
+
+NAMED_ARRAY_GENERATOR = GenericAggregationGenerator(
+ cls="",
+ datastructure=NAMED_ARRAY_OBJECT,
+ methods=AGGREGATION_METHODS,
+ docref="agg",
+ docref_description="reduction or aggregation operations",
+ example_call_preamble="",
+ definition_preamble=NAMED_ARRAY_AGGREGATIONS_PREAMBLE,
+ has_keep_attrs=False,
+)
+
+
+def write_methods(filepath, generators, preamble):
+ with open(filepath, mode="w", encoding="utf-8") as f:
+ f.write(preamble)
+ for gen in generators:
+ for lines in gen.generate_methods():
+ for line in lines:
+ f.write(line + "\n")
+
+
+if __name__ == "__main__":
import os
from pathlib import Path
+
p = Path(os.getcwd())
- write_methods(filepath=p.parent / 'xarray' / 'xarray' / 'core' /
- '_aggregations.py', generators=[DATASET_GENERATOR,
- DATAARRAY_GENERATOR, DATASET_GROUPBY_GENERATOR,
- DATASET_RESAMPLE_GENERATOR, DATAARRAY_GROUPBY_GENERATOR,
- DATAARRAY_RESAMPLE_GENERATOR], preamble=MODULE_PREAMBLE)
- write_methods(filepath=p.parent / 'xarray' / 'xarray' / 'namedarray' /
- '_aggregations.py', generators=[NAMED_ARRAY_GENERATOR], preamble=
- NAMED_ARRAY_MODULE_PREAMBLE)
+ write_methods(
+ filepath=p.parent / "xarray" / "xarray" / "core" / "_aggregations.py",
+ generators=[
+ DATASET_GENERATOR,
+ DATAARRAY_GENERATOR,
+ DATASET_GROUPBY_GENERATOR,
+ DATASET_RESAMPLE_GENERATOR,
+ DATAARRAY_GROUPBY_GENERATOR,
+ DATAARRAY_RESAMPLE_GENERATOR,
+ ],
+ preamble=MODULE_PREAMBLE,
+ )
+ write_methods(
+ filepath=p.parent / "xarray" / "xarray" / "namedarray" / "_aggregations.py",
+ generators=[NAMED_ARRAY_GENERATOR],
+ preamble=NAMED_ARRAY_MODULE_PREAMBLE,
+ )
+ # filepath = p.parent / "core" / "_aggregations.py" # Run from script location
diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py
index 9818559a..a9f66cdc 100644
--- a/xarray/util/generate_ops.py
+++ b/xarray/util/generate_ops.py
@@ -6,33 +6,78 @@ Usage:
python xarray/util/generate_ops.py > xarray/core/_typed_ops.py
"""
+
+# Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some
+# background to some of the design choices made here.
+
from __future__ import annotations
+
from collections.abc import Iterator, Sequence
from typing import Optional
-BINOPS_EQNE = ('__eq__', 'nputils.array_eq'), ('__ne__', 'nputils.array_ne')
-BINOPS_CMP = ('__lt__', 'operator.lt'), ('__le__', 'operator.le'), ('__gt__',
- 'operator.gt'), ('__ge__', 'operator.ge')
-BINOPS_NUM = ('__add__', 'operator.add'), ('__sub__', 'operator.sub'), (
- '__mul__', 'operator.mul'), ('__pow__', 'operator.pow'), ('__truediv__',
- 'operator.truediv'), ('__floordiv__', 'operator.floordiv'), ('__mod__',
- 'operator.mod'), ('__and__', 'operator.and_'), ('__xor__', 'operator.xor'
- ), ('__or__', 'operator.or_'), ('__lshift__', 'operator.lshift'), (
- '__rshift__', 'operator.rshift')
-BINOPS_REFLEXIVE = ('__radd__', 'operator.add'), ('__rsub__', 'operator.sub'
- ), ('__rmul__', 'operator.mul'), ('__rpow__', 'operator.pow'), (
- '__rtruediv__', 'operator.truediv'), ('__rfloordiv__', 'operator.floordiv'
- ), ('__rmod__', 'operator.mod'), ('__rand__', 'operator.and_'), ('__rxor__'
- , 'operator.xor'), ('__ror__', 'operator.or_')
-BINOPS_INPLACE = ('__iadd__', 'operator.iadd'), ('__isub__', 'operator.isub'
- ), ('__imul__', 'operator.imul'), ('__ipow__', 'operator.ipow'), (
- '__itruediv__', 'operator.itruediv'), ('__ifloordiv__',
- 'operator.ifloordiv'), ('__imod__', 'operator.imod'), ('__iand__',
- 'operator.iand'), ('__ixor__', 'operator.ixor'), ('__ior__', 'operator.ior'
- ), ('__ilshift__', 'operator.ilshift'), ('__irshift__', 'operator.irshift')
-UNARY_OPS = ('__neg__', 'operator.neg'), ('__pos__', 'operator.pos'), (
- '__abs__', 'operator.abs'), ('__invert__', 'operator.invert')
-OTHER_UNARY_METHODS = ('round', 'ops.round_'), ('argsort', 'ops.argsort'), (
- 'conj', 'ops.conj'), ('conjugate', 'ops.conjugate')
+
+BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne"))
+BINOPS_CMP = (
+ ("__lt__", "operator.lt"),
+ ("__le__", "operator.le"),
+ ("__gt__", "operator.gt"),
+ ("__ge__", "operator.ge"),
+)
+BINOPS_NUM = (
+ ("__add__", "operator.add"),
+ ("__sub__", "operator.sub"),
+ ("__mul__", "operator.mul"),
+ ("__pow__", "operator.pow"),
+ ("__truediv__", "operator.truediv"),
+ ("__floordiv__", "operator.floordiv"),
+ ("__mod__", "operator.mod"),
+ ("__and__", "operator.and_"),
+ ("__xor__", "operator.xor"),
+ ("__or__", "operator.or_"),
+ ("__lshift__", "operator.lshift"),
+ ("__rshift__", "operator.rshift"),
+)
+BINOPS_REFLEXIVE = (
+ ("__radd__", "operator.add"),
+ ("__rsub__", "operator.sub"),
+ ("__rmul__", "operator.mul"),
+ ("__rpow__", "operator.pow"),
+ ("__rtruediv__", "operator.truediv"),
+ ("__rfloordiv__", "operator.floordiv"),
+ ("__rmod__", "operator.mod"),
+ ("__rand__", "operator.and_"),
+ ("__rxor__", "operator.xor"),
+ ("__ror__", "operator.or_"),
+)
+BINOPS_INPLACE = (
+ ("__iadd__", "operator.iadd"),
+ ("__isub__", "operator.isub"),
+ ("__imul__", "operator.imul"),
+ ("__ipow__", "operator.ipow"),
+ ("__itruediv__", "operator.itruediv"),
+ ("__ifloordiv__", "operator.ifloordiv"),
+ ("__imod__", "operator.imod"),
+ ("__iand__", "operator.iand"),
+ ("__ixor__", "operator.ixor"),
+ ("__ior__", "operator.ior"),
+ ("__ilshift__", "operator.ilshift"),
+ ("__irshift__", "operator.irshift"),
+)
+UNARY_OPS = (
+ ("__neg__", "operator.neg"),
+ ("__pos__", "operator.pos"),
+ ("__abs__", "operator.abs"),
+ ("__invert__", "operator.invert"),
+)
+# round method and numpy/pandas unary methods which don't modify the data shape,
+# so the result should still be wrapped in an Variable/DataArray/Dataset
+OTHER_UNARY_METHODS = (
+ ("round", "ops.round_"),
+ ("argsort", "ops.argsort"),
+ ("conj", "ops.conj"),
+ ("conjugate", "ops.conjugate"),
+)
+
+
required_method_binary = """
def _binary_op(
self, other: {other_type}, f: Callable, reflexive: bool = False
@@ -53,12 +98,14 @@ template_binop_overload = """
template_reflexive = """
def {method}(self, other: {other_type}) -> {return_type}:
return self._binary_op(other, {func}, reflexive=True)"""
+
required_method_inplace = """
def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self:
raise NotImplementedError"""
template_inplace = """
def {method}(self, other: {other_type}) -> Self:{type_ignore}
return self._inplace_binary_op(other, {func})"""
+
required_method_unary = """
def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError"""
@@ -72,21 +119,126 @@ unhashable = """
# When __eq__ is defined but __hash__ is not, then an object is unhashable,
# and it should be declared as follows:
__hash__: None # type:ignore[assignment]"""
+
+# For some methods we override return type `bool` defined by base class `object`.
+# We need to add "# type: ignore[override]"
+# Keep an eye out for:
+# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240
+# The type ignores might not be necessary anymore at some point.
+#
+# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray
+# In reality this returns NotImplemented, but this is not a valid type in python 3.9.
+# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable)
+# TODO: change once python 3.10 is the minimum.
+#
+# Mypy seems to require that __iadd__ and __add__ have the same signature.
+# This requires some extra type: ignores[misc] in the inplace methods :/
+
+
+def _type_ignore(ignore: str) -> str:
+ return f" # type:ignore[{ignore}]" if ignore else ""
+
+
FuncType = Sequence[tuple[Optional[str], Optional[str]]]
OpsType = tuple[FuncType, str, dict[str, str]]
+
+
+def binops(
+ other_type: str, return_type: str = "Self", type_ignore_eq: str = "override"
+) -> list[OpsType]:
+ extras = {"other_type": other_type, "return_type": return_type}
+ return [
+ ([(None, None)], required_method_binary, extras),
+ (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}),
+ (
+ BINOPS_EQNE,
+ template_binop,
+ extras | {"type_ignore": _type_ignore(type_ignore_eq)},
+ ),
+ ([(None, None)], unhashable, extras),
+ (BINOPS_REFLEXIVE, template_reflexive, extras),
+ ]
+
+
+def binops_overload(
+ other_type: str,
+ overload_type: str,
+ return_type: str = "Self",
+ type_ignore_eq: str = "override",
+) -> list[OpsType]:
+ extras = {"other_type": other_type, "return_type": return_type}
+ return [
+ ([(None, None)], required_method_binary, extras),
+ (
+ BINOPS_NUM + BINOPS_CMP,
+ template_binop_overload,
+ extras
+ | {
+ "overload_type": overload_type,
+ "type_ignore": "",
+ "overload_type_ignore": "",
+ },
+ ),
+ (
+ BINOPS_EQNE,
+ template_binop_overload,
+ extras
+ | {
+ "overload_type": overload_type,
+ "type_ignore": "",
+ "overload_type_ignore": _type_ignore(type_ignore_eq),
+ },
+ ),
+ ([(None, None)], unhashable, extras),
+ (BINOPS_REFLEXIVE, template_reflexive, extras),
+ ]
+
+
+def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]:
+ extras = {"other_type": other_type}
+ return [
+ ([(None, None)], required_method_inplace, extras),
+ (
+ BINOPS_INPLACE,
+ template_inplace,
+ extras | {"type_ignore": _type_ignore(type_ignore)},
+ ),
+ ]
+
+
+def unops() -> list[OpsType]:
+ return [
+ ([(None, None)], required_method_unary, {}),
+ (UNARY_OPS, template_unary, {}),
+ (OTHER_UNARY_METHODS, template_other_unary, {}),
+ ]
+
+
+# We use short names T_DA and T_DS to keep below 88 lines so
+# ruff does not reformat everything. When reformatting, the
+# type-ignores end up in the wrong line :/
+
ops_info = {}
-ops_info['DatasetOpsMixin'] = binops(other_type='DsCompatible') + inplace(
- other_type='DsCompatible') + unops()
-ops_info['DataArrayOpsMixin'] = binops(other_type='DaCompatible') + inplace(
- other_type='DaCompatible') + unops()
-ops_info['VariableOpsMixin'] = binops_overload(other_type='VarCompatible',
- overload_type='T_DA') + inplace(other_type='VarCompatible', type_ignore
- ='misc') + unops()
-ops_info['DatasetGroupByOpsMixin'] = binops(other_type=
- 'Dataset | DataArray', return_type='Dataset')
-ops_info['DataArrayGroupByOpsMixin'] = binops(other_type='T_Xarray',
- return_type='T_Xarray')
-MODULE_PREAMBLE = """""\"Mixin classes with arithmetic operators.""\"
+ops_info["DatasetOpsMixin"] = (
+ binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops()
+)
+ops_info["DataArrayOpsMixin"] = (
+ binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops()
+)
+ops_info["VariableOpsMixin"] = (
+ binops_overload(other_type="VarCompatible", overload_type="T_DA")
+ + inplace(other_type="VarCompatible", type_ignore="misc")
+ + unops()
+)
+ops_info["DatasetGroupByOpsMixin"] = binops(
+ other_type="Dataset | DataArray", return_type="Dataset"
+)
+ops_info["DataArrayGroupByOpsMixin"] = binops(
+ other_type="T_Xarray", return_type="T_Xarray"
+)
+
+MODULE_PREAMBLE = '''\
+"""Mixin classes with arithmetic operators."""
# This file was generated using xarray.util.generate_ops. Do not edit manually.
@@ -107,18 +259,39 @@ from xarray.core.types import (
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
- from xarray.core.types import T_DataArray as T_DA"""
+ from xarray.core.types import T_DataArray as T_DA'''
+
+
CLASS_PREAMBLE = """{newline}
class {cls_name}:
__slots__ = ()"""
-COPY_DOCSTRING = ' {method}.__doc__ = {func}.__doc__'
+COPY_DOCSTRING = """\
+ {method}.__doc__ = {func}.__doc__"""
-def render(ops_info: dict[str, list[OpsType]]) ->Iterator[str]:
+
+def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]:
"""Render the module or stub file."""
- pass
+ yield MODULE_PREAMBLE
+
+ for cls_name, method_blocks in ops_info.items():
+ yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n")
+ yield from _render_classbody(method_blocks)
+
+
+def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]:
+ for method_func_pairs, template, extra in method_blocks:
+ if template:
+ for method, func in method_func_pairs:
+ yield template.format(method=method, func=func, **extra)
+
+ yield ""
+ for method_func_pairs, *_ in method_blocks:
+ for method, func in method_func_pairs:
+ if method and func:
+ yield COPY_DOCSTRING.format(method=method, func=func)
-if __name__ == '__main__':
+if __name__ == "__main__":
for line in render(ops_info):
print(line)
diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py
index 586848d2..0b2e2b05 100755
--- a/xarray/util/print_versions.py
+++ b/xarray/util/print_versions.py
@@ -1,4 +1,5 @@
"""Utility functions for printing version information."""
+
import importlib
import locale
import os
@@ -10,7 +11,71 @@ import sys
def get_sys_info():
"""Returns system information as a dict"""
- pass
+
+ blob = []
+
+ # get full commit hash
+ commit = None
+ if os.path.isdir(".git") and os.path.isdir("xarray"):
+ try:
+ pipe = subprocess.Popen(
+ 'git log --format="%H" -n 1'.split(" "),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ )
+ so, _ = pipe.communicate()
+ except Exception:
+ pass
+ else:
+ if pipe.returncode == 0:
+ commit = so
+ try:
+ commit = so.decode("utf-8")
+ except ValueError:
+ pass
+ commit = commit.strip().strip('"')
+
+ blob.append(("commit", commit))
+
+ try:
+ (sysname, _nodename, release, _version, machine, processor) = platform.uname()
+ blob.extend(
+ [
+ ("python", sys.version),
+ ("python-bits", struct.calcsize("P") * 8),
+ ("OS", f"{sysname}"),
+ ("OS-release", f"{release}"),
+ # ("Version", f"{version}"),
+ ("machine", f"{machine}"),
+ ("processor", f"{processor}"),
+ ("byteorder", f"{sys.byteorder}"),
+ ("LC_ALL", f'{os.environ.get("LC_ALL", "None")}'),
+ ("LANG", f'{os.environ.get("LANG", "None")}'),
+ ("LOCALE", f"{locale.getlocale()}"),
+ ]
+ )
+ except Exception:
+ pass
+
+ return blob
+
+
+def netcdf_and_hdf5_versions():
+ libhdf5_version = None
+ libnetcdf_version = None
+ try:
+ import netCDF4
+
+ libhdf5_version = netCDF4.__hdf5libversion__
+ libnetcdf_version = netCDF4.__netcdf4libversion__
+ except ImportError:
+ try:
+ import h5py
+
+ libhdf5_version = h5py.version.hdf5_version
+ except ImportError:
+ pass
+ return [("libhdf5", libhdf5_version), ("libnetcdf", libnetcdf_version)]
def show_versions(file=sys.stdout):
@@ -21,8 +86,78 @@ def show_versions(file=sys.stdout):
file : file-like, optional
print to the given file-like object. Defaults to sys.stdout.
"""
- pass
+ sys_info = get_sys_info()
+
+ try:
+ sys_info.extend(netcdf_and_hdf5_versions())
+ except Exception as e:
+ print(f"Error collecting netcdf / hdf5 version: {e}")
+
+ deps = [
+ # (MODULE_NAME, f(mod) -> mod version)
+ ("xarray", lambda mod: mod.__version__),
+ ("pandas", lambda mod: mod.__version__),
+ ("numpy", lambda mod: mod.__version__),
+ ("scipy", lambda mod: mod.__version__),
+ # xarray optionals
+ ("netCDF4", lambda mod: mod.__version__),
+ ("pydap", lambda mod: mod.__version__),
+ ("h5netcdf", lambda mod: mod.__version__),
+ ("h5py", lambda mod: mod.__version__),
+ ("zarr", lambda mod: mod.__version__),
+ ("cftime", lambda mod: mod.__version__),
+ ("nc_time_axis", lambda mod: mod.__version__),
+ ("iris", lambda mod: mod.__version__),
+ ("bottleneck", lambda mod: mod.__version__),
+ ("dask", lambda mod: mod.__version__),
+ ("distributed", lambda mod: mod.__version__),
+ ("matplotlib", lambda mod: mod.__version__),
+ ("cartopy", lambda mod: mod.__version__),
+ ("seaborn", lambda mod: mod.__version__),
+ ("numbagg", lambda mod: mod.__version__),
+ ("fsspec", lambda mod: mod.__version__),
+ ("cupy", lambda mod: mod.__version__),
+ ("pint", lambda mod: mod.__version__),
+ ("sparse", lambda mod: mod.__version__),
+ ("flox", lambda mod: mod.__version__),
+ ("numpy_groupies", lambda mod: mod.__version__),
+ # xarray setup/test
+ ("setuptools", lambda mod: mod.__version__),
+ ("pip", lambda mod: mod.__version__),
+ ("conda", lambda mod: mod.__version__),
+ ("pytest", lambda mod: mod.__version__),
+ ("mypy", lambda mod: importlib.metadata.version(mod.__name__)),
+ # Misc.
+ ("IPython", lambda mod: mod.__version__),
+ ("sphinx", lambda mod: mod.__version__),
+ ]
+
+ deps_blob = []
+ for modname, ver_f in deps:
+ try:
+ if modname in sys.modules:
+ mod = sys.modules[modname]
+ else:
+ mod = importlib.import_module(modname)
+ except Exception:
+ deps_blob.append((modname, None))
+ else:
+ try:
+ ver = ver_f(mod)
+ deps_blob.append((modname, ver))
+ except Exception:
+ deps_blob.append((modname, "installed"))
+
+ print("\nINSTALLED VERSIONS", file=file)
+ print("------------------", file=file)
+
+ for k, stat in sys_info:
+ print(f"{k}: {stat}", file=file)
+
+ print("", file=file)
+ for k, stat in deps_blob:
+ print(f"{k}: {stat}", file=file)
-if __name__ == '__main__':
+if __name__ == "__main__":
show_versions()