Making Watershed Maps in Python

This post builds off of earlier posts by Jon Lamontagne and Jon Herman on making global maps in Python using matplotlib and basemap. However rather than making a global map, I’ll show how to zoom into a particular region, here the Red River basin in East Asia. To make these maps, you’ll need to have basemap installed (from github here, or using a Windows installer here).

The first step is to create a basemap. Both Jons used the ‘robin’ global projection to do this in their posts. Since I’m only interested in a particular region, I just specify the bounding box using the lower and upper latitudes and longitudes of the region I’d like to plot. As Jon H points out, you can also specify the resolution (‘f’ = full, ‘h’ =high, ‘i’ = intermediate, ‘l’ = low, ‘c’ = crude), and you can even use different ArcGIS images for the background (see here). I use ‘World_Shaded_Relief’. It’s also possible to add a lot of features such as rivers, countries, coastlines, counties, etc. I plot countries and rivers. The argument ‘zorder’ specifies the order of the layering from 1 to n, where 1 is the bottom layer and n the top.

from mpl_toolkits.basemap import Basemap
from matplotlib import pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111)

# plot basemap, rivers and countries
m = basemap(llcrnrlat=19.5, urcrnrlat=26.0, llcrnrlon=99.6, urcrnr=107.5, resolution='h')

The above code makes the following image (it takes some time, since I’m using high resolution):

Now let’s add a shaded outline of the Red River basin. To do this, you need a shapefile of the basin. The FAO provides a shapefile of major watersheds in the world, from which you can extract the watershed you’re interested in using ArcGIS (see instructions here). In this shapefile, the Red River is labeled by its name in Vietnamese, ‘Song Hong.’ I chose not to draw the bounds of the basin in my map because it would be too busy with the country borders. Instead, I shaded the region gray (facecolor=’0.33′) with a slightly darker border (edgecolor=’0.5′) and slight transparency (alpha=0.5). To do that, I had to collect all of the patches associated with the shapefile (which I called ‘Basin’ when reading it in) that needed to be shaded.

from matplotlib.patches import Polygon
from matplotlib.collections import Patch Collection

# plot Red River basin
m.readshapefile('RedRiverBasin_WGS1984', 'Basin', drawbounds=False)
patches = []
for info, shape in zip(m.Basin_info, m.Basin):
    if info['OBJECTID'] == 1: # attribute in attribute table of shapefile
        patches.append(Polygon(np.array(shape), True))

ax.add_collection(PatchCollection(patches, facecolor='0.33', edgecolor='0.5', alpha=0.5))

This creates the following image:

Now let’s add the locations of major dams and cities in the basin using ‘scatter‘. You could again do this by adding a shapefile, but I’m just going to add their locations manually, either by uploading their latitude and longitude coordinates from a .csv file or by passing them directly.

import numpy as np

# plot dams
damsLatLong = np.loadtxt('DamLocations.csv', delimiter=',', skiprows=1, usecols=[1,2])
x, y = m(damsLatLong[:,1], damLatLong[:,0]) # m(longitude, latitude)
m.scatter(x, y, c='k', s = 150, marker = '^')

# plot Hanoi
x, y = m(105.8342, 21.0278)
m.scatter(x, y, facecolor='darkred', edgecolor='darkred', s=150)

This makes the following image:

If we want to label the dams and cities, we can add text specifying where on the map we’d like them to be located. This may require some guess-and-check work to determine the best place (comment if you know a better way!). I temporarily added gridlines to the map to aid in this process using ‘drawparallels‘ and ‘drawmeridians‘.

# label dams and Hanoi
plt.text(104.8, 21.0, 'Hoa Binh', fontsize=18, ha='center', va='center', color='k')
plt.text(104.0, 21.7, 'Son La', fontsize=18, ha='center', va='center', color='k')
plt.text(105.0, 21.95, 'Thac Ba', fontsize=18, ha='center', va='center', color='k')
plt.text(105.4, 22.55, 'Tuyen Quang', fontsize=18, ha='center', va='center', color='k')
plt.text(105.8, 21.2, 'Hanoi', fontsize=18, ha='center', va='center', color='k')

Now our map looks like this:

That looks nice, but it would be helpful to add some context as to where in the world the Red River basin is located. To illustrate this, we can create an inset of the greater geographical area by adding another set of axes with its own basemap. This one can be at a lower resolution.

from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes

# plot inset of greater geographic area
axins = zoomed_inset_axes(ax, 0.1, loc=1) # locations
axins.set_xlim(90, 115) # longitude boundaries of inset map
axins.set_ylim(8, 28) # latitude boundaries of inset map

# remove tick marks from inset axes

# add basemap to inset map
m2 = Basemap(llcrnrlat=8.0, urcrnclat=28.0, llcrnr=90.0, urcrnrlon=115.0, resolution='l', ax=axins)
m2.drawcountries(color='k', linewidth=0.5)

This image looks like this:

Now let’s highlight a country of interest (Vietnam) in green and also add the Red River basin in light gray again.

# plot Vietnam green in inset
m2.readshapefile('VN_borders_only_WGS1984', 'Vietnam', drawbounds=False)
patches2 = []
for info, shape in zip(m2.Vietnam_info, m2.Vietnam):
    if info['Joiner'] == 1:
        patches2.append(Polygon(np.array(shape), True))

axins.add_collection(PatchCollection(patches2, facecolor='forestgreen', edgecolor='0.5', alpha=0.5))

# shade Red River basin gray in inset
axins.add_collection(PatchCollection(patches, faceolor='0.33', edgecolor='0.5', alpha=0.5)

Now our map looks like this:

Finally, let’s label the countries in the inset. Some of the countries are too small to fit their name inside, so we’ll have to create arrows pointing to them using ‘annotate‘. In this function, ‘xy’ specifies where the arrow points to and ‘xytext’ where the text is written relative to where the arrow points.

# label countries
plt.text(107.5, 25.5, 'China', fontsize=11, ha='center', va='center', color='k')
plt.text(102.5, 20.2, 'China', fontsize=11, ha='center', va='center', color='k')
plt.text(101.9, 15.5, 'China', fontsize=11, ha='center', va='center', color='k')
plt.text(9.5, 21.0, 'China', fontsize=11, ha='center', va='center', color='k')

# add arrows to label Vietnam and Cambodia 
plt.annotate('Vietnam', xy=(108.0, 14.0), xycoords='data', xytext=(5.0, 20.0), textcoords='offset points', \ 
    color='k', arrowprops=dict(arrowstyle='-'), fontsize=11)
plt.annotate('Cambodia', xy=(104.5, 12.0), xycoords='data', xytext=(-60.0, -25.0), textcoords='offset points', \ 
    color='k', arrowprops=dict(arrowstyle='-'), fontsize=11)

Now our map looks like this:

I think that’s pretty good, so let’s save it ;). See below for all the code used to make this map, with all the import statements at the beginning rather than sporadically inserted throughout the code!

If you’re looking for any other tips on how to make different types of maps using basemap, I recommend browsing through the basemap toolkit documentation and this basemap tutorial, where I learned how to do most of what I showed here.

