Chris Holden (ceholden@gmail.com) - https://github.com/ceholden
matplotlib
¶matplotlib is a very powerful plotting library for making amazing visualizations for publications, personal use, or even web and desktop applications. matplotlib
can create almost any two dimensional visualization you can think of, including histograms, scatter plots, bivariate plots, and image displays. For some inspiration, check out the matplotlib
example gallery which includes the source code required to generate each example.
A great resource for learning matplotlib
is available from J.R. Johansson.
One part of matplotlib
that may be initially confusing is that matplotlib
contains two main methods of making plots - the object-oriented method, and the state-machine method.
While the library can be either used in an object-oriented manner (i.e., you create an object representing the figure, then the figure can spawn objects representing the axes, etc.), the most familiar usage of matplotlib
for MATLAB users is the pyplot
state-machine environment:
from the matplotlib usage FAQ:
Pyplot’s state-machine environment behaves similarly to MATLAB and should be most familiar to users with MATLAB experience.
A very good overview of the difference between the two usages is provided by Jake Vanderplas. Specifically,
In general, you should only use the Pyplot
state-machine environment when plotting data interactively or when developing visualizations for your data. The object-oriented API, while more complicated, is a much more powerful way of creating plots and should be used when developing more complicated visualizations.
As this is a brief introduction to matplotlib
, we will be using the Pyplot
state-machine method for creating visualizations.
# Import the Python 3 print function
from __future__ import print_function
# Import the "gdal" and "gdal_array" submodules from within the "osgeo" module
from osgeo import gdal
from osgeo import gdal_array
# Import the NumPy module
import numpy as np
# Open a GDAL dataset
dataset = gdal.Open('../../example/LE70220491999322EDC01_stack.gtif', gdal.GA_ReadOnly)
# Allocate our array using the first band's datatype
image_datatype = dataset.GetRasterBand(1).DataType
image = np.zeros((dataset.RasterYSize, dataset.RasterXSize, dataset.RasterCount),
dtype=gdal_array.GDALTypeCodeToNumericTypeCode(image_datatype))
# Loop over all bands in dataset
for b in range(dataset.RasterCount):
# Remember, GDAL index is on 1, but Python is on 0 -- so we add 1 for our GDAL calls
band = dataset.GetRasterBand(b + 1)
# Read in the band's data into the third dimension of our array
image[:, :, b] = band.ReadAsArray()
ndvi = (image[:, :, 3] - image[:, :, 2]) / \
(image[:, :, 3] + image[:, :, 2]).astype(np.float64)
With the data read in and NDVI calculated, let's make some plots.
First thing to do is to import matplotlib
into our namespace. I will be using a special feature of the IPython utility which allows me to "inline" matplotlib
figures by entering the %matplotlib inline
command. You might also want to try the nbagg
backend designed specifically for Jupyter notebooks as this backend allows you to interact (pan, zoom, etc) with the plot.
import matplotlib.pyplot as plt
%matplotlib inline
With matplotlib
imported, we can summon up a figure and make our first plot:
# Array of 0 - 9
x = np.arange(10)
# 10 random numbers, between 0 and 10
y = np.random.randint(0, 10, size=10)
# plot them as lines
plt.plot(x, y)
# plot them as just points -- specify "ls" ("linestyle") as a null string
plt.plot(x, y, 'ro', ls='')
One typical thing that we might want to do would be to plot one band against another. In order to do this, we will need to transform, or flatten
, our 2 dimensional arrays of each band's values into 1 dimensional arrays:
print('Array shape before: {shp} (size is {sz})'.format(shp=image[:, :, 3].shape, sz=image[:, :, 3].size))
red = np.ndarray.flatten(image[:, :, 2])
nir = np.ndarray.flatten(image[:, :, 3])
print('Array shape after: {shp} (size is {sz})'.format(shp=nir.shape, sz=nir.size))
We have retained the number of entries in each of these raster bands, but we have flattened them from 2 dimensions into 1.
Now we can plot them. Since we just want points, we can use scatter
for a scatterplot. Since there are no lines in a scatterplot, it has a slightly different syntax.
# Make the plot
plt.scatter(red, nir, color='r', marker='o')
# Add some axis labels
plt.xlabel('Red Reflectance')
plt.ylabel('NIR label')
# Add a title
plt.title('Tasseled Cap, eh?')
If we wanted the two axes to have the same limits, we can calculate the limits and apply them
# Make the plot
plt.scatter(red, nir, color='r', marker='o')
# Calculate min and max
plot_min = min(red.min(), nir.min())
plot_max = max(red.max(), nir.max())
plt.xlim((plot_min, plot_max))
plt.ylim((plot_min, plot_max))
# Add some axis labels
plt.xlabel('Red Reflectance')
plt.ylabel('NIR label')
# Add a title
plt.title('Tasseled Cap, eh?')
With so much data available to look at, it can be hard to understand what is going on with the mess of points shown above. Luckily our datasets aren't just a mess of points - they have a spatial structure.
To show the spatial structure of our images, we could make an image plot of one of our bands using imshow
to display an image on the axes:
# use "imshow" for an image -- nir at first
plt.imshow(image[:, :, 3])
Well, it looks like there is something going on - maybe a river in the center and some bright vegetation to the bottom left of the image. What's lacking is any knowledge of what the colors mean.
Luckily, matplotlib
can provide us a colorbar.
# use "imshow" for an image -- nir at first
plt.imshow(image[:, :, 3])
plt.colorbar()
If we want a greyscale image, we can manually specify a colormap:
# use "imshow" for an image -- nir in first subplot, red in second
plt.subplot(121)
plt.imshow(image[:, :, 3], cmap=plt.cm.Greys)
plt.colorbar()
# Now red band in the second subplot (indicated by last of the 3 numbers)
plt.subplot(122)
plt.imshow(image[:, :, 2], cmap=plt.cm.Greys)
plt.colorbar()
Greyscale images are nice, but the most information we can receive comes from looking at the interplay among different bands. To accomplish this, we can map different spectral bands to the Red, Green, and Blue channels on our monitors.
Before we can do this, the matplotlib
imshow
help tells us that we need to normalize our bands into a 0 - 1 range. To do so, we will perform a simple linear scale fitting 0 reflectance to 0 and 80% reflectance to 1, clipping anything larger or smaller.
Remember:
If we are going from a Int16 datatype (e.g., reflectance scaled by 10,000x) to a decimal between 0 and 1, we will need to use a Float!
# Extract reference to SWIR1, NIR, and Red bands
index = np.array([4, 3, 2])
colors = image[:, :, index].astype(np.float64)
max_val = 8000
min_val = 0
# Enforce maximum and minimum values
colors[colors[:, :, :] > max_val] = max_val
colors[colors[:, :, :] < min_val] = min_val
for b in range(colors.shape[2]):
colors[:, :, b] = colors[:, :, b] * 1 / (max_val - min_val)
plt.subplot(121)
plt.imshow(colors)
# Show NDVI
plt.subplot(122)
plt.imshow(ndvi, cmap=plt.cm.Greys_r)