Source code for stems.io.vrt

""" Create VRTs
"""
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
import logging
from pathlib import Path
import xml.etree.ElementTree as ET
from xml.dom import minidom

from osgeo import gdal
import rasterio
from rasterio.coords import BoundingBox
from rasterio.dtypes import _gdal_typename
from rasterio.transform import rowcol
import six

from ..gis.geom import (bounds_transform_union,
                        calculate_src_window,
                        calculate_dst_window)
from ..utils import cached_property, list_like, relative_to

gdal.UseExceptions()

logger = logging.getLogger(__name__)


_NOBANDS_NOPROPS_ERROR_MSG = (
    'Cannot determine dataset properties without storing any '
    'bands. Add some via ``VRTDataset.add_band``'
)


[docs]class VRTDataset(object): """ Create a VRT from a band in one or more datasets Parameters ---------- separate : bool, optional Put input bands in separate, stacked bands in the output """ def __init__(self, separate=True): self.separate = separate self.root = ET.Element('VRTDataset') self._bands = defaultdict(list) @property def bands(self): """dict[int, list[VRTSourceBand]]: Bands organized by output VRT band """ # Ensure is returned ordered by VRT bidx (the key) # Also ensure we don't return empty list of bands return OrderedDict((k, self._bands[k]) for k in sorted(self._bands) if self._bands[k]) @property def transform(self): """Affine: Affine transform for VRTDataset """ _, transform, _ = self._get_bounds_transform() return transform @property def bounds(self): """BoundingBox: Bounding box of VRTDataset """ bounds, _, _ = self._get_bounds_transform() return bounds @property def width(self): """ int: Number of columns """ return self.shape[1] @property def height(self): """ int: Number of rows """ return self.shape[0] @property def shape(self): """tuple[int, int]: Number of rows and columns """ _, _, shape = self._get_bounds_transform() return shape @property def count(self): """int: Number of output bands in VRT """ return len(self.bands) @property def crs(self): """CRS: VRTDataset coordinate reference system """ if not self.bands: raise ValueError(_NOBANDS_NOPROPS_ERROR_MSG) bands_list = self._bands_to_list() return bands_list[0][1].crs
[docs] @classmethod def from_bands(cls, paths, separate=True, bidx=1, **kwds): """ Parameters ---------- paths : str or list[str] List of paths to open as datasets separate : bool, optional Put input bands in separate, stacked bands in the output bidx : int, list[int], optional Band indices of `datasets` to include. If ``int``, all ``paths`` will use this band ``bidx``. Otherwise, pass a list of band indices for each path in ``paths``. Defaults to ``1``. kwds : dict Keywords to pass to ``VRTSourceBand()`` for each band. Pass a list or tuple as a value to specify different values for each band. Returns ------- VRTDataset VRTDataset initialized from input paths """ if isinstance(paths, six.string_types): paths = (paths, ) if isinstance(bidx, int): bidx = (bidx, ) * len(paths) assert len(paths) == len(bidx) if kwds: for k, v in kwds.items(): if not list_like(v): logger.debug('Found scalar for "{0}" keyword. Duplicating ' 'for each band') kwds[k] = (v, ) * len(paths) assert len(kwds[k]) == len(paths) vrt = cls() for i, (path, bidx_) in enumerate(zip(paths, bidx)): _kwds = {k: v[i] for k, v in kwds.items()} vrt.add_band(path, bidx_, **_kwds) return vrt
[docs] def add_band(self, path, src_bidx=1, vrt_bidx=None, validate=True, **band_kwds): """ Add a band to VRT dataset Parameters ---------- path : str List of paths to open as datasets src_bidx : int, optional Band indices of `datasets` to include. Defaults to ``1`` vrt_bidx : int or None Destination band in VRT for new band. Only used if ``self.separate`` is True validate : bool, optional Validate band forms to expected attributes before adding band_kwds : dict Additional keyword arguments passed onto :py:class:`VRTSourceBand` Returns ------- vrt_bidx: int VRT band index Raises ------ ValueError Raised if vrt_bidx is invalid """ # Handle non-specified vrt_bidx n_band = len(self.bands) if self.separate: vrt_bidx = n_band + 1 if vrt_bidx is None else vrt_bidx else: if vrt_bidx is not None and vrt_bidx != 1: raise ValueError('`vrt_bidx` must be `1` if not stacking into ' 'separate bands (see `self.separate`)') vrt_bidx = 1 if vrt_bidx <= 0: raise ValueError('`vrt_bid` must be greater than 0') vrtband = VRTSourceBand(path, src_bidx, **band_kwds) # Validate if not 1st band if n_band > 0 and validate: self._validate(vrtband) # Append self._bands[vrt_bidx].append(vrtband) return vrt_bidx
[docs] def write(self, path=None, relative=False): """ Save VRT XML data to a filename Parameters ---------- path : str, optional Save VRT to this filename. If ``None``, returns the XML text relative : bool, optional Reference VRT sources relative to the VRT Returns ------- str Filename """ if relative: relative = str(path) xml_ele = _make_vrt_element(self.shape[1], self.shape[0], self.transform, self.crs, self.bands, relative_to_vrt=relative) xmlstr = _make_vrt_str(xml_ele) if path is not None: with open(str(path), 'w') as fid: fid.write(xmlstr) return path else: return xmlstr
[docs] def close(self): """ Close any opened VRTSourceBand(s) """ for bidx in self.bands: for band in self.bands[bidx]: band.close()
def _bands_to_list(self): # () -> [(vrt_bidx, VRTSourceBand), ...] return list([(k, band) for k in self.bands for band in self.bands[k]]) def _validate(self, test_band): # Validate suitability of newly added band if test_band.crs != self.crs: raise ValueError('All bands must have same ``crs``') def _get_bounds_transform(self): if not self.bands: raise ValueError(_NOBANDS_NOPROPS_ERROR_MSG) bands_list = self._bands_to_list() bounds_ = list([b.bounds for idx, b in bands_list]) transforms_ = list([b.transform for idx, b in bands_list]) return bounds_transform_union(bounds_, transforms_[0])
[docs]class VRTSourceBand(object): """ A VRT band originating from some other file Note that all properties on this object return information used for VRT XML generation, but not XML elements (e.g., returns the SubElement name and this new element's text value). Parameters ---------- path : str Filename of dataset containing band src_bidx : int Source band index (begins on 1) description : str, optional Override band description nodata : float or int, optional Override NoDataValue keep_open : bool, optional Keep dataset open """ def __init__(self, path, src_bidx, description=None, nodata=None, keep_open=False): self.path = path self.src_bidx = src_bidx self._desc = description self._ndv = nodata self.keep_open = keep_open self._ds = None
[docs] @contextmanager def open(self): self.start() yield self._ds if not self.keep_open: self._ds = None
[docs] def start(self): """ Open dataset, if closed """ if getattr(self._ds, 'closed', True): logger.debug('Opening dataset for VRTSourceBand') self._ds = rasterio.open(str(self.path), 'r')
[docs] def close(self): """ Close dataset reference, if open """ if self._ds is not None: self._ds.close() self._ds = None
@cached_property def crs(self): with self.open() as ds: return ds.crs @cached_property def transform(self): with self.open() as ds: return ds.transform @cached_property def bounds(self): with self.open() as ds: return ds.bounds @cached_property def width(self): with self.open() as ds: return ds.width @cached_property def height(self): with self.open() as ds: return ds.height @cached_property def shape(self): with self.open() as ds: return (self.height, self.width) @cached_property def dtype(self): with self.open() as ds: return ds.dtypes[self.src_bidx - 1] @cached_property def blockxsize(self): with self.open() as ds: return ds.block_shapes[self.src_bidx - 1][1] @cached_property def blockysize(self): with self.open() as ds: return ds.block_shapes[self.src_bidx - 1][0] @cached_property def nodata(self): if self._ndv is not None: return self._ndv else: with self.open() as ds: return ds.nodatavals[self.src_bidx - 1] @cached_property def description(self): if self._desc is not None: return self._desc else: with self.open() as ds: return ds.descriptions[self.src_bidx - 1] @cached_property def colorinterp(self): with self.open() as ds: return ds.colorinterp[self.src_bidx - 1]
# ---------------------------------------------------------------------------- # XML def _make_vrt_str(root): root_str = ET.tostring(root) return (minidom .parseString(root_str) .toprettyxml(indent=' ')) def _make_vrt_element(vrt_width, vrt_height, vrt_transform, vrt_crs, vrt_bands, relative_to_vrt=None): """Return VRT as XML Element Parameters ---------- vrt_width : int Number of columns vrt_height : int Number of rows vrt_transform : Affine Transform of output VRT vrt_crs : CRS CRS of output VRT vrt_bands : dict[int, list[VRTSourceBand]] VRTSourceBand information, organized by output VRT bidx (e.g., ``{1: [VRTSourceBand], 2: [VRTSourceBand]}`` if separate, or ``{1: [VRTSourceBand, VRTSourceBand]}`` if mosaicing) relative_to_vrt : str or Path Reference VRT sources relative to the VRT at this location Returns ------- xml.etree.ElementTree XML element tree with XML information """ # Needed, but since we can get from other inputs calculate to save space vrt_bounds = BoundingBox(vrt_transform.c, vrt_transform.f + vrt_transform.e * vrt_height, vrt_transform.c + vrt_transform.a * vrt_width, vrt_transform.f) # Create root element root = _make_vrt_root(vrt_width, vrt_height, vrt_transform, vrt_crs) # Create destination VRT bands xml_bands = {} for vrt_bidx in sorted(vrt_bands): # Only create if there are bands to put in if not vrt_bands[vrt_bidx]: continue # For now, first band sets some metadata, like description and NDV # TODO: Allow overrides to pass to `_make_band`! ex = vrt_bands[vrt_bidx][0] with ex.open(): # ensure open the whole time xml_band = _make_band(root, ex, vrt_bidx) xml_bands[vrt_bidx] = xml_band # Add sources for each output VRT band for vrt_bidx, xml_band in xml_bands.items(): for src_band in vrt_bands[vrt_bidx]: with src_band.open(): # ensure open the entire time _make_source(xml_band, src_band, vrt_bounds, vrt_transform, relative_to_vrt=relative_to_vrt) return root def _make_vrt_root(width, height, transform, crs): root = ET.Element('VRTDataset') root.set('rasterXSize', str(width)) root.set('rasterYSize', str(height)) _make_geotransform(root, transform) _make_crs(root, crs) return root def _make_crs(root, crs): ele = _make_subelement(root, 'SRS', crs.wkt) return ele def _make_geotransform(root, transform, precision=9): # Output VRT tranform gt_str = list(str(round(n, precision)) for n in transform.to_gdal()) ele = _make_subelement(root, 'GeoTransform', ', '.join(gt_str)) return ele def _make_band(root, source_band, vrt_bidx, description=None, vrt_ndv=None): # Create <VRTRasterBand> band = ET.SubElement(root, 'VRTRasterBand') band.set('dataType', _gdal_typename(source_band.dtype)) band.set('band', str(vrt_bidx)) # Optional subelements # TODO: ColorTable, GDALRasterAttributeTable, UnitType, # Offset, Scale, CategoryNames colorinterp = gdal.GetColorInterpretationName( source_band.colorinterp.value) _make_subelement(band, 'ColorInterp', colorinterp) # Output band NDV defaults to source NDV vrt_ndv = vrt_ndv if vrt_ndv is not None else source_band.nodata if vrt_ndv is not None: _make_subelement(band, 'NoDataValue', str(vrt_ndv)) description = description or source_band.description if description: band.set('Description', description) return band def _make_source(xml_band, source_band, vrt_bounds, vrt_transform, relative_to_vrt=None): source = ET.SubElement(xml_band, 'ComplexSource') _make_source_path(source, source_band.path, relative_to_vrt=relative_to_vrt) _make_source_band(source, source_band.src_bidx) _make_source_props(source, source_band) # SrcRect and DstRect _make_src_rect(source, source_band.bounds, source_band.transform, vrt_bounds) _make_dst_rect(source, source_band.bounds, vrt_transform) # NODATA if source_band.nodata is not None: _make_subelement(source, 'NODATA', str(source_band.nodata)) return source def _make_source_path(xml_parent, path, relative_to_vrt=None): ele = ET.SubElement(xml_parent, 'SourceFilename') if relative_to_vrt: path = relative_to(path, relative_to_vrt) ele.set('relativeToVRT', '1') else: path = Path(path).absolute() ele.text = str(path) return ele def _make_source_band(xml_parent, src_bidx): # Creates <SourceBand> ... a number ... <SourceBand/> tag return _make_subelement(xml_parent, 'SourceBand', src_bidx) def _make_source_props(xml_parent, source_band): """Creates <SourceProperties ... /> tag Parameters ---------- xml_parent : xml.etree.Element.SubElement Parent XML element, like a "ComplexSource" or "SimpleSource" source_band : VRTSourceBand The source band Returns ------- xml.etree.Element.SubElement "SourceProperties" subelement """ ele = ET.SubElement(xml_parent, 'SourceProperties') ele.set('RasterXSize', str(source_band.width)) ele.set('RasterYSize', str(source_band.height)) ele.set('DataType', _gdal_typename(source_band.dtype)) ele.set('BlockXSize', str(source_band.blockxsize)) ele.set('BlockYSize', str(source_band.blockysize)) return ele def _make_src_rect(xml_parent, src_bounds, src_transform, dst_bounds): win, _ = calculate_src_window(src_bounds, src_transform, dst_bounds) ele = ET.SubElement(xml_parent, 'SrcRect') ele.set('xOff', str(win.col_off)) ele.set('yOff', str(win.row_off)) ele.set('xSize', str(win.width)) ele.set('ySize', str(win.height)) return ele def _make_dst_rect(xml_parent, src_bounds, dst_transform): win = calculate_dst_window(src_bounds, dst_transform) ele = ET.SubElement(xml_parent, 'DstRect') ele.set('xOff', str(win.col_off)) ele.set('yOff', str(win.row_off)) ele.set('xSize', str(win.width)) ele.set('ySize', str(win.height)) return ele def _make_subelement(root, name, text): sub = ET.SubElement(root, name) sub.text = str(text) return sub