from mpl_toolkits.basemap import Basemap
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
import numpy as np

# set-up Vietnam basemap
fig = plt.figure()
fig.set_size_inches([17.05, 8.15])
ax = fig.add_subplot(111)

# plot basemap, rivers and countries
m = Basemap(llcrnrlat=19.5,urcrnrlat=26.0,llcrnrlon=99.6,urcrnrlon=107.5,resolution='h')

# plot Red River basin
patches = []
for info, shape in zip(m.Basin_info, m.Basin):
    if info['OBJECTID'] == 1:
        patches.append(Polygon(np.array(shape), True))

ax.add_collection(PatchCollection(patches, facecolor='0.33',edgecolor='0.5',alpha=0.5))

# plot dams
damsLatLong = np.loadtxt('DamLocations.csv',delimiter=',',skiprows=1,usecols=[1,2])
x, y = m(damsLatLong[:,1], damsLatLong[:,0])
m.scatter(x, y, c='k', s=150, marker='^')

# plot Hanoi
x, y = m(105.8342, 21.0278)
m.scatter(x, y, facecolor='darkred', edgecolor='darkred', s=150)

# label reservoirs and Hanoi
plt.text(104.8, 21.0, 'Hoa Binh', fontsize=18, ha='center',va='center',color='k')
plt.text(104.0, 21.7, 'Son La', fontsize=18, ha='center', va='center', color='k')
plt.text(105.0, 21.95, 'Thac Ba', fontsize=18, ha='center', va='center', color='k')
plt.text(105.4, 22.55, 'Tuyen Quang', fontsize=18, ha='center', va='center', color='k')
plt.text(105.8, 21.2, 'Hanoi', fontsize=18, ha='center', va='center', color='k')

# plot inset of greater geographic area
axins = zoomed_inset_axes(ax, 0.1, loc=1)
axins.set_xlim(90, 115)


m2 = Basemap(llcrnrlat=8.0,urcrnrlat=28.0,llcrnrlon=90.0,urcrnrlon=115.0,resolution='l',ax=axins)

# plot Vietnam green in inset
patches2 = []
for info, shape in zip(m2.Vietnam_info, m2.Vietnam):
    if info['Joiner'] == 1:
        patches2.append(Polygon(np.array(shape), True))

axins.add_collection(PatchCollection(patches2, facecolor='forestgreen',edgecolor='0.5',alpha=0.5))

# shade Red River basin gray in inset
axins.add_collection(PatchCollection(patches, facecolor='0.33',edgecolor='0.5',alpha=0.5))

# label countries
plt.text(107.5, 25.5, 'China', fontsize=11, ha='center',va='center',color='k')
plt.text(102.5, 20.2, 'Laos', fontsize=11, ha='center', va='center', color='k')
plt.text(101.9, 15.5, 'Thailand', fontsize=11, ha='center', va='center', color='k')
plt.text(96.5, 21.0, 'Myanmar', fontsize=11, ha='center', va='center', color='k')

plt.annotate('Vietnam', xy=(108.0,14.0), xycoords='data', xytext=(5.0,20.0), textcoords='offset points', \
plt.annotate('Cambodia', xy=(104.5,12.0), xycoords='data', xytext=(-60.0,-25.0), textcoords='offset points', \


Plotting geographic data from geojson files using Python

Plotting geographic data from geojson files using Python

Hi folks,

I’m writing today about plotting geojson files with Matplotlib’s Basemap.  In a previous post I laid out how to plot shapefiles using Basemap.

geojson is an open file format for representing geographical data based on java script notation.  They are composed of points, lines, and polygons or ‘multiple’ (e.g. multipolygons composed of several polygons), with accompanying properties.  The basic structure is one of names and vales, where names are always strings and values may be strings, objects, arrays, or logical literal.

The geojson structure we will be considering here is a collection of features, where each feature contains a geometry and properties.  Each geojson feature must contain properties and geometry.  Properties could be things like country name, country code, state, etc.  The geometry must contain a type (point, line, polygons, etc.) and coordinates (likely an array of lat-long). Below is an excerpt of a geojson file specifying Agro-Ecological Zones (AEZs) within the various GCAM regions.

"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },

"features": [
{ "type": "Feature", "id": 1, "properties": { "ID": 1.000000, "GRIDCODE": 11913.000000, "CTRYCODE": 119.000000, "CTRYNAME": "Russian Fed", "AEZ": 13.000000, "GCAM_ID": "Russian Fed-13" }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ 99.5, 78.5 ], [ 98.33203125, 78.735787391662598 ], [ 98.85723876953125, 79.66796875 ], [ 99.901641845703125, 79.308036804199219 ], [ 99.5, 78.5 ] ] ] ] } },
{ "type": "Feature", "id": 2, "properties": { "ID": 2.000000, "GRIDCODE": 11913.000000, "CTRYCODE": 119.000000, "CTRYNAME": "Russian Fed", "AEZ": 13.000000, "GCAM_ID": "Russian Fed-13" }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ 104.5, 78.0 ], [ 104.0, 78.0 ], [ 99.5, 78.0 ], [ 99.5, 78.5 ], [ 100.2957763671875, 78.704218864440918 ], [ 102.13778686523437, 79.477890968322754 ], [ 104.83050537109375, 78.786871910095215 ], [ 104.5, 78.0 ] ] ] ] } },
{ "type": "Feature", "id": 3, "properties": { "ID": 3.000000, "GRIDCODE": 2713.000000, "CTRYCODE": 27.000000, "CTRYNAME": "Canada", "AEZ": 13.000000, "GCAM_ID": "Canada-13" }, "geometry": { "type": "MultiPolygon", "coordinates": [ [ [ [ -99.5, 77.5 ], [ -100.50860595703125, 77.896504402160645 ], [ -101.76053619384766, 77.711499214172363 ], [ -104.68202209472656, 78.563323974609375 ], [ -105.71781158447266, 79.692866325378418 ], [ -99.067413330078125, 78.600395202636719 ], [ -99.5, 77.5 ] ] ] ] } }

Now that we have some understanding of the geojson structure, plotting the information therein should be as straightforward as traversing that structure and tying geometries to data.  We do the former using the geojson python package and the latter using pretty basic python manipulation.  To do the actual plotting, we’ll use PolygonPatches from the descartes library and recycle most of the code from my previous post.

We start by importing the necessary libraries and then open the geojson file.

import geojson
from descartes import PolygonPatch
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
import numpy as np

with open("aez-w-greenland.geojson") as json_file:
    json_data = geojson.load(json_file)

We then define a MatplotLib Figure, and generate a Basemap object as a ‘canvas’ to draw the geojson geometries on.

ax = plt.figure(figsize=(10,10)).add_subplot(111)#fig.gca()

m = Basemap(projection='robin', lon_0=0,resolution='c')
m.drawmapboundary(fill_color='white', zorder=-1)
m.drawparallels(np.arange(-90.,91.,30.), labels=[1,0,0,1], dashes=[1,1], linewidth=0.25, color='0.5',fontsize=14)
m.drawmeridians(np.arange(0., 360., 60.), labels=[1,0,0,1], dashes=[1,1], linewidth=0.25, color='0.5',fontsize=14)
m.drawcoastlines(color='0.6', linewidth=1)

