Chris Holden (ceholden@gmail.com) - https://github.com/ceholden

Chapter 3: Plotting and visualizing your data with matplotlib

Introduction

matplotlib is a very powerful plotting library for making amazing visualizations for publications, personal use, or even web and desktop applications. matplotlib can create almost any two dimensional visualization you can think of, including histograms, scatter plots, bivariate plots, and image displays. For some inspiration, check out the matplotlib example gallery which includes the source code required to generate each example.

A great resource for learning matplotlib is available from J.R. Johansson.

matplotlib API - state-machine versus object-oriented

One part of matplotlib that may be initially confusing is that matplotlib contains two main methods of making plots - the object-oriented method, and the state-machine method.

While the library can be either used in an object-oriented manner (i.e., you create an object representing the figure, then the figure can spawn objects representing the axes, etc.), the most familiar usage of matplotlib for MATLAB users is the pyplot state-machine environment:

from the matplotlib usage FAQ:

Pyplot’s state-machine environment behaves similarly to MATLAB and should be most familiar to users with MATLAB experience.

A very good overview of the difference between the two usages is provided by Jake Vanderplas. Specifically,

In general, you should only use the Pyplot state-machine environment when plotting data interactively or when developing visualizations for your data. The object-oriented API, while more complicated, is a much more powerful way of creating plots and should be used when developing more complicated visualizations.

As this is a brief introduction to matplotlib, we will be using the Pyplot state-machine method for creating visualizations.

Image display

We will begin by reading our example image into a NumPy memory array as shown in Chapter 3

In [1]:
# Import the Python 3 print function
from __future__ import print_function

# Import the "gdal" and "gdal_array" submodules from within the "osgeo" module
from osgeo import gdal
from osgeo import gdal_array

# Import the NumPy module
import numpy as np

# Open a GDAL dataset
dataset = gdal.Open('../../example/LE70220491999322EDC01_stack.gtif', gdal.GA_ReadOnly)

# Allocate our array using the first band's datatype
image_datatype = dataset.GetRasterBand(1).DataType

image = np.zeros((dataset.RasterYSize, dataset.RasterXSize, dataset.RasterCount),
                 dtype=gdal_array.GDALTypeCodeToNumericTypeCode(image_datatype))

# Loop over all bands in dataset
for b in range(dataset.RasterCount):
    # Remember, GDAL index is on 1, but Python is on 0 -- so we add 1 for our GDAL calls
    band = dataset.GetRasterBand(b + 1)
    
    # Read in the band's data into the third dimension of our array
    image[:, :, b] = band.ReadAsArray()

ndvi = (image[:, :, 3] - image[:, :, 2]) / \
        (image[:, :, 3] + image[:, :, 2]).astype(np.float64)

With the data read in and NDVI calculated, let's make some plots.

Basic plotting

First thing to do is to import matplotlib into our namespace. I will be using a special feature of the IPython utility which allows me to "inline" matplotlib figures by entering the %matplotlib inline command. You might also want to try the nbagg backend designed specifically for Jupyter notebooks as this backend allows you to interact (pan, zoom, etc) with the plot.

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

With matplotlib imported, we can summon up a figure and make our first plot:

In [3]:
# Array of 0 - 9
x = np.arange(10)
# 10 random numbers, between 0 and 10
y = np.random.randint(0, 10, size=10)

# plot them as lines
plt.plot(x, y)
Out[3]:
[<matplotlib.lines.Line2D at 0x7f28b1a056a0>]
In [4]:
# plot them as just points -- specify "ls" ("linestyle") as a null string
plt.plot(x, y, 'ro', ls='')
Out[4]:
[<matplotlib.lines.Line2D at 0x7f28b14ab898>]

Plotting 2D arrays

One typical thing that we might want to do would be to plot one band against another. In order to do this, we will need to transform, or flatten, our 2 dimensional arrays of each band's values into 1 dimensional arrays:

In [5]:
print('Array shape before: {shp} (size is {sz})'.format(shp=image[:, :, 3].shape, sz=image[:, :, 3].size))

red = np.ndarray.flatten(image[:, :, 2])
nir = np.ndarray.flatten(image[:, :, 3])

print('Array shape after: {shp} (size is {sz})'.format(shp=nir.shape, sz=nir.size))
Array shape before: (250, 250) (size is 62500)
Array shape after: (62500,) (size is 62500)

