Source code for stems.io.encoding

from collections import OrderedDict
from functools import singledispatch
import logging

import numpy as np
import xarray as xr

from ..utils import list_like
from .chunk import chunks_to_chunksizes

logger = logging.getLogger(__name__)


#: NumPy string types (str, bytes, unicode)
_NP_STRING_TYPES = (np.str_, np.bytes_, np.unicode_, )


[docs]@singledispatch def netcdf_encoding(data, chunks=None, zlib=True, complevel=4, nodata=None, **encoding_kwds): """ Return "good" NetCDF encoding information for some data The returned encoding is the default or "good" known standard for data used in ``stems``. Each default determined in this function is given as a keyword argument to allow overriding, and you can also pass additional encoding items via ``**encoding``. You may pass one override for all data, or overrides for each data variable (as a dict). For more information, see the NetCDF4 documentation for the ``createVariable`` [1]. Parameters ---------- data : xr.DataArray or xr.Dataset Define encoding for this data. If ``xr.Dataset``, map function across all ``xr.DataArray`` in ``data.data_vars`` dtype : np.dtype, optional The data type used for the encoded data. Defaults to the input data type(s), but can be set to facilitate discretization based compression (typically alongside scale_factor and _FillValue) chunks : None, tuple or dict, optional Chunksizes used to encode NetCDF. If given as a `tuple`, chunks should be given for each dimension. Chunks for dimensions not specified when given as a `dict` will default to 1. Passing ``False`` will not use chunks. zlib : bool, optional Use compression complevel : int, optional Compression level nodata : int, float, or sequence, optional NoDataValue(s). Specify one for each ``DataArray`` in ``data`` if a :py:class:`xarray.Dataset`. Used for ``_FillValue`` encoding_kwds : dict Additional encoding data to pass Returns ------- dict Dict mapping band name (e.g., variable name) to relevant encoding information See Also -------- xarray.Dataset.to_netcdf Encoding information designed to be passed to :py:meth:`xarray.Dataset.to_netcdf`. References ---------- .. [1] http://unidata.github.io/netcdf4-python/#netCDF4.Dataset.createVariable """ raise TypeError(f'Unknown type for input ``data`` "{type(data)}"')
@netcdf_encoding.register(xr.DataArray) def _netcdf_encoding_dataarray(data, dtype=None, chunks=None, zlib=True, complevel=4, nodata=None, **encoding_kwds): name = encoding_name(data) encoding = {name: {}} # dtype: Determine & guard/fixup dtype_ = guard_dtype(data, dtype or encoding_dtype(data)) assert isinstance(dtype_, dict) encoding[name].update(dtype_) # chunksizes: Determine and guard/fixup if chunks is not False: chunks = encoding_chunksizes(data, chunks=chunks) chunks = guard_chunksizes_str(data, guard_chunksizes(data, chunks)) assert isinstance(chunks, tuple) assert all(isinstance(i, int) for i in chunks) if chunks: encoding[name]['chunksizes'] = tuple(chunks) # _FillValue if nodata is not None: assert isinstance(nodata, (float, int, np.integer, np.floating)) encoding[name]['_FillValue'] = nodata # complevel & zlib: compression encoding[name]['complevel'] = complevel encoding[name]['zlib'] = zlib # Fill in user input encoding[name].update(encoding_kwds) return encoding @netcdf_encoding.register(xr.Dataset) def _netcdf_encoding_dataset(data, dtype=None, chunks=None, zlib=True, complevel=4, nodata=None, **encoding_kwds): encoding = OrderedDict() for var in data.data_vars: # Allow user to specify local (key exists) or global (KeyError) kwds kwds = encoding_kwds.get(var, encoding_kwds) # Construct, unpacking if needed var_encoding = _netcdf_encoding_dataarray( data[var], dtype=dtype[var] if _is_dict(dtype) else dtype, chunks=chunks[var] if _has_key(chunks, var) else chunks, zlib=zlib[var] if _is_dict(zlib) else zlib, complevel=complevel[var] if _is_dict(complevel) else complevel, nodata=nodata[var] if _is_dict(nodata) else nodata, **kwds ) encoding[var] = var_encoding[var] return encoding # ---------------------------------------------------------------------------- # Encoding components for DataArray(s)
[docs]def encoding_name(xarr): """ Return the name of the variable to provide encoding for Either returns the name of the DataArray, or the name that XArray will assign it when writing to disk (:py:data:`xarray.backends.api.DATAARRAY_VARIABLE`). Parameters ---------- xarr : xarray.DataArray Provide the name of this DataArray used for encoding Returns ------- str Encoding variable name """ return xarr.name or xr.backends.api.DATAARRAY_VARIABLE
[docs]def encoding_dtype(xarr): """Get dtype encoding info Parameters ---------- xarr : xarray.DataArray or np.ndarray DataArray to consider Returns ------- dict[str, np.dtype] Datatype information for encoding (e.g., ``{'dtype': np.float32}``) """ return {'dtype': xarr.dtype}
[docs]def encoding_chunksizes(xarr, chunks=None): """ Find/resolve chunksize for a DataArray Parameters ---------- xarr : xarray.DataArray DataArray to consider chunks : tuple[int] or Mapping[str, int] Chunks per dimension Returns ------- tuple[int] Chunksizes per dimension """ if chunks is None: # Grab chunks from DataArray chunksize = chunks_to_chunksizes(xarr) elif isinstance(chunks, dict): # Default to 1 chunk per dimension if none found chunksize = tuple(chunks.get(dim, len(xarr.coords[dim])) for dim in xarr.dims) return chunksize
# ---------------------------------------------------------------------------- # Encoding checks, safeguards, and fixes
[docs]def guard_chunksizes(xarr, chunksizes): """ Guard chunksize to be <= dimension sizes Parameters ---------- xarr : xarray.DataArray DataArray to consider chunksizes : tuple[int] Chunks per dimension Returns ------- tuple[int] Guarded chunksizes """ assert isinstance(xarr, xr.DataArray) and list_like(chunksizes) assert len(chunksizes) == xarr.ndim or not chunksizes chunksizes_ = [] for csize, dim, size in zip(chunksizes, xarr.dims, xarr.shape): if csize > size: logger.warning(f'Chunk size for dim {dim} is larger than dim ' f'size ({size}). Resetting ({csize}->{size})') chunksizes_.append(size) elif csize == -1: # shortcut for length, but doesn't work in encoding chunksizes_.append(size) else: chunksizes_.append(csize) return tuple(chunksizes_)
[docs]def guard_chunksizes_str(xarr, chunksizes): """Guard chunk sizes for str datatypes Chunks for ``str`` need to include string length dimension since python-netcdf represents, for example, 1d char array as 2d array Parameters ---------- xarr : xarray.DataArray DataArray to consider chunksizes : tuple[int] Chunk sizes per dimension Returns ------- tuple[int] Guarded chunk sizes """ if xarr.dtype.type in _NP_STRING_TYPES and chunksizes is not None: if len(chunksizes) == xarr.ndim: logger.debug('Adding chunk size to `str` dtype variable ' f'"{xarr.name}" corresponding to the string length') chunksizes = chunksizes + (xarr.dtype.itemsize, ) elif xarr.dtype.kind == 'O': logger.debug('Removing chunks for `dtype=object` variable ' f'"{xarr.name}"') chunksizes = () return chunksizes
[docs]def guard_dtype(xarr, dtype_): """Guard dtype encoding for datetime datatypes Parameters ---------- xarr : xarray.DataArray or np.ndarray DataArray to consider dtype_ : dict[str, np.dtype] Datatype information for encoding (e.g., ``{'dtype': np.float32}``) Returns ------- dict[str, np.dtype] Datatype information for encoding (e.g., ``{'dtype': np.float32}``), if valid. Otherwise returns empty dict """ # Don't encode datatype for datetime types, since xarray changes it if xarr.dtype.kind == 'M': dtype_ = {} return dtype_
def _is_dict(x): return isinstance(x, dict) def _has_key(d, k): return _is_dict(d) and k in d