Next, we iterate over the nested features in this file and pull out the coordinate list defining each feature’s geometry (line 2).  In lines 4-5 we also pull out the feature’s name and AEZ, which I can tie to GCAM data.

for i in range(2799):
    coordlist = json_data.features[i]['geometry']['coordinates'][0]
    if i < 2796:
        name = json_data.features[i]['properties']['CTRYNAME']
        aez =  json_data.features[i]['properties']['AEZ']

    for j in range(len(coordlist)):
        for k in range(len(coordlist[j])):

    poly = {"type":"Polygon","coordinates":coordlist}#coordlist
    ax.add_patch(PolygonPatch(poly, fc=[0,0.5,0], ec=[0,0.3,0], zorder=0.2 ))


Line 9 is used to convert the coordinate list from lat/long units to meters.  Depending on what projection you’re working in and what units your inputs are in, you may or may not need to do this step.

The final lines are used to add the polygon to the figure, and to make the face color of each polygon green and the border dark green. Which generates the figure:


To get a bit more fancy, we could tie the data to a colormap and then link that to the facecolor of the polygons.  For instance, the following figure shows the improvement in maize yields over the next century in the shared socio-economic pathway 1 (SSP 1), relative to a reference scenario (SSP 2).


Fitting Multivariate Normal Distributions

In water resources modeling, we often want to generate synthetic multivariate data for a simulation model in a way that preserves their marginal and joint behavior. The multivariate normal (MVN) distribution is a common model choice for these simulations because 1) it often arises naturally due to the Central Limit Theorem, 2) it has many useful properties that make data manipulation convenient, and 3) data can often be transformed to MVN even if that is not their underlying distribution.

As an example, let X be a K-dimensional random vector with 1 mean vector μ and K covariance matrix [Σ]. If X is multivariate normal, i.e. X~MVN(μ,[Σ]), its probability density function is the following:


where det(·) denotes the determinant. The term (xμ)T[Σ]-1(xμ) is called the squared Mahalanobis distance, which measures how far away an observation x is from its distribution’s mean, scaled by a multi-dimensional measure of spread, [Σ]. It is therefore a measure of how far away from the mean each data vector is in a statistical sense, as opposed to Euclidean distance, which only measures distance in a physical sense. The two measures are related, though. If all of the K dimensions of X are independent, every non-diagonal element of [Σ] will be equal to 0, and the diagonals equal to the variances in each dimension, σk2 where k ϵ {1, 2, …, K}. In that case, the Mahalanobis distance is equal to the Euclidean distance after scaling each data vector by its standard deviation.

So how do we fit MVN distributions? Well, the MVN distribution has “four handy properties” (Wilks, 2011) that we can test. Here I will discuss two of them and how we can use these properties to test for multivariate normality. See Chapter 11 of Wilks (2011) for additional tests for multivariate normality.

Let X be a set of n joint observations of K variables. Denote each of the n observations xi = [xi1, xi2, …, xik] where i ϵ {1, 2, …, n} and each of the K marginals Xk = [xk1, xk2, …, xkn] where k ϵ {1, 2, …, K}. If X~MVN(μ,[Σ]), the following two properties (among others) hold:

  1. All marginal distributions of X are univariate normal, i.e. Xk~N(μk, σk2)
  2. The squared Mahalanobis distances, Di2 = (xiμ)T[Σ]-1(xiμ), have a χk2 distribution with k degrees of freedom.

So if we want to fit a MVN distribution to X, each of these will have to be true. Let’s look at an example where X is the standard deviation in daily flows during all historical Septembers at five different sites within the same basin. In this case, K=5 for the 5 sites and n=51, as there are 51 years in the historical record. To fit a MVN distribution to X, we’ll first want to ensure that the marginal distributions of the standard deviations in daily September flows are normal at each of the K sites. Let’s inspect these distributions visually with a histogram, first:


Clearly these distributions are not normal, as they are positively skewed. But that’s okay, we can transform the data so that they look more normal. The Box-Cox transformation is commonly used to transform non-normal data, X, to normal data, Y (see Chapter 3 of Wilks (2011) for more details):


Using λ=0 (a log-transform), our transformed data look like this:


These look much better! We can perform a formal hypothesis test to confirm that each of these 5 transformed data series are not inconsistent with the normal distribution using a number of tests, such as the Shapiro-Wilk test, the Kolmogorov-Smirnov test, and the Filliben Q-Q correlation test, which I use here (see Chapter 5 of Wilks, 2011 for a description of other tests). The Filliben Q-Q test finds the correlation between the sample data quantiles and the theoretical quantiles of the distribution being fit. I’ve plotted these below; the correlation coefficients at these 5 sites are [0.9922, 0.9951, 0.9822, 0.9909, 0.9945].


Rejection regions for the Filliben Q-Q test for the normal distribution are tabulated for different significance levels and sample sizes based on Monte Carlo results. The relevant section of the table is copied below. For a sample size of n≈50, the site with the lowest correlation (Site 3: 0.9822) fails to reject the null hypothesis that the data are normal at the 10% level, as the rejection region is ρ≤0.981. This means that if the data were normal, there would be a 10% chance that a data series of length n=50 would have a correlation coefficient below 0.981.


So now we know that none of the marginal distributions at each site is inconsistent with the normal distribution, but that does not guarantee that the joint distribution across sites will be multi-variate normal. There could be multi-variate outliers, or points which are not extreme within any particular site’s distribution, but are extreme within the context of the overall covariance structure. We test this by confirming that the squared Mahalanobis distances are not inconsistent with a χk2 distribution. Again, this can be done by comparing the sample data quantiles to the theoretical data quantiles (figure below). Here the correlation coefficient is 0.9964.


Because the rejection regions will depend not only on the sample size (n) and significance levels, but also the number of degrees of freedom (k), there are no tabulated critical values for this test (there would need to be a separate table for every possible k). Instead of using a table, one has to perform a Monte Carlo simulation to calculate the critical region for their specific application. In this case, I did that by generating 10,000 random samples of length n=51 from a χ2 distribution with k=5 degrees of freedom. Of the generated samples, 97.8% had correlation coefficients less than the observed value of 0.9964 suggesting that this sample is very consistent with a χ52 distribution.