We have retained the number of entries in each of these raster bands, but we have flattened them from 2 dimensions into 1.

Now we can plot them. Since we just want points, we can use scatter for a scatterplot. Since there are no lines in a scatterplot, it has a slightly different syntax.

In [6]:
# Make the plot
plt.scatter(red, nir, color='r', marker='o')

# Add some axis labels
plt.xlabel('Red Reflectance')
plt.ylabel('NIR label')

# Add a title
plt.title('Tasseled Cap, eh?')
Out[6]:
<matplotlib.text.Text at 0x7f28b1483470>

If we wanted the two axes to have the same limits, we can calculate the limits and apply them

In [7]:
# Make the plot
plt.scatter(red, nir, color='r', marker='o')

# Calculate min and max
plot_min = min(red.min(), nir.min())
plot_max = max(red.max(), nir.max())

plt.xlim((plot_min, plot_max))
plt.ylim((plot_min, plot_max))

# Add some axis labels
plt.xlabel('Red Reflectance')
plt.ylabel('NIR label')

# Add a title
plt.title('Tasseled Cap, eh?')
Out[7]:
<matplotlib.text.Text at 0x7f28b13ed860>

Plotting 2D arrays - images

With so much data available to look at, it can be hard to understand what is going on with the mess of points shown above. Luckily our datasets aren't just a mess of points - they have a spatial structure.

To show the spatial structure of our images, we could make an image plot of one of our bands using imshow to display an image on the axes:

In [8]:
# use "imshow" for an image -- nir at first
plt.imshow(image[:, :, 3])
Out[8]:
<matplotlib.image.AxesImage at 0x7f28b136c0f0>

Well, it looks like there is something going on - maybe a river in the center and some bright vegetation to the bottom left of the image. What's lacking is any knowledge of what the colors mean.

Luckily, matplotlib can provide us a colorbar.

In [9]:
# use "imshow" for an image -- nir at first
plt.imshow(image[:, :, 3])
plt.colorbar()
Out[9]:
<matplotlib.colorbar.Colorbar at 0x7f28b02a8c88>

If we want a greyscale image, we can manually specify a colormap:

In [10]:
# use "imshow" for an image -- nir in first subplot, red in second
plt.subplot(121)
plt.imshow(image[:, :, 3], cmap=plt.cm.Greys)
plt.colorbar()

# Now red band in the second subplot (indicated by last of the 3 numbers)
plt.subplot(122)
plt.imshow(image[:, :, 2], cmap=plt.cm.Greys)
plt.colorbar()
Out[10]:
<matplotlib.colorbar.Colorbar at 0x7f28b017cb70>

Plotting 3D arrays - multispectral images

Greyscale images are nice, but the most information we can receive comes from looking at the interplay among different bands. To accomplish this, we can map different spectral bands to the Red, Green, and Blue channels on our monitors.

Before we can do this, the matplotlib imshow help tells us that we need to normalize our bands into a 0 - 1 range. To do so, we will perform a simple linear scale fitting 0 reflectance to 0 and 80% reflectance to 1, clipping anything larger or smaller.

Remember:

If we are going from a Int16 datatype (e.g., reflectance scaled by 10,000x) to a decimal between 0 and 1, we will need to use a Float!
In [11]:
# Extract reference to SWIR1, NIR, and Red bands
index = np.array([4, 3, 2])
colors = image[:, :, index].astype(np.float64)

max_val = 8000
min_val = 0

# Enforce maximum and minimum values
colors[colors[:, :, :] > max_val] = max_val
colors[colors[:, :, :] < min_val] = min_val

for b in range(colors.shape[2]):
    colors[:, :, b] = colors[:, :, b] * 1 / (max_val - min_val)

plt.subplot(121)
plt.imshow(colors)

# Show NDVI
plt.subplot(122)
plt.imshow(ndvi, cmap=plt.cm.Greys_r)
Out[11]:
<matplotlib.image.AxesImage at 0x7f28b0108940>

Wrapup

We seen how matplotlib can be combined with NumPy and GDAL to easily visualize and explore our remote sensing data. In the next chapter (link to webpage or Notebook) we will cover how to use GDAL's companion library - OGR - to open and read vector data.