So now that we know the MVN is a good fit for the log-transformed standard deviations in daily September flows, we can estimate the model parameters. This part is easy! The MLE estimator of the mean vector μ is the sample mean vector = [1,2, … , k] of the data (in this case, the log-transformed data), while the MLE estimator of the covariance [Σ] is the sample covariance \left[S\right]= \frac{1}{N-1}\left[X^{'}\right]^{T}\left[X^{'}\right], where \left[X^{'}\right] = \frac{1}{N}\left[1\right]\left[X\right] with [1] being an N×N matrix of 1s and [X] an N×K  matrix of the data (log transformed here).

Below is Python code for all of the fitting and plotting done here.

from __future__ import division
import numpy as np
from matplotlib import pyplot as plt
import seaborn.apionly as sns
from scipy import stats

def fitMVN():
    # set plotting style

    # load streamflow data
    Qdaily = np.loadtxt('data/Qdaily.txt')
    Nyears = int(np.shape(Qdaily)[0]/365)
    Nsites = np.shape(Qdaily)[1]
    Months = ['May','June','July','August','September','October','November','December','January','February','March','April']

    # calculate standard deviation in daily flows each month and squared Mahalanobis distances
    StdMonthly = calc_monthly_std(Qdaily, Nyears, Nsites)
    D2 = calcD2(Nyears, Nsites, np.log(StdMonthly))

    # calculate theoretical quantiles for a chi^2 distribution with dof = Nsites, and for the standard normal distribution
    m = np.array(range(1,Nyears+1))
    p = (m-0.5)/Nyears
    chi2 = stats.chi2.ppf(p,Nsites)
    norm = stats.norm.ppf(p,0,1)

    # initialize matrices to store correlation coefficients and significance levels for marginal normal distributions and chi^2 distributions
    normCorr = np.zeros([Nsites,12])
    norm_sigLevel = np.zeros([Nsites,12])
    chi2Corr = np.zeros([12])
    chi2_sigLevel = np.zeros([12])

    for i in range(len(Months)):
        # plot histograms of standard deviation of daily flows each month, and of their logs
        plotHistograms(Nsites, StdMonthly[:,:,i], 'Standard Deviation of Daily ' + Months[i] + ' Flows', Months[i] + 'Hist.png')
        plotHistograms(Nsites, np.log(StdMonthly[:,:,i]), 'log(Standard Deviation of Daily ' + Months[i] + ' Flows)', \
            'Log' + Months[i] + 'Hist.png')

        # plot QQ plots of standard deviation of daily flows each month, and of their logs
        plotNormQQ(Nsites, StdMonthly[:,:,i], norm, 'Standard Deviation of Daily ' + Months[i] + ' Flows', Months[i] + 'QQ.png')
        normCorr[:,i] = plotNormQQ(Nsites, np.log(StdMonthly[:,:,i]), norm, 'log(Standard Deviation of Daily ' + Months[i] + ' Flows)', 'Log' + Months[i] + 'QQ.png')

        # plot QQ plot of Chi Squared distribution of log of standard deviation in daily flows each month
        chi2Corr[i] = plotChi2QQ(Nsites, D2[:,i], chi2, 'D$\mathregular{^2}\!$ of log(Standard Deviation of Daily ' + Months[i] + ' Flows)', \
            'Log' + Months[i] + 'Chi2QQ.png')

        # find significance levels
        chi2_sigLevel[i] = chi2_MC(Nsites,Nyears,chi2,chi2Corr[i])
        norm_sigLevel[:,i] = norm_MC(Nsites,Nyears,norm,normCorr[:,i])


    return None

def calc_monthly_std(Qdaily, Nyears, Nsites):
    Nmonths = 12
    # first month = May (1st month of water year)
    DaysPerMonth = np.array([31, 30, 31, 31, 30, 31, 30, 31, 31, 28, 31, 30])

    Qmonthly = np.zeros([Nsites, Nyears, Nmonths])
    StdMonthly = np.zeros([Nsites, Nyears, Nmonths])
    for year in range(Nyears):
        for month in range(Nmonths):
            start = year*365 + np.sum(DaysPerMonth[0:month])

            for i in range(Nsites):
                # find total flow each month
                Qmonthly[i,year,month] = 86400*np.sum(Qdaily[start:start+DaysPerMonth[month],i])

            # find standard deviation in daily flows each month
            for i in range(Nsites):
                for j in range(DaysPerMonth[month]):
                    StdMonthly[i,year,month] = StdMonthly[i,year,month] + \

                StdMonthly[i,year,month] = np.sqrt((1/(DaysPerMonth[month]-1))*StdMonthly[i,year,month])

    return StdMonthly

def plotHistograms(Nsites, data, xlabel, filename):
    fig = plt.figure()
    for i in range(Nsites):
        ax = fig.add_subplot(1,Nsites,i+1)
        ax.set_title('Site ' + str(i+1),fontsize=16)

    fig.text(0.1, 0.5, 'Frequency', va='center', rotation='vertical', fontsize=14)
    fig.text(0.5, 0.04, xlabel, ha='center', fontsize=14)
    fig.savefig('Hists/' + filename)

    return None

def plotNormQQ(Nsites, data, norm, title, filename):
    corr = np.zeros([Nsites])
    fig = plt.figure()
    for i in range(Nsites):
        corr[i] = np.corrcoef(np.sort(data[i,:]),norm)[0,1]
        z = (data[i,:] - np.mean(data[i,:]))/np.std(data[i,:])
        ax = fig.add_subplot(1,Nsites,i+1)
        ax.set_title('Site ' + str(i+1),fontsize=16)

    fig.text(0.1, 0.5, 'Sample Quantiles', va='center', rotation='vertical', fontsize=14)
    fig.text(0.5, 0.04, 'Theoretical Quantiles', ha='center', fontsize=14)
    fig.suptitle('Normal Q-Q Plot of ' + title,fontsize=16)
    fig.savefig('QQplots/' + filename)

    return corr

def calcD2(Nyears, Nsites, data):
    D2 = np.zeros([Nyears,12])
    X = np.zeros([Nyears, Nsites])
    Xprime = np.zeros([Nyears,Nsites])
    S = np.zeros(Nsites)
    for i in range(12):
        # fill data matrix, X, for ith month
        for j in range(Nsites):
            X[:,j] = data[j,:,i]

        # calculate covariance matrix, S, for ith month
        Xprime = X - (1/Nyears)*[Nyears,Nyears]),X)
        S = (1/(Nyears-1))*,Xprime)

        #calculate Mahalanobis distance, D2, for each year's ith month
        for j in range(Nyears):
            D2[j,i] =[j,:] - np.mean(X,0)),np.linalg.inv(S)),(np.transpose(X[j,:] - np.mean(X,0))))

    return D2

def plotChi2QQ(Nsites, data, chi2, title, filename):
    corr = np.corrcoef(np.sort(data),chi2)[0,1]
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.set_xlabel('Theoretical Quantiles',fontsize=16)
    ax.set_xlim([0, 1.1*np.max(chi2)])
    ax.set_ylabel('Sample Quantiles',fontsize=16)
    ax.set_title(r'$\chi^2$' + ' Q-Q Plot of ' + title,fontsize=16)
    fig.savefig('QQplots/' + filename)

    return corr

def chi2_MC(Nsites,Nyears,theoretical,dataCorr):
    corr = np.zeros(10000)
    for i in range(10000): # 10,000 MC simulations
        simulated = stats.chi2.rvs(Nsites,size=Nyears)
        corr[i] = np.corrcoef(np.sort(simulated),theoretical)[0,1]

    # find significance levels
    corr = np.sort(corr)
    for i in range(10000):
        if dataCorr > corr[i]:
            sigLevel = (i+0.5)/10000

    return sigLevel

def norm_MC(Nsites,Nyears,theoretical,dataCorr):
    sigLevel = np.zeros(Nsites)
    corr = np.zeros([10000])
    for i in range(10000): # 10,000 MC simulations
        simulated = stats.norm.rvs(0,1,size=Nyears)
        corr[i] = np.corrcoef(np.sort(simulated),theoretical)[0,1]

    # find significance levels
    corr = np.sort(corr)
    for i in range(10000):
        for j in range(Nsites):
            if dataCorr[j] > corr[i]:
                sigLevel[j] = (i+0.5)/10000

    return sigLevel


Easy vectorized parallel plots for multiple data sets

I will share a very quick and straight-forward solution to generate parallel plots in python of multiple groups of data.   The idea is transitioning from the parallel axis plot tool  to a method that enables  the plots to be exported as a vectorized image.   You can also take a look at Matt’s python code available in github: .

This is the type of figure that you will get:


The previous figure was generated with the following lines of code:

import numpy as np
import pandas as pd
from import parallel_coordinates
import matplotlib.pyplot as plt
import seaborn

data = pd.read_csv('sample_data.csv')

parallel_coordinates(data,'Name', color= ['#225ea8','#7fcdbb','#1d91c0'], linewidth=5, alpha=.8)
plt.ylabel('Direction of Preference $\\rightarrow$', fontsize=12)


Lines 1-4 are the required libraries.  I just threw in the seaborn library to give it the gray background but it is not necessary.  In the parallel_coordinates function, you need to specify the data, ‘Name’ and the color of the different groups.  You can substitute the color  variable for colormap and specify the colormap that you wish to use (e.g. colormap=’YlGnBu’).   I also specified an alpha for transparency to see overlapping lines. If you want to learn more, you can take a look at the parallel_coordinates source code.  I found this stack overflow link very useful,  it shows some examples on editing the source code to enable other capabilities.

Finally, the following snippet shows the format of the input data (the sample_data.csv  file that is read in line 7 ) :


Columns A-G the different categories to be plotted are specified (e.g. the objectives of a problem) and in Column H the names of the different data groups are specified.  And there you have it, I hope you find this plotting alternative useful.

Visualization strategies for multidimensional data

This is the first part of a series of blog posts on multidimensional data visualization strategies.   The main objectives of this first part are:

  1. Show you how to expand plotting capabilities by modifying matplotlib source code.
  2. Generate a tailored 6-D Pareto front plot with completely customized legends.
  3. Provide a glimpse of a recently developed Pareto front video repository in R.

1. Expanding matplotlib capabilities

Keeping in mind that matplotlib is an opensource project developed in the contributors’ free time, there is no guarantee that features that contributors make will be added straightaway.  In my case, I needed the marker rotation capabilities in a 3 D scatter plot.  Luckily, someone already had figured out how to do so and started a pull request in the matplotlib github repository but this change has not yet been implemented.  Since I couldn’t wait for the changes to happen, here’s the straightforward solution that I found:

Here’s  the link to the  pull request that I am referring to.

First, I located where Matplotlib lives in my computer, the path in my case is:


Then, I located the files that the contributor changed.  The files’ paths are circled in red in the following snippets of the pull request:



I located those files in my local matplotlib folder, which in my case are:



In the previous snippets, the lines of code that were added to the original script are highlighted in green and those that were removed are highlighted in red.  Hence, to access the clean version I clicked on the view button and selected the entire script and copied and pasted it in my local matplotlib code.  For this exercise I ended changing only a couple of scripts: the and the

NOTE:  If you ever need to undertake this type of solution, make sure you paste the lines of code in the right places, do this part carefully.   Also, it’s always a good idea to make backups of the original files in case something goes irreversibly wrong.  Or you can always uninstall and install, no big deal.

2. Generate a tailored 6D Pareto front plot with customized legends.

Matplotlib allows visualization of 5 objectives quite easily, but scaling to 6 or more objectives can be a bit tricky.  So, lets walk through our  6 D  plots in Matplotlib. We will learn how to do one of the following plots:

Pie Day  Plot:


St. Patrick’s Day  Plot:


2.1. Required libraries:

The following are the only libraries that you’ll need.   I import seaborn sometimes because it looks fancy but it’s totally unnecessary in this case, which is why it is commented out.

import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import matplotlib.pyplot as plt
#import seaborn

2.2. Importing data:

The data file that I used consists of 6 space-separated columns, if your data has another delimiter you can just add it like so:   data= np.loadtxt(‘sample_data.txt, delimiter=’,’).  I am also multiplying the first five columns by -1 because I want to remove the negatives, this is specific to my data, you may not require to do so.

data= np.loadtxt('sample_data.txt')

#Organizing the data by objectives
obj1 = data[:,0]*-1
obj2 = data[:,1]*-1
obj3 = data[:,2]*-1
obj4 = data[:,3]*-1
obj5 = data[:,4]*-1
obj6 = data[:,5]

2.3. Object-based plotting:

To allow more customization, we need to move to a more object-based way to make the plots.  That is, storing  elements of the  plots in variables.

&lt;span class=&quot;n&quot;&gt;
fig = plt.figure() # create a figure object
ax = fig.add_subplot(111, projection='3d') # create an axes object in the figure

2.4. Setting marker options:

Any mathtext symbol can be used as a marker.  In order to use rotation to represent an additional objective  it’s preferable if the marker has a single axis of symmetry so that the rotation is distinguishable.  Here are some marker options:

pie=r'$\pi$' #pie themed option
arrow = u'$\u2193$' # Arrows
clover=r'$\clubsuit$' #Saint Patrick's theme
heart=r'$\heartsuit$' # Valentine's theme
marker=pie #this is were you provide the marker options

More marker options can be found in :

2.4.  Scatter 6D plot:

The first three objectives are plotted in a 3-D scatter plot, in the x,y, and z axis respectively.  The fourth objective is represented by color, the fifth by size and the sixth by rotation.  Note that the rotation is scaled in degrees.  This is the step were I had to modify matplotlib source code to enable the ‘angles’ option shown below.  Also, it may be required to scale the size objective to have the desired marker size in your plot.  You can also plot the ideal point by adding a second scatter plot specifying the ideal values for each objective.  Finally, we assign the size objective “objs” and rotation objective “objr”, this will be useful later on when setting up the legend for these two objectives.

rot_angle=180 #rotation angle multiplier
scale=2000 #size objective multiplier
#Plotting 6 objectives:
im= ax.scatter(obj1, obj2, obj3, c=obj4, s= obj5*scale, marker=marker, angles=obj6*rot_angle, alpha=1,
ax.scatter(1,1,0, marker=pie, c='seagreen', s=scale, alpha=1)
objs=obj5 #size objective
objr=obj6 #rotation objective

2.5.  Main axis labels and limits:

This is extremely straightforward, you can set the x,y, and z labels and specify their limits as follows:

#Main axis labels:
ax.set_xlabel('Objective 1')
ax.set_ylabel('Objective 2')
ax.set_zlabel('Objective 3')
#Axis limits:
ax.set_zlim3d(0, 1)

2.6.  Color bar options:

The colorbar limits and labels can also be specified, as shown in the code below.  There are many colormap options in matplotlib, some of the most popular ones are: jet, hsv and spectral.   As an example, if you want to change the colormap in the code shown in part 2.4, do cmap=  To reverse the colormap attach an ‘_r ‘ like so: cmap=  There is also a color brewer package for the more artistic plotter.

# Set the color limits.. not necessary here, but good to know how.
im.set_clim(0.0, 1.0)

#Colorbar label:
cbar = plt.colorbar(im)'Objective 4')

2.6.  Size and rotation legends:

This is were it gets interesting.  The first couple of lines get the labels for legend and chose which ones to display.  This allows for much flexibility when creating the legends.  As you can see in the code below, you can show markers that correspond to the maximum and the minimum objective values to orient the reader.  You can assign the spacing between lines in the legend, the  title, weather you want to frame your legend or not, the location in the figure, etc.  Line 22 of the following code shows how to add more than one legend.  There are many options for an entirely customized legend in the legend documentation which you can explore for more options.

&lt;pre&gt;handles, labels = ax.get_legend_handles_labels()
display = (0,1,2)

#Code for size and rotation legends begins here for Objectives 5 and 6:

#Custom size legend:
size_max = plt.Line2D((0,1),(0,0), color='k', marker=marker, markersize=max_size,linestyle='')
size_min = plt.Line2D((0,1),(0,0), color='k', marker=marker, markersize=min_size,linestyle='')
legend1= ax.legend([handle for i,handle in enumerate(handles) if i in display]+[size_max,size_min],
[label for i,label in enumerate(labels) if i in display]+[&quot;%.2f&quot;%(np.amax(objs)), &quot;%.2f&quot;%(np.amin(objs))], labelspacing=1.5, title='Objective 6', loc=1, frameon=True, numpoints=1, markerscale=1)

#Custom rotation legend
rotation_max = plt.Line2D((0,1),(0,0),color='k',marker=r'$\Uparrow$', markersize=15, linestyle='')
rotation_min = plt.Line2D((0,1),(0,0),color='k', marker=r'$\Downarrow$', markersize=15, linestyle='')
ax.legend([handle for i,handle in enumerate(handles) if i in display]+[rotation_max,rotation_min],
[label for i,label in enumerate(labels) if i in display]+[&quot;%.2f&quot;%(np.amax(objr)), &quot;%.2f&quot;%(np.amin(objr))], labelspacing=1.5, title='Objective 5',loc=2, frameon=True, numpoints=1, markerscale=1)


You can find the full code for the previous example in the following github repository:

3. Generate 6D Pareto front and runtime videos in R.

And last but not least, let me direct everyone to Calvin’s repository:  Where  you can find the paretoMovieFront6D.R script which enables the exploration of  the evolution of a  6D Pareto front.   It is an extremely flexible tool and it has around 50 customization options to adapt your video or your plot to your visual needs, all you need is your runtime output, so check it out.  I made the tiniest contribution to this repository so I feel totally entitled to talk about it.   Here is a snippet of the video:


Making Movies of Time-Evolving Global Maps with Python

Making Movies of Time-Evolving Global Maps with Python

Hi All,

These past few months I’ve been working with the Global Change Assessment Model (GCAM) which is an integrated assessment model (IAM) that combines models of the global climate and economic systems. I wrote an earlier post on compiling GCAM on a Unix cluster.  This post discusses some visualization tools I’ve developed for GCAM output.

GCAM models energy and agriculture systems at a regional level, where the world is composed of 32 regions.  We’re interested in tracking statistics (like the policy cost of stabilization) over time and across regions.  This required three things:

  1. The ability to draw a global map.
  2. The ability to shade individual political units on that map.
  3. The ability to animate this map.

Dr. Jon Herman has already posted a good example of how to do (1) in python using matplotlib’s Basemap.  We’ll appropriate some of his example for this example.  The Basemap has the option to draw coastlines and boundaries, but these boundaries are not tied to shapes, meaning that you can’t assign different colors to individual countries (task (2) above).  To do that, we need a shapefile containing information about political boundaries.  You can find these for free from a number of sources online, but I like Natural Earth.  They provide data on many different scales. For this application I downloaded their coarsest data set.  To give each country a shade which is tied to data, we use matplotlib’s color map.  The basic plan is to generate a colored map for each time-step in our data, and then to animate the maps using the convert linux command.

Now that we’ve described roughly how we’ll proceed, a word about the data we’re dealing with and how I’ve handled it.  GCAM has 32 geo-political regions, some of which are individual countries (like the USA or China), while others are groups of countries (like Australia & New Zealand). I stored this information using a list of lists (i.e. a 32-element list, where each element is a list of countries in that region). I’ve creatively named this variable list_list in this example (see code below). For each of the regions GCAM produces a time series of policy costs as a fraction of GDP every 5 years from 2020-2100. I’ve creatively named this variable data. We want to tie the color of a country in each time to its policy cost relative to costs across countries and times.  To do this, I wrote the following (clumsy!) Python function, which I explain below.

def world_plot(data,idx,MN,MX):
 from mpl_toolkits.basemap import Basemap
 import matplotlib.pyplot as plt
 from matplotlib.patches import Polygon
 from matplotlib.collections import PatchCollection
 import as cm
 import matplotlib as mpl
 import numpy as np

 norm = mpl.colors.Normalize(vmin=MN, vmax=MX)
 cmap = cm.coolwarm
 colors=cm.ScalarMappable(norm=norm, cmap=cmap)
 a = np.zeros([32,4])
 a = colors.to_rgba(data)

 fig = plt.figure(figsize=(10,10))
 ax = fig.add_subplot(111)

 m = Basemap(projection='robin', lon_0=0,resolution='c')
 m.drawmapboundary(fill_color='white', zorder=-1)
 m.drawparallels(np.arange(-90.,91.,30.), labels=[1,0,0,1], dashes=[1,1], linewidth=0.25, color='0.5',fontsize=14)
 m.drawmeridians(np.arange(0., 360., 60.), labels=[1,0,0,1], dashes=[1,1], linewidth=0.25, color='0.5',fontsize=14)

 year = [1990,2005,2010,2015,2020,2025,2030,2035,2040,2045,2050,2055,2060,2065,2070,2075,2080,2085,2090,2095,2100]
 GCAM_32 = ['PRI','USA','VIR']
 GCAM_1 = ['BDI','COM','DJI','ERI','ETH','KEN','MDG','MUS','REU','RWA','SDS','SDN','SOM','UGA','SOL']
 GCAM_2 = ['DZA','EGY','ESH','LBY','MAR','TUN','SAH']
 GCAM_3 = ['AGO','BWA','LSO','MOZ','MWI','NAM','SWZ','TZA','ZMB','ZWE']
 GCAM_4 = ['BEN','BFA','CAF','CIV','CMR','COD','COG','CPV','GAB','GHA','GIN','GMB','GNB','GNQ','LBR','MLI','MRT','NER','NGA','SEN','SLE','STP','TCD','TGO']
 GCAM_6 = ['AUS','NZL']
 GCAM_7 = ['BRA']
 GCAM_8 = ['CAN']
 GCAM_9 = ['ABW','AIA','ANT','ATG','BHS','BLZ','BMU','BRB','CRI','CUB','CYM','DMA','DOM','GLP','GRD','GTM','HND','HTI','JAM','KNA','LCA','MSR','MTQ','NIC','PAN','SLV','TTO','VCT']
 GCAM_10 = ['ARM','AZE','GEO','KAZ','KGZ','MNG','TJK','TKM','UZB']
 GCAM_11 = ['CHN','HKG','MAC']
 GCAM_13 = ['BGR','CYP','CZE','EST','HUN','LTU','LVA','MLT','POL','ROM','SVK','SVN']
 GCAM_14 = ['AND','AUT','BEL','CHI','DEU','DNK','ESP','FIN','FLK','FRA','FRO','GBR','GIB','GRC','GRL','IMN','IRL','ITA','LUX','MCO','NLD','PRT','SHN','SMR','SPM','SWE','TCA','VAT','VGB','WLF']
 GCAM_15 = ['BLR','MDA','UKR']
 GCAM_16 = ['ALB','BIH','HRV','MKD','MNE','SCG','SRB','TUR','YUG']
 GCAM_17 = ['CHE','ISL','LIE','NOR','SJM']
 GCAM_18 = ['IND']
 GCAM_19 = ['IDN']
 GCAM_20 = ['JPN']
 GCAM_21 = ['MEX']
 GCAM_22 = ['ARE','BHR','IRN','IRQ','ISR','JOR','KWT','LBN','OMN','PSE','QAT','SAU','SYR','YEM']
 GCAM_23 = ['PAK']
 GCAM_24 = ['RUS']
 GCAM_25 = ['ZAF']
 GCAM_26 = ['GUF','GUY','SUR','VEN']
 GCAM_27 = ['BOL','CHL','ECU','PER','PRY','URY']
 GCAM_28 = ['AFG','ASM','BGD','BTN','LAO','LKA','MDV','NPL']
 GCAM_29 = ['KOR']
 GCAM_30 = ['BRN','CCK','COK','CXR','FJI','FSM','GUM','KHM','KIR','MHL','MMR','MNP','MYS','MYT','NCL','NFK','NIU','NRU','PCI','PCN','PHL','PLW','PNG','PRK','PYF','SGP','SLB','SYC','THA','TKL','TLS','TON','TUV','VNM','VUT','WSM']
 GCAM_31 = ['TWN']
 GCAM_5 = ['ARG']
 GCAM_12 = ['COL']

 list_list = [GCAM_1,GCAM_2,GCAM_3,GCAM_4,GCAM_5,GCAM_6,GCAM_7,GCAM_8,GCAM_9,GCAM_10,GCAM_11,GCAM_12,GCAM_13,GCAM_14,GCAM_15,GCAM_16,GCAM_17,GCAM_18,GCAM_19,GCAM_20,GCAM_21,GCAM_22,GCAM_23,GCAM_24,GCAM_25,GCAM_26,GCAM_27,GCAM_28,GCAM_29,GCAM_30,GCAM_31,GCAM_32]
 num = len(list_list)
 for info, shape in zip(m.comarques_info,m.comarques):
 for i in range(num):
 if info['adm0_a3'] in list_list[i]:
 patches1 = []
 patches1.append( Polygon(np.array(shape), True) )
 ax.set_title('Policy Cost',fontsize=25,y=1.01)#GDP Adjusted Policy Cost#Policy Cost#Policy Cost Reduction from Technology
 plt.annotate('%s'%year[idx],xy=(0.1,0.2),xytext=(0.1,0.2),xycoords='axes fraction',fontsize=30)
 cb = m.colorbar(colors,'right')
 filename = &amp;quot;out/map_%s.png&amp;quot; %(str(idx).rjust(3,&amp;quot;0&amp;quot;))

The function’s name is world_plot and it’s inputs are:

  1. The raw data for a specific time step.
  2. The index of the time step for the map we are working with (e.g. idx=0 for 2020).
  3. The minimum and maximum of the data across countries and time.

(1) is plotted, (2) is used to name the resulting png figure (line 73), and (3) is used to scale the color colormap (line 11).  On lines 2-8 we import the necessary Python packages, which in this case are pretty standard Matplotlib packages and numpy.  On lines 11-16 we generate a numpy array which contains the rgba color code for each of the data points in data.  In lines 18-19 we create the pyplot figure object.

On lines 21-24 we create and format the Basemap object.  Note that on line 21 I’ve selected the Robinson projection, but that the Basemap provides many options.

Lines 26-60 are specific for this application, and certainly could have been handled more compactly if I wanted to invest the time.  year is a list of time steps for our GCAM experiment, and lines 27-58 contain lists of three letter ID codes for each GCAM region, which are assembled into a list of lists (creatively called list_list) on line 60.

On line 61 we read the data from the shapefile database which was downloaded from Natural Earth. From lines 63-68 we loop through the info and shape attributes of the shapefile database, and determine which of the GCAM geo-political units each of the administrative units in the database is associated with.  Once this is determined, the polygon associated with that administrative unit is given the correct color (lines 66-68).

Lines 69-72 are doing some final formatting and labeling, and in lines 73-75 we are giving the file a unique name (tied to the time step plotted) and saving the images to some output directory.

When we put this function into a loop over time, we generate a sequence of figures looking something like this:


To convert this sequence of PNGs to a gif file, we use the convert command in linux (or in my case Cygwin).  So, we go to the command line and cd into the directory where we’ve saved our figures and type:

convert -delay 45 -loop 0 *.png globe_Cost_Reduction_faster.gif

Here the delay flag controls the framerate of the gif (in milliseconds), the loop flag controls whether the gif repeats, next I’m using a wildcat to include all of the pngs in the output directory, and the final input is the resulting name of the gif. The final product:



Scenario discovery in Python

The purpose of this blog post is to demonstrate how one can do scenario discovery in python. This blogpost will use the exploratory modeling workbench available on github. I will demonstrate how we can perform both PRIM in an interactive way, as well as briefly show how to use CART, which is also available in the exploratory modeling workbench. There is ample literature on both CART and PRIM and their relative merits for use in scenario discovery. So I won’t be discussing that here in any detail. This blog was first written as an ipython notebook, which can be found here

The workbench is mend as a one stop shop for doing exploratory modeling, scenario discovery, and (multi-objective) robust decision making. To support this, the workbench is split into several packages. The most important packages are expWorkbench that contains the support for setting up and executing computational experiments or (multi-objective) optimization with models; The connectors package, which contains connectors to vensim (system dynamics modeling package), netlogo (agent based modeling environment), and excel; and the analysis package that contains a wide range of techniques for visualization and analysis of the results from series of computational experiments. Here, we will focus on the analysis package. It some future blog post, I plan to demonstrate the use of the workbench for performing computational experimentation and multi-objective (robust) optimization.

The workbench can be found on github and downloaded from there. At present, the workbench is only available for python 2.7. There is a seperate branch where I am working on making a version of the workbench that works under both python 2.7 and 3. The workbench is depended on various scientific python libraries. If you have a standard scientific python distribution, like anaconda, installed, the main dependencies will be met. In addition to the standard scientific python libraries, the workbench is also dependend on deap for genetic algorithms. There are also some optional dependencies. These include seaborn and mpld3 for nicer and interactive visualizations, and jpype for controlling models implemented in Java, like netlogo, from within the workbench.

In order to demonstrate the use of the exploratory modeling workbench for scenario discovery, I am using a published example. I am using the data used in the original article by Ben Bryant and Rob Lempert where they first introduced scenario discovery. Ben Bryant kindly made this data available for my use. The data comes as a csv file. We can import the data easily using pandas. columns 2 up to and including 10 contain the experimental design, while the classification is presented in column 15

import pandas as pd

data = pd.DataFrame.from_csv('./data/bryant et al 2010 data.csv',
x = data.ix[:, 2:11]
y = data.ix[:, 15]

The exploratory modeling workbench is built on top of numpy rather than pandas. This is partly a path dependecy issue. The earliest version of prim in the workbench is from 2012, when pandas was still under heavy development. Another problem is that the pandas does not contain explicit information on the datatypes of the columns. The implementation of prim in the exploratory workbench is however datatype aware, in contrast to the scenario discovery toolkit in R. That is, it will handle categorical data differently than continuous data. Internally, prim uses a numpy structured array for x, and a numpy array for y. We can easily transform the pandas dataframe to either.

x = x.to_records()
y = y.values

the exploratory modeling workbench comes with a seperate analysis package. This analysis package contains prim. So let’s import prim. The workbench also has its own logging functionality. We can turn this on to get some more insight into prim while it is running.

from analysis import prim
from expWorkbench import ema_logging

Next, we need to instantiate the prim algorithm. To mimic the original work of Ben Bryant and Rob Lempert, we set the peeling alpha to 0.1. The peeling alpha determines how much data is peeled off in each iteration of the algorithm. The lower the value, the less data is removed in each iteration. The minimium coverage threshold that a box should meet is set to 0.8. Next, we can use the instantiated algorithm and find a first box.

prim_alg = prim.Prim(x, y, threshold=0.8, peel_alpha=0.1)
box1 = prim_alg.find_box()

Let’s investigate this first box is some detail. A first thing to look at is the trade off between coverage and density. The box has a convenience function for this called show_tradeoff. To support working in the ipython notebook, this method returns a matplotlib figure with some additional information than can be used by mpld3.

import matplotlib.pyplot as plt



The notebook contains an mpld3 version of the same figure with interactive pop ups. Let’s look at point 21, just as in the original paper. For this, we can use the inspect method. By default this will display two tables, but we can also make a nice graph instead that contains the same information.

box1.inspect(21, style='graph')

This first displays two tables, followed by a figure

coverage    0.752809
density     0.770115
mass        0.098639
mean        0.770115
res dim     4.000000
Name: 21, dtype: float64

                            box 21
                               min         max     qp values
Demand elasticity        -0.422000   -0.202000  1.184930e-16
Biomass backstop price  150.049995  199.600006  3.515113e-11
Total biomass           450.000000  755.799988  4.716969e-06
Cellulosic cost          72.650002  133.699997  1.574133e-01

fig 2

If one where to do a detailed comparison with the results reported in the original article, one would see small numerical differences. These differences arise out of subtle differences in implementation. The most important difference is that the exploratory modeling workbench uses a custom objective function inside prim which is different from the one used in the scenario discovery toolkit. Other differences have to do with details about the hill climbing optimization that is used in prim, and in particular how ties are handled in selecting the next step. The differences between the two implementations are only numerical, and don’t affect the overarching conclusions drawn from the analysis.

Let’s select this 21 box, and get a more detailed view of what the box looks like. Following Bryant et al., we can use scatter plots for this.
fig = box1.show_pairs_scatter()


We have now found a first box that explains close to 80% of the cases of interest. Let’s see if we can find a second box that explains the remainder of the cases.

box2 = prim_alg.find_box()

The logging will inform us in this case that no additional box can be found. The best coverage we can achieve is 0.35, which is well below the specified 0.8 threshold. Let’s look at the final overal results from interactively fitting PRIM to the data. For this, we can use to convenience functions that transform the stats and boxes to pandas data frames.

print prim_alg.stats_to_dataframe()
print prim_alg.boxes_to_dataframe()
       coverage   density      mass  res_dim
box 1  0.752809  0.770115  0.098639        4
box 2  0.247191  0.027673  0.901361        0
                             box 1              box 2
                               min         max    min         max
Demand elasticity        -0.422000   -0.202000   -0.8   -0.202000
Biomass backstop price  150.049995  199.600006   90.0  199.600006
Total biomass           450.000000  755.799988  450.0  997.799988
Cellulosic cost          72.650002  133.699997   67.0  133.699997

For comparison, we can also use CART for doing scenario discovery. This is readily supported by the exploratory modelling workbench.

from analysis import cart
cart_alg = cart.CART(x,y, 0.05)

Now that we have trained CART on the data, we can investigate its results. Just like PRIM, we can use stats_to_dataframe and boxes_to_dataframe to get an overview.

print cart_alg.stats_to_dataframe()
print cart_alg.boxes_to_dataframe()
       coverage   density      mass  res dim
box 1  0.011236  0.021739  0.052154        2
box 2  0.000000  0.000000  0.546485        2
box 3  0.000000  0.000000  0.103175        2
box 4  0.044944  0.090909  0.049887        2
box 5  0.224719  0.434783  0.052154        2
box 6  0.112360  0.227273  0.049887        3
box 7  0.000000  0.000000  0.051020        3
box 8  0.606742  0.642857  0.095238        2
                       box 1                  box 2               box 3  \
                         min         max        min         max     min
Cellulosic yield        80.0   81.649994  81.649994   99.900002  80.000
Demand elasticity       -0.8   -0.439000  -0.800000   -0.439000  -0.439
Biomass backstop price  90.0  199.600006  90.000000  199.600006  90.000   

                                         box 4                box 5  \
                               max         min         max      min
Cellulosic yield         99.900002   80.000000   99.900002   80.000
Demand elasticity        -0.316500   -0.439000   -0.316500   -0.439
Biomass backstop price  144.350006  144.350006  170.750000  170.750   

                                      box 6                  box 7  \
                               max      min         max        min
Cellulosic yield         99.900002  80.0000   89.050003  89.050003
Demand elasticity        -0.316500  -0.3165   -0.202000  -0.316500
Biomass backstop price  199.600006  90.0000  148.300003  90.000000   

                                         box 8
                               max         min         max
Cellulosic yield         99.900002   80.000000   99.900002
Demand elasticity        -0.202000   -0.316500   -0.202000
Biomass backstop price  148.300003  148.300003  199.600006

Alternatively, we might want to look at the classification tree directly. For this, we can use the show_tree method. This returns an image that we can either save, or display.


If we look at the results of CART and PRIM, we can see that in this case PRIM produces a better description of the data. The best box found by CART has a coverage and density of a little above 0.6. In contrast, PRIM produces a box with coverage and density above 0.75.