A step-by-step tutorial for scenario discovery with gradient boosted trees

Our recently published eBook, Addressing Uncertainty in Multisector Dynamics Research, provides several interactive tutorials for hands on training in model diagnostics and uncertainty characterization. The purpose of this post is to expand upon these trainings by providing a tutorial demonstrating gradient boosted trees for scenario discovery. I’ll first provide some brief background on scenario discovery and gradient boosted trees, then demonstrate a Python implementation on a water supply planning problem. All code here is written in Python, but the workflow is model agnostic, and can be paired with simulation models in any language. I’ve included my code within the text below, but all code and data for this post can also be found in this git repository.

Scenario discovery gradient boosted trees

In water resources planning and management, decision makers are often faced with uncertainty about how their system will change in the future. Traditionally, planners have confronted this uncertainty by developing prespecified narrative scenarios, which reduce the multitude of possible future conditions into a small subset of important future states of the world (a prominent example is the ‘scenario matrix framework’ used to evaluate climate change (O’Neill et al., 2014)). While this approach provides intuitive appeal, it may increase system vulnerability if future conditions do not evolve as decision makers expect (for a detailed critique of scenario based planning see Reed et al., 2022). This vulnerability is especially apparent for systems facing deep uncertainty, where decision makers do not know or cannot agree upon the probability density functions of key system inputs (Kwakkel et al., 2016).

Scenario discovery (Groves and Lempert, 2007) is an exploratory modeling centered approach that seeks to discover consequential future scenarios using computational experiments rather than relying on prespecified information. To perform scenario discovery, decision makers first identify a set of relevant uncertainties and their plausible ranges. Next, an ensemble of these uncertainties is developed by sampling across parameter ranges. Candidate policies are then evaluated across this ensemble and machine learning or data mining algorithms are used to examine which combinations of uncertainties generate vulnerability in the system. These combinations can then be used to develop narrative scenarios to inform implementation and monitoring efforts or new policy development.

A core element of the scenario discovery process is the algorithm used to classify future states of the world. Popular algorithms include the PRIM, CART and logistic regression. Recently, gradient boosted trees have been applied as an alternative classificiation algorithm. Gradient boosted trees have advantages over other scenario discovery algorithms because they can easily capture nonlinear and non-differentiable boundaries in the uncertainty space (which are particularly prevalent in water supply planning problems that have discrete capacity expansion options), are highly resistant to overfitting and provide a clear means of ranking the importance of uncertain factors (Trindade et al., 2020). For a comprehensive overview of gradient boosted trees, see Bernardo’s post here.

Test case: the Sedento Valley

To demonstrate gradient boosted trees for scenario discovery we’ll use the Sedento Valley water supply planning test case (Trindade et al., 2020). In the Sedento Valley, three water utilities seek to discover cooperative water supply managment and infrastructure investment portfolios to meet several conflicting objectives in a system facing deep uncertainty. In this post, we’ll investigate how these deep uncertainties (which include demand growth, the efficacy of water use restrictions, financial variables and parameters governing infrastructure permitting and construction time) impact a utility’s ability to maintain three performance criteria: keeping reliability > 98%, restriction frequency < 20% and worst case cost less than 10% of annual revenue. For simplicity, we’ll focus on one regional water utility named Watertown.

Step 1: create a sample of deeply uncertain states of the world

To start the scenario discovery process, we generate an ensemble of deep uncertainties that represent future states of the world (SOWs). Here, we’ll use Latin Hypercube Sampling with an implementation I found in the Surrogate Modeling Toolbox.

import numpy as np
from smt.sampling_methods import LHS

This script will generate 1000 Latin Hypercube Samples (LHS)
of deeply uncertain system parameters for the Sedento Valley

# create an array storing the ranges of deeply uncertain parameters
DU_factor_limits = np.array([
    [0.9, 1.1], # Watertown restriction efficacy 
    [0.9, 1.1], # Dryville restriction efficacy
    [0.9, 1.1], # Fallsland restriction efficacy
    [0.5, 2.0], # Demand growth rate multiplier
    [1.0, 1.2], # Bond term
    [0.6, 1.0], # Bond interest rate
    [0.6, 1.4], # Discount rate
    [0.75, 1.5], # New River Reservoir permitting time
    [1.0, 1.2], # New River Reservoir construction time
    [0.75, 1.5], # College Rock Reservoir (low) permitting time
    [1.0, 1.2], # College Rock Reservoir (low) construction time
    [0.75, 1.5], # College Rock Reservoir (high) permitting time
    [1.0, 1.2], # College Rock Reseroir (high) construction time
    [0.75, 1.5], # Water Reuse permitting time
    [1.0, 1.2], # Water Reuse construction time
    [0.8, 1.2], # Inflow amplitude
    [0.2, 0.5], # Inflow frequency
    [-1.57, 1.57]]) # Inflow phase

# Use the smt package to set up the LHS sampling
sampling = LHS(xlimits=DU_factor_limits)

# We will create 1000 samples
num = 1000

# Create the actual sample
x = sampling(num)

# save to a csv file
np.savetxt('DU_factors.csv', x, delimiter=',')

Step 2: Evaluate performance across SOWs

Next, we’ll evaluate the performance of our policy across the LHS sample of DU factors. For the Sedento Valley test case, we use WaterPaths, an open-source simulation system for integrated water supply portfolio management and infrastructure investment planning (for more see Trindade et al., 2020). This step is not included in the git repository as it requires high-performance computing for this system, but results can be found in the “Model_output.csv” file. For simulation details, see Gold et al., 2022.

Step 3: Convert model output into a boolean array for classification

To perform classification, we need to convert the results of our simulations to a binary array classifying each SOW as a “success” or “failure” based on whether the policy met the performance criteria under the SOW. First, we define a small function to determine if an SOW meets a set of criteria, then we apply this function to our results. We also load the DU factor LHS sample.

# First, define a function to check whether performance criteria are met
def check_criteria(objectives, crit_objs, crit_vals):
    Determines if an objective meets a given set of criteria for a set of SOWs

        objectives: np array of all objectives across a set of SOWs
        crit_objs: the column index of the objective in question
        crit_vals: an array containing [min, max] of the values 
        meets_criteria: an numpy array containing the SOWs that meet both min and max criteria

    # check max and min criteria for each objective
    meet_low = objectives[:, crit_objs] >= crit_vals[0]
    meet_high = objectives[:, crit_objs] <= crit_vals[1]

    # check if max and min criteria are met at the same time
    meets_criteria = np.hstack((meet_low, meet_high)).all(axis=1)

    return meets_criteria

##### Load data and pre-process #####

# load objectives and create input array of boolean values for SD input
Reeval_objectives = np.loadtxt('Model_output.csv', skiprows=1, delimiter=',')
REL = check_criteria(Reeval_objectives, [0], [.979, 1])
RF = check_criteria(Reeval_objectives, [1], [0, 0.10])
WCC = check_criteria(Reeval_objectives, [2], [0, 0.10])
SD_input = np.vstack((REL, RF, WCC)).SD_input(axis=0)

# load DU factors
DU_factors = np.loadtxt('DU_factors.csv', skiprows=1, delimiter=',')
DU_names = ['Watertown Rest. Eff.', 'Dryville Rest. Eff.', 'FSD_inputsland Rest. Eff.',
            'Demand Growth Rate', 'Bond Term', 'Bond Interest',
            'Discount Rate', 'NRR Perm', 'NRR Const', 'CRR L Perm',
            'CRR L Const.',	'CRR H Perm.', 'CRR H Const.', 'WR1 Perm.',
             'WR1 Const.', 'Inflows A', 'Inflows m','Inflows p']

Step 4: Fit a boosted trees classifier

After we’ve formatted the data, we’re ready to perform boosted trees classification. There are several packages for boosted trees in Python, here we’ll use the implementation from scikit-learn. We’ll use an ensemble of 200 trees with depth 3 and a learning rate of 0.1. These parameters need to be tuned for the individual problem, I found this nice post that goes into detail on parameter tuning.

##### Boosted Tree Classification #####

from sklearn.ensemble import GradientBoostingClassifier

# create a gradient boosted classifier object
gbc = GradientBoostingClassifier(n_estimators=200,

# fit the classifier
gbc.fit(DU_factors, SD_input)

Step 5: Examine which DU factors have the most impact on performance criteria

Now we’re ready to examine the results of our classification. First, we’ll examine how important each DU factor is to the classification results generated by boosted trees. To rank the imporance of each DU factor, we examine the percentage decrease in impurity of the ensemble of trees that is associated with each factor. In plain english, this is a measure of how helpful each DU factor is to correctly classifying SOWs. This infromation is generated during the fit of the classifier above and is easily accessible as an attribute of our scikit-learn classifier.

For our example, one deep uncertainty, demand growth rate, clearly stands out as the most influential, as shown in the figure below. A second, the restriction efficacy for Watertown (the utility we’re focusing on), also stands out as a higher level of importance. All other DU factors have little impact on the classification in this case.

##### Factor Ranking #####

# Extract the feature importances
feature_importances = deepcopy(gbc.feature_importances_)

# rank the feature importances and plot
importances_sorted_idx = np.argsort(feature_importances)
sorted_names = [DU_names[i] for i in importances_sorted_idx]

fig = plt.figure(figsize=(8,8))
ax = fig.gca()
ax.barh(np.arange(len(feature_importances)), feature_importances[importances_sorted_idx])
ax.set_xlabel('Feature Importance')

Step 6: Create factor maps

Finally, we visualize the results of our classification through factor mapping. In the plot below, we show the uncertainty space projected onto the two most influential factors, demand growth rate and restriciton efficacy. Each point represents a sampled SOW, red points represent SOWs that resulted in failure, while white points represent SOWs that resulted in success. The color in the background shows the predicted regions of success and failure from the boosted trees classification.

Here we observe that high levels of demand growth are the primary source of vulnerability for the water utility. When restriction efficacy is lower than our estimate (multiplier < 1), the utility faces failures at demand growth levels of about 1.7 times the estimated values. When restriction effectiveness is at or above estimates, the acceptable scaling of demand growth raises to about 1.8.

Taken as a whole, these results provide valueable insights for decision makers. From our original 18 deep uncertainties, we have discovered that two are critical for the success of our water supply management policy. Further, we have defined thresholds within the uncertainty space that define scenarios that lead to failure. We can use this information to inform monitoring efforts for the water supply policy, or to inform a new problem formulation that tailors actions to mitigate these vulnerabilities.

##### Factor Mapping #####

# Select the top two factors discovered above
selected_factors = DU_factors[:, [3,0]]

# Fit a classifier using only these two factors
gbc_2_factors = GradientBoostingClassifier(n_estimators=200,
gbc_2_factors.fit(selected_factors, SD_input)

# plot prediction contours
x_data = selected_factors[:,0]
y_data = selected_factors[:,1]

x_min, x_max = (x_data.min(), x_data.max())
y_min, y_max = (y_data.min(), y_data.max())

# create a grid to makes predictions on
xx, yy = np.meshgrid(np.arange(x_min, x_max * 1.001, (x_max - x_min) / 100),
                        np.arange(y_min, y_max * 1.001, (y_max - y_min) / 100))
dummy_points = list(zip(xx.ravel(), yy.ravel()))

z = gbc_2_factors.predict_proba(dummy_points)[:, 1]
z[z < 0] = 0.
z = z.reshape(xx.shape)

# plot the factor map        
fig = plt.figure(figsize=(10,8))
ax = fig.gca()
ax.contourf(xx, yy, z, [0, 0.5, 1.], cmap='RdBu',
                alpha=.6, vmin=0.0, vmax=1)
ax.scatter(selected_factors[:,0], selected_factors[:,1],\
            c=SD_input, cmap='Reds_r', edgecolor='grey', 
            alpha=.6, s= 100, linewidth=.5)
ax.set_xlim([.5, 2])
ax.set_xlabel('Demand Growth Multiplier')
ax.set_ylabel('Restriction Eff. Multiplier')


Gold, D. F., Reed, P. M., Gorelick, D. E., & Characklis, G. W. (2022). Power and Pathways: Exploring Robustness, Cooperative Stability, and Power Relationships in Regional Infrastructure Investment and Water Supply Management Portfolio Pathways. Earth’s Future, 10(2), e2021EF002472.

Groves, D. G., & Lempert, R. J. (2007). A new analytic method for finding policy-relevant scenarios. Global Environmental Change, 17(1), 73-85.

Kwakkel, J. H., Walker, W. E., & Haasnoot, M. (2016). Coping with the wickedness of public policy problems: approaches for decision making under deep uncertainty. Journal of Water Resources Planning and Management, 142(3), 01816001.

O’Neill, B. C., Kriegler, E., Riahi, K., Ebi, K. L., Hallegatte, S., Carter, T. R., … & van Vuuren, D. P. (2014). A new scenario framework for climate change research: the concept of shared socioeconomic pathways. Climatic change, 122(3), 387-400.

Reed, P.M., Hadjimichael, A., Malek, K., Karimi, T., Vernon, C.R., Srikrishnan, V., Gupta, R.S., Gold, D.F., Lee, B., Keller, K., Thurber, T.B., & Rice, J.S. (2022). Addressing Uncertainty in Multisector Dynamics Research [Book]. Zenodo. https://doi.org/10.5281/zenodo.6110623

Trindade, B. C., Gold, D. F., Reed, P. M., Zeff, H. B., & Characklis, G. W. (2020). Water pathways: An open source stochastic simulation system for integrated water supply portfolio management and infrastructure investment planning. Environmental Modelling & Software, 132, 104772.

Easy batch parallelization of code in any language using mpi4py

The simplest form of parallel computing is what’s known as “embarrassingly” parallel processes. These processes involve fully independent runs of a model or script where little or no communication is needed across parallel processes. A common example is Monte Carlo evaluation, when we run a model over an ensemble of inputs. To parallelize an embarrassingly parallel application we simply need to send a set of commands to the cluster telling it to run each sample on a different core (or set of cores). For small applications, this can be done by submitting each run individually. For larger applications, SLURM Job Arrays (which are nicely detailed in Antonia’s post, here) can efficiently batch large number of function calls to independent computing cores. While this method is efficient and effective, I find it sometimes can be hard to keep track of, as you may be submitting tens or hundreds of jobs at a time. An alternative approach to submitting embarrassingly parallel tasks is to utilize MPI with Python to dispatch and organize jobs.

I like the MPI / Python combo because it consolidates all parallel applications into a single job, meaning you have one job to keep track of on a cluster at a time, and one output file generated by the batch set of model runs. I also find Python slightly easier to edit and debug than Bash scripts (which are used to create job arrays). Additionally, it’s very easy to assign each computing core a set of function evaluations to run (this can also be done with Job arrays, but again, I find Python easier to work with). Though Python is the language used to coordinate parallel tasks, we can use it to parallelize code in any language, as I’ll demonstrate below.

In this post I’ll first provide some background on MPI and its Python implementation, mpi4py. Next I’ll provide an example I’ve developed to demonstrate how to batch run a Matlab code on a cluster. The examples presented here are derived from some of Bernardo’s code in his post on Parallel programming in C/C++, which you can find here.

A very light introduction to MPI

MPI stands for “Message Passing Interface” and is the standard library for distributed memory parallelization (for background, see this post). To understand how MPI works, it’s helpful to define some of it’s basic components.

  1. Tasks: I’ll use the term task to define a processor (or group of processors) assigned to perform a specific set of instructions. These instructions may by a single evaluation of a function, or a set of function evaluations
  2. Communicators: A communicator is a group of MPI task units that are permitted to communicate with each other. In advanced MPI applications you may have multiple communicators, but for embarrassingly parallel applications we’ll only use one. The default communicator is called “MPI_COMM_WORLD” (I don’t know why, if anyone does please feel free to share in the comments), and that’s what I’ll work with here.
  3. Ranks: Each MPI task is assigned a unique identifier within the communicator called a rank. The processors running each task can access their own rank number, which will play an important role in how we use MPI for embarrassingly parallel applications.

A example schematic of the MPI_COMM_WORLD communicator with six tasks and their associated ranks is shown below.


MPI is implemented in Python with the mpi4py library. When we run an MPI code on a cluster, MPI creates the communicator and assigns each task a rank, then each task unit independently load the script. The processor/s associated with a task can then access their own unique rank.

The following snip of code loads this library, accesses the communicator and stores the rank of the given process:

# load the mpi4py library
from mpi4py import MPI

# access the MPI COMM WORLD communicator and assign it to a variable

# get the rank of the current process (different for each process on the cluster)
rank = comm.Get_rank()

Example of using mpi4py to batch parallel jobs

Here, I’ll parallelize the submission of a Matlab script called demoScript.m. This script reads an input file from a specific file location and prints out the contents of that file. For example purposes I’ve created 20 input files, each in their own folders. The folders are called “input_sample_0”, “input_sample_1” etc.. Each input_sample folder contains a file called “sample_data.txt”, which contains one line of text reading: “This is data for run <sample_number>”.

All code for this example can be found on Github, here: https://github.com/davidfgold/mpi4py_blog.git

Batching runs of demoScript.m process involves three components:

  1. Write demoScript.m so that it reads the sample number from the input.
  2. Write a Python script that will use mpi4py to distribute calls of demoScript.m. Here I’ll call this script “callDemoScript.py”
  3. Write a Bash script that sets up your MPI run and calls the Python function. Here I’ll call this script “submitDemoScript.sh”

1. demoScript.m

The demo Matlab script is found below. It reads in two arguments that are called from the command line. The first argument is the rank, which will vary for each task, and the second is the sample number, which will specify which input folder to read from.

% demoScript.m
% reads an input file from a given sample number (specified via command line)
% prints output from the sample file associated with the sample number
% also prints the rank for demonstration purposes

% read in command line input
arg_list = argv();
rank = arg_list{1,1}; % rank is the first argument
sample = arg_list{2, 1}; % sample number is the second argument

% Create a string that contains the location of the proper sample directory
sample_out = fileread(strcat("input_sample_", sample, "/sample_data.txt"));

% create a string to print the rank number
rank_call = strcat("This is rank_", rank, ", recieving the following input: \n");

% format the output and print
output = strcat(rank_call, sample_out);

2. callDemoScript.py

The second component is a Python script that uses mpi4py to call demoScript.m many times across different tasks. Each task will run a number of samples equal to a variable called “N_SAMPLES_PER_TASK” which will be fed to this script when it is called.


Called to batch demoScript.m across multiple MPI tasks

Reads in the total tasks and number of samples per task from command line.
# load necessary libraries
from mpi4py import MPI
import numpy as np
import sys
import os
import time

# locate the COMM WORLD communicator

# get the number of the current rank
rank = comm.Get_rank()

# read in arguments from the submission script
TOTAL_TASKS = int(sys.argv[1]) # number of MPI processes
N_SAMPLES_PER_TASK = int(sys.argv[2]) # number of runs per/task

# loop through samples assigned to current rank
for i in range(N_SAMPLES_PER_TASK):
	sample= rank + TOTAL_TASKS * i

	# write the command that will be sent to the terminal (here RUN will replace the {})
	terminal_command = "octave-cli ./demoScript.m {} {} ".format(rank, sample)

	# write the terminal command to the process

	# sleep before submitting the next command
	time.sleep(1) # optional, for memory intensive submissions



The final component is a Bash script that will send this MPI job to the cluster. Here I’ll use SLURM to create 4 MPI tasks across 2 Nodes (each node will have 2 associated task). This will create a total of 4 MPI tasks, and each task will be assigned 5 samples to run.

I wrote this for a local cluster at Cornell, note that I had to load two modules to run Python and a third to run Octave (which is used to call Matlab scripts on Linux). I’ll call the Python script with mpirun, and then specify the total number of MPI tasks before making the function call. The output of the script is printed to a text file called demoOutput.txt

# Set up your parallel runs
SAMPLES_PER_TASK=5 # number of runs for each MPI task
N_NODES=2 # number of nodes
TASKS_PER_NODE=2 # number of tasks per node

TOTAL_TASKS=$(($N_NODES*$TASKS_PER_NODE)) # total number of tasks

# Submit the parallel job
#SBATCH --time=0:01:00
#SBATCH --job-name=demoMPI4py
#SBATCH --output=output/demo.out
#SBATCH --error=output/demo.err
#SBATCH --exclusive
module load py3-mpi4py
module load py3-numpy
module load octave/6.3.0

mpirun -np $TOTAL_TASKS python3 callDemoScript.py $TOTAL_TASKS $SAMPLES_PER_TASK > demoOutput.txt

Additional resources

Putting some thought into how you design a set of parallel runs can save you a lot of time and headache. The example above has worked well for me when submitting sets of embarrassingly parallel tasks, but each application will be different, so take the time to find the procedure that works best for you. Our blog and the internet are full of resources that can help you parallelize your code, below are some suggestions:

Performing Experiments on HPC Systems

Scaling experiments: how to measure the performance of parallel code on HPC systems

Parallel processing with R on Windows

How to automate scripts on a cluster

Parallelization of C/C++ and Python on Clusters

Developing parallelised code with MPI for dummies, in C (Part 1/2)

Cornell CAC glossery on HPC terms: https://cvw.cac.cornell.edu/main/glossary

A great MPI tutorial I found online: https://mpitutorial.com/tutorials/

Make LaTeX easier with custom commands

LaTeX is a powerful tool for creating professional looking documents. Its ability to easily format mathematical equations, citations and complex figures makes LaTeX especially useful for developing peer-reviewed journal articles and scientific reports. LaTeX is highly customizable, which allows you to create documents that are not carbon copies of generic Microsoft Word templates.

Using LaTeX does have it’s drawbacks- instead of simply typing on a page, you construct the document by writing LaTeX code. Once you’ve written your code, a compiler translates it into a finished and formatted document. This can sometimes result in high overhead time for fixing bugs and managing format. But coding a document also has advantages, in addition to the vast array of existing LaTeX libraries and commands, you can create your own custom commands that speed up the writing and formatting process. Below I’ll overview the basics of creating custom LaTeX commands and provide some illustrative examples.

Commands with no arguments

If you have an equation or a complex sequence of text that you know you’ll be repeating, you can create a custom command to produce it. For example, if I’m constantly referencing the equation for an Ordinary Least Squares (OLS) estimator, I can make a new command that produces it:


There are three parts to defining this command, as shown in the figure below:

  1. Tell LaTex you are defining a new command by specifying “\newcommand”
  2. name the command (make sure to include the backslash)
  3. Specify the output of the new command

Example LaTex code that calls the OLS command:

I can store complex terms using a predefined command: \OLS

Compiled output:

Commands with basic arguments

Single argument

You can also define commands that accept arguments. For example, if you want to make commands to assist tracking changes in a document, you can create a command that formats a section of added text so it has the color blue and is bolded:

\newcommand{\addtxt}[1]{{\color{blue} \textbf{#1}}} % Highlight text that has been added

The command defined above accepts one argument (shown as the “[1]”) and calls that argument using “#1”, as highlighted in the figure below:

Example Latex code using my command:

Demonstrating my custom commands with arguments:\addtxt{This text has be inserted into this sentence}

Example compiled output:

Multiple arguments

You can also define commands with multiple arguments, for example, you can create a template sentence that provides an update the timing of a project:

\newcommand{\projReportA}[2]{The project was planned to finish on \textbf{#1}after reviewing current progress we have determined that it will likely finish on \textbf{#2}} % insert a date for when a project was planned to be completed and when a project is likely to be completed

Here, argument #1 is the date when the project was planned on being completed, and argument #2 is the date that the project will likely be completed.

Example use of this function:

Another way you can use an argument:\\ \\
\projReportA{September $9^{th}$}{October $1^{st}$}

Example compiled output:

Commands with optional arguments

The project report command above can be modified to accept a default completion date with an option to include an updated date.

\newcommand{\progReportB}[2][September 9th]{The project was planned to finish on \textbf{#2}, after reviewing current progress, we have determined that it will likely finish on \textbf{#1}}

To create an optional argument, specify the default value of the first argument in a new set of brackets. Note that for basic Latex this only works for a single default argument, for more defaults you can use a package such as xparse.

Here’s an example using this new command with the default argument:

Here I'll will use the command without the optional argument, so it will print the default: \\
\progReportB{September $9^{th}$}

This will compile to:

Here’s an example with the optional argument specified

Now I'll add the optional argument, which will be added in place of the default: \\ \\
\progReportB[October $1^{st}$]{September $9^{th}$}

This will compile to:

Concluding thoughts

These simple examples only scratch the surface of what you can do with LaTex commands. I should also note that while custom commands are useful, LaTex also contains a large suite of packages with predefined commands that can be easily imported into your document.

Helpful Latex resources:

Measuring the parallel performance of the Borg MOEA

In most applications, parallel computing is used to improve code efficiency and expand the scope of problems that can be addressed computationally (for background on parallelization, see references listed at the bottom of this post). For the Borg Many Objective Evolutionary Algorithm (MOEA) however, parallelization can also improve the quality and reliability of many objective search by enabling a multi-population search. The multi-population implementation of Borg is known as Multi-Master Borg, details can be found here. To measure the performance of Multi-Master Borg, we need to go beyond basic parallel efficiency (discussed in my last post, here), which measures the efficiency of computation but not the quality of the many objective search. In this post, I’ll discuss how we measure the performance of Multi-Master Borg using two metrics: hypervolume speedup and reliability.

Hypervolume speedup

In my last post, I discussed traditional parallel efficiency, which measures the improvement in speed and efficiency that can be achieved through parallelization. For many objective search, speed and efficiency of computation are important, but we are more interested in the speed and efficiency with which the algorithm produces high quality solutions. We often use the hypervolume metric to measure the quality of an approximation set as it captures both convergence and diversity (for a thorough explanation of hypervolume, see this post). Using hypervolume as a measure of search quality, we can then evaluate hypervolume speedup, defined as:

Hypervolume speedup = \frac{T_S^H}{T_P^H}

where T_S^H is the time it takes the serial version of the MOEA to achieve a given hypervolume threshold, and T_P^H is the time it takes the parallel implementation to reach the same threshold. Figure 1 below, adapted from Hadka and Reed, (2014), shows the hypervolume speedup across different parallel implementations of the Borg MOEA for the five objective NSGA II test problem run on 16,384 processors (in this work the parallel epsilon-NSGA II algorithm is used as a baseline rather than a serial implementation). Results from Figure 1 reveal that the Multi-Master implementations of Borg are able to reach each hypervolume threshold faster than the baseline algorithm and the master-worker implementation. For high hypervolume thresholds, the 16 and 32 Master implementations achieve the hypervolume thresholds 10 times faster than the baseline.

Figure 1: Hypervolume speedup for the five objective LRGV test problem across implementations of the Borg MOEA (epsilon NSGA-II, another algorithm) is used as the baseline here). This figure is adapted from Hadka and Reed, (2014).


MOEAs are inherently stochastic algorithms, they are be initialized with random seeds which may speedup or slow down the efficiency of the search process. To ensure high quality Pareto approximate sets, it’s standard practice to run an MOEA across many random seeds and extract the best solutions across all seeds as the final solution set. Reliability is a measure of the probability that each seed will achieve a high quality set of solutions. Algorithms that have higher reliability allow users to run fewer random seeds which saves computational resources and speeds up the search process. Salazar et al., (2017) examined the performance of 17 configurations of Borg on the Lower Susquehanna River Basin (LSRB) for a fixed 10 hour runtime. Figure 2 shows the performance of each configuration across 50 random seeds. A configuration that is able to achieve the best hypervolume across all seeds would be represented as a blue bar that extends to the top of the plot. The algorithmic configurations are shown in the plot to the right. These results show that though configuration D, which has a high core count and low master count, achieves the best overall hypervolume, it does not do so reliably. Configuration H, which has two masters, is able to achieve nearly the same hypervolume, but has a much higher reliability. Configuration L, which has four masters, achieves a lower maximum hypervolume, but has vary little variance across random seeds.

Figure 2: Reliability of search adapted from Salazar et al., (2017). Each letter represents a different algorithmic configuration (shown right) for the many objective LSRB problem across 10 hours of runtime. The color represents the probability that each configuration was able to attain a given level of hypervolume across 50 seeds.

These results can be further examined by looking at the quality of search across its runtime. In Figure 3, Salazar et al. (2017) compare the performance of the three algorithmic configurations highlighted above (D, H and L). The hypervolume achieved by the maximum and minimum seeds are shown in the shaded areas, and the median hypervolume is shown with each line. Figure 3 clearly demonstrates how improved reliability can enhance search. Though the Multi-Master implementation is able to perform fewer function evaluations in the 10 hour runtime, it has very low variance across seeds. The Master-worker implementation on the other hand achieves better performance with it’s best seed (it gets lucky), but its median performance never achieves the hypervolume of the two or four master configurations.

Figure 3: Runtime hypervolume dynamics for the LSRB problem by Salazar et al., (2017). The reduction in variance in the Multi-Master implementations of Borg demonstrate the benefits of improved reliability.

Concluding thoughts

The two measures introduced above allow us to quantify the benefits of parallelizing the Multi-Master Borg MOEA. The improvements to search quality not only allow us to reduce the time and resources that we need to expend on many objective search, but may also allow us to discover solutions that would be missed by the serial or Master-Worker implementations of the algorithm. In many objective optimization contexts, this improvement may fundamentally alter our understanding of what is possible in a challenging environmental optimization problems.

Parallel computing resources


Hadka, D., & Reed, P. (2015). Large-scale parallelization of the Borg multiobjective evolutionary algorithm to enhance the management of complex environmental systems. Environmental Modelling & Software, 69, 353-369.

Salazar, J. Z., Reed, P. M., Quinn, J. D., Giuliani, M., & Castelletti, A. (2017). Balancing exploration, uncertainty and computational demands in many objective reservoir optimization. Advances in water resources, 109, 196-210.

Scaling experiments: how to measure the performance of parallel code on HPC systems

Parallel computing allows us to speed up code by performing multiple tasks simultaneously across a distributed set of processors. On high performance computing (HPC) systems, an efficient parallel code can accomplish in minutes what might take days or even years to perform on a single processor. Not all code will scale well on HPC systems however. Most code has inherently serial components that cannot be divided among processors. If the serial component is a large segment of a code, the speedup gained from parallelization will greatly diminish. Memory and communication bottlenecks also present challenges to parallelization, and their impact on performance may not be readily apparent.

To measure the parallel performance of a code, we perform scaling experiments. Scaling experiments are useful as 1) a diagnostic tool to analyze performance of your code and 2) a evidence of code performance that can be used when requesting allocations on HPC systems (for example, NSF’s XSEDE program requires scaling information when requesting allocations). In this post I’ll outline the basics for performing scaling analysis of your code and discuss how these results are often presented in allocation applications.

Amdahl’s law and strong scaling

One way to measure the performance a parallel code is through what is known as “speedup” which measures the ratio of computational time in serial to the time in parallel:

speedup = \frac{t_s}{t_p}

Where t_s is the serial time and t_p is the parallel time.

The maximum speedup of any code is limited the portion of code that is inherently serial. In the 1960’s programmer Gene Amdahl formalized this limitation by presenting what is now known as Amdahl’s law:

Speedup = \frac{t_s}{t_p} = \frac{1}{s+(1-s)/p} < \frac{1}{s}

Where p is the number of processors, and s is the fraction of work that is serial.

On it’s face, Amdahl’s law seems like a severe limitation for parallel performance. If just 10% of your code is inherently serial, then the maximum speedup you can achieve is a factor of 10 ( s= 0.10, 1/.1 = 10). This means that even if you run your code over 1,000 processors, the code will only run 10 times faster (so there is no reason to run across more than 10 processors). Luckily, in water resources applications the inherently serial fraction of many codes is very small (think ensemble model runs or MOEA function evaluations).

Experiments that measure speedup of parallel code are known as “strong scaling” experiments. To perform a strong scaling experiment, you fix the amount of work for the code to do (ie. run 10,000 MOEA function evaluations) and examine how long it takes to finish with varying processor counts. Ideally, your speedup will increase linearly with the number of processors. Agencies that grant HPC allocations (like NSF XSEDE) like to see the results of strong scaling experiments visually. Below, I’ve adapted a figure from an XSEDE training on how to assess performance and scaling:

Plots like this are easy for funding agencies to assess. Good scaling performance can be observed in the lower left corner of the plot, where the speedup increases linearly with the number of processors. When the speedup starts to decrease, but has not leveled off, the scaling is likely acceptable. The peak of the curve represents poor scaling. Note that this will actually be the fastest runtime, but does not represent an efficient use of the parallel system.

Gustafson’s law and weak scaling

Many codes will not show acceptable scaling performance when analyzed with strong scaling due to inherently serial sections of code. While this is obviously not a desirable attribute, it does not necessarily mean that parallelization is useless. An alternative measure of parallel performance is to measure the amount of additional work that can be completed when you increase the number of processors. For example, if you have a model that needs to read a large amount of input data, the code may perform poorly if you only run it for a short simulation, but much better if you run a long simulation.

In the 1980s, John Gustafson proposed a relationship that notes relates the parallel performance to the amount of work a code can accomplish. This relationship has since been termed Gustafson’s law:

speedup = s+p*N

Where s and p are once again the portions of the code that are serial and parallel respectively and N is the number of core.

Gustafson’s law removes the inherent limits from serial sections of code and allows for another type of scaling analysis, termed “weak scaling”. Weak scaling is often measured by “efficiency” rather than speedup. Efficiency is calculated by proportionally scaling the amount of work with the number of processors and measure the ratio of completion times:

efficiency = \frac{t_1}{t_N}

Ideally, efficiency will be close to one (the time it take one processor to do one unit of work is the same time it takes N processors to do N units of work). For resource allocations, it is again advisable to visualize the results of weak scaling experiments by creating plots like the one shown below (again adapted from the XSEDE training).

Final thoughts

Scaling experiments will help you understand how your code will scale and give you a realistic idea of computation requirements for large experiments. Unfortunately however, it will not diagnose the source of poor scaling. To improve scaling performance, it often helps to improve the serial version of your code as much as possible. A helpful first step is to profile your code. Other useful tips are to reduce the frequency of data input/output and (if using compiled code) to check the flags on your compiler (see some other tips here).

Writing sharable Python code part II: Unit Testing

When writing Python functions that may be used by others, it is important to demonstrate that your functions work as intended. In my last post I discussed how a proper function specification establishes a sort of contract between developers and users of functions that delineates errors in implementation (user error) from bugs in the code (developer error). While this contract is provides an important chain of accountability, it does not actually contain any proof that the function works as advertised. This is where unit testing comes in. A unit test simply runs a function over a suite of test cases (set of inputs that produce known results) to verify performance. As its name implies, a single unit test is meant to test one basic component of a code. Large or complex codes will have many sets of unit tests, each testing different elements (usually individual functions). Unit testing provides users with proof that your code works as intended, and serves as a powerful tool for finding and removing any errors in your code.

In this post I’ll provide guidance on the development of unit tests and demonstrate an example for a simple python function. Material in this post represents my interpretation of content taught by Professor Walker White at Cornell, in a Python Fundamentals course that I proctored in the summer of 2020.

An example script

To illustrate unit testing I’ll examine the a function called “check_satisficing.py” which tests whether a set of performance objectives meets a set of satisficing criteria (for background on satisficing, see this post). The function returns a boolean and has three parameters:

  • objs: a numpy array containing the two objective values
  • criteria: a numpy array two criteria,
  • directions: a list of strings that specify whether the criteria is meant to be a lower or upper bound.

Note that the actual code for this function only takes eight lines, the rest is the function specification and precondition checks.

def check_satisficing(objs, criteria, directions):
    Returns whether a set of objectives meets a set of satisficing criteria
    Value return has type bool
        objs: objective values from a given solution
        criteria: satisficing criteria for each objective
        directions: a list of strings containing "ge" or "le" for each 
            criteria, where ge indicates satisfication for values greater than
            or equal to the criteria and le indicates satisfaction for values 
            less than or equal to the criteria
        objs = [.5, .5], criteria = [0.4, 0.4], directions = ['ge', 'ge'] 
            returns True
        objs = [.5, .5], criteria = [0.4, 0.4], directions = ['le', 'le'] 
            returns False     
        objs = [.4, .4], criteria = [0.5, 0.5], directions = ['le', 'le'] 
            returns True
        objs = [.5, .5], criteria = [0.4, 0.4], directions = ['ge', 'le'] 
            returns False
        objs = [.5, .5], criteria = [0.4, 0.4], directions = ['le', 'ge'] 
            returns False
        objs = [.5, .5], criteria = [0.5, 0.5], directions = ['ge', 'ge'] 
            returns True
        objs = [.5, .5], criteria = [0.5, 0.5], directions = ['le', 'le'] 
            returns True
        objs is a numpy array of floats with length 2
        criteria is a numpy array of floats with length 2
        directions is a list of strings with length 2 containing either 
            "ge" or "le"
    # check preconditions
    assert len(objs) == 2, 'objs has length ' + repr(len(objs)) + ', should be 2'
    assert len(criteria) == 2, 'criteria has length ' + repr(len(criteria)) + \
    ', should be 2'
    assert len(directions) == 2, 'directions has length ' + \
    repr(len(directions)) + ', should be 2'
    # check to make sure
    for i in range(2):
        assert type(objs[i])== np.float64, 'objs element ' + str(i) + ': ' + \
        repr(objs[i]) + ', is not a numpy array of floats'
        assert type(criteria[i])== np.float64, 'criteria element ' + str(i) + \
        ': ' + repr(criteria[i]) + ', is not a numpy array of floats'
        assert type(directions[i])== str, 'directions element ' + str(i) + \
        ': ' + repr(directions[i]) + ', is not a string'
        assert directions[i] == 'ge' or directions[i] == 'le', 'directions ' + \
        str(i) + ' is ' + repr(directions[i]) + ', should be either "ge" or "le"' 
    # loop through objectives and check if each meets the criteria
    meets_criteria = True
    for i in range(2):
        if directions[i] == 'ge':
            meets_criteria = meets_criteria and (objs[i] >= criteria[i])
            meets_criteria = meets_criteria and (objs[i] <= criteria[i])
    return meets_criteria

Developing test cases

If you read my last post, you may notice that this specification includes an extra section called “Examples”. These examples show the user how the function is supposed to perform. They also represent the suite of test cases used to validate the function. Creating test cases is more of an art than a science, and test cases will be unique to each function you write. However, there is a set of basic rule you can follow to guide your implementation of unit testing which I’ll outline below.

  1. Pick the simplest cases first. In this example, the simplest cases are when both objectives are both above or below the criteria
  2. Move to more complex cases. In this example, a more complex case could be when one objective is above and the other below, or vice versa
  3. Think about “weird” possibilities. One “weird” case for this code could be when one or both objectives are equal to the criteria
  4. Never test a precondition violation. Precondition violations are outside the scope of the function and should not be included in unit tests

Test cases should be exhaustive and even simple codes may have many test cases. In my example above I provide seven, can you think of any more that are applicable to this code?

Implementing a testing strategy

After you’ve developed your test cases, it’s time to implement your unit test. For demonstration purposes I’ve written my own unit test code which can be used to test the cases developed above. This function simply utilizes assert statements to check if each test input generates the correct output. A danger of writing your own testing function is that the test function itself may have errors. In practice, it’s easiest to use an established tool such as PyTest to perform unit testing (for in-depth coverage of PyTest see Bernardo’s post from 2019).

def test_check_satisficing():
    Unit test for the function check_satisficing
    import numpy as np
    from check_satisficing import check_satisficing
    print("Testing check_satisficing")
    # test case 1:
    objs1 = np.array([0.5, 0.5])
    criteria1 = np.array([0.4, 0.4])
    directions1 = ['ge','ge']
    result1 = True
    assert (check_satisficing(objs1, criteria1, directions1)) == result1, \
    'Test 1 failed ' + repr(objs1) + ', ' + repr(criteria1) + ', ' + \
    repr(directions1) + ' returned False, should be True'
    # test case 2:
    objs2 = np.array([0.5, 0.5])
    criteria2 = np.array([0.4, 0.4])
    directions2 = ['ge','le']
    result2 = False
    assert (check_satisficing(objs2, criteria2, directions2)) == result2, \
    'Test 2 failed ' + repr(objs2) + ', ' + repr(criteria2) + ', ' + \
    repr(directions2) + ' returned True, should be False'
     # test case 3:
    objs3 = np.array([0.4, 0.4])
    criteria3 = np.array([0.5, 0.5])
    directions3 = ['le','le']
    result3= True
    assert (check_satisficing(objs3, criteria3, directions3)) == result3, \
    'Test 3 failed ' + repr(objs3) + ', ' + repr(criteria3) + ', ' + \
    repr(directions3) + ' returned False, should be True'
     # test case 4:    
    objs4 = np.array([0.5, 0.5])
    criteria4 = np.array([0.4, 0.4])
    directions4 = ['ge','le']
    result4 = False
    assert (check_satisficing(objs4, criteria4, directions4)) == result4, \
    'Test 4 failed ' + repr(objs4) + ', ' + repr(criteria4) + ', ' + \
    repr(directions4) + ' returned True, should be False'
    # test case 5    
    objs5 = np.array([0.5, 0.5])
    criteria5 = np.array([0.4, 0.4])
    directions5 = ['le','ge']
    result5 = False
    assert (check_satisficing(objs5, criteria5, directions5)) == result5, \
    'Test 5 failed ' + repr(objs5) + ', ' + repr(criteria5) + ', ' + \
    repr(directions5) + ' returned True, should be False'
    # test case 6: 
    objs6 = np.array([0.5, 0.5])
    criteria6 = np.array([0.5, 0.5])
    directions6 = ['ge','ge']
    result6 = True
    assert (check_satisficing(objs6, criteria6, directions6)) == result6, \
    'Test 6 failed ' + repr(objs6) + ', ' + repr(criteria6) + ', ' + \
    repr(directions6) + ' returned False, should be True'
    # test case 7: 
    objs7 = np.array([0.5, 0.5])
    criteria7 = np.array([0.5, 0.5])
    directions7 = ['le','le']
    result7 = True
    assert (check_satisficing(objs7, criteria7, directions7)) == result7, \
    'Test 7 failed ' + repr(objs7) + ', ' + repr(criteria7) + ', ' + \
    repr(directions7) + ' returned False, should be True'
    print("check_satisficing has passed all tests!")


Concluding thoughts

When developing a new Python function, it’s good practice to code the test cases while you’re writing the function and test each component during the development cycle. Coding in this matter will allow you to identify and fix errors early an can save a lot of time and headache. Creating test cases during code development may also improve the quality of your code by helping you conceptualize and avoid potential bugs.

Writing sharable Python code

Most of us in academia do not have formal training in computer science, yet for many of us writing and sharing code is an essential part of our research. While the quality of code produced by many graduate students can be quite impressive, many of us never learned the basic CS principles that allow programs to be sharable and inheritable by other programmers. Last summer I proctored an online class in Python fundamentals. The course was for beginner programmers and, though it covered material simpler than the scripts I write for my PhD, I learned a lot about how to properly document and structure Python functions to make them usable by others. In this post I’ll discuss two important concepts I took away from this course: writing proper function specifications and enforcing preconditions using assert statements.

Function Specifications

One of the the most important aspects of writing a quality Python function is proper specification. While the term may sound generic, a specification actually has a very precise definition and implementation for a Python function. In practice, a specification is a docstring, a “string literal” that occurs as the first statement in a function, module, class or method, formed by a bracketed set of “””. Here is an example of a simple function I wrote with a specification:

def radians_to_degrees(theta):
    Returns: theta converted to degrees
    Value return has type float
    Parameter theta: the angle in radians
    Precondition: theta is a float
    return theta * (180.0/3.14159)

The function specification is everything between the sets of “””. When Python sees this docstring at the front of a function definition, it automatically is stored as the “__doc__” associated with the function. With this specification in place, any user that loads this function can access its __doc__ by typing help(radians_to_degrees), which will print the following to the terminal:

Help on function radians_to_degrees in module __main__:

    Returns: theta converted to degrees
    Value return has type float
    Parameter theta: the angle in radians
    Precondition: theta is a float

The help function will print anything in the docstring at the start of a function, but a it is good practice for the specification to have the following elements:

  1. A one-line summary of the function at the very beginning, if the function is a “fruitful function” (meaning it returns something), this line should tell what it returns. In my example I note that my function returns theta converted to degrees.
  2. A more detailed description of the function. In my example this simply provides the type of the return value.
  3. A description of the function’s parameter(s)
  4. Any preconditions that are necessary for the code to run properly (more on this later)

I should note that officially the Python programming language has a more flexible set of requirements for function specifications, which can be found here, but the attributes above are a good starting point for writing clear specifications.

Properly specifying a Python function will clarify the function’s intended use and provide instructions for how new users can utilize it. It will also help you document your code for formal release if you ever publish it. Google any of your favorite Python functions and you’ll likely be brought to a page that has a fancy looking version of the function’s specification. These pages can be automatically generated by tools such as Spinx that create them right from the function’s definition.

Aside from clarifying and providing instructions for your function, specifications provide a means of creating a chain of accountability for any problems with your code. This chain of accountability is created through precondition statements (element four above). A precondition statement dictates requirements for the function to run properly. Preconditions may specify the type of parameter input (i.e. x is a float) or a general statement about the parameter (x < 0).

For large teams of many developers and users of functions, precondition statements create a chain of accountability for code problems. If the preconditions are violated and a code crashes, then it is the responsibility of the user, who did not use the code properly. On the other hand, if the preconditions were met and the code crashes, it is the responsibility of the developer, who did not properly specify the code.

Asserting preconditions to catch errors early

One way to improve the usability of a code is to catch precondition violations with statements that throw specific errors. In Python, this may be done using assert statements. Assert statements evaluate a boolean expression and can display a custom error message if the expression returns false. For example, I can modify my radians_to_degrees function with the following assert statement (line 10):

def radians_to_degrees(theta):
    Returns: theta converted to degrees
    Value return has type float
    Parameter theta: the angle in radians
    Precondition: theta is a float
    assert type(theta) == float, repr(theta) + ' is not float'
    return theta * (180.0/3.14159)

A helpful function in my assert statement above is “repr”, which will return a printable representation of a given object (i.e. for a string it will keep the quotation marks around it). Now if I were to enter ‘a’ into my function, I would get the following return:

AssertionError: 'a' is not float

This is trivial for my example, but can save a lot of time when running large and complex functions. Rather than seeing 30 lines of “traceback…”, the user will be immediately brought to the source of the error. In general, you should try to make assert statements as precise as possible to pinpoint the violation.

It’s important to note that not every conceivable precondition violation requires an assert statement. There is a tradeoff between code efficiency and number of assert statements (checking a precondition takes time). Determining which preconditions to strictly enforce with assert statements is a balancing act and will be different for each application. It’s important to remember though, that even if you do not have an assert statement for a precondition, you should still list all preconditions in the function specification to preserve the chain of accountability.

Further resources

This post concerns specifications for Python functions. However, the use of docstrings is also important when coding modules, classes and public methods. Guidance on how these docstrings may be formatted can be found here.

While this post focused on Python, informative function specifications are important for all programming languages. Some resources on documentation in other languages can be found below:

A video training on Rhodium

A few weeks ago I filmed a video training guide to the Rhodium framework for the annual meeting of the society for Decision Making Under Deep Uncertainty. Rhodium is a Python library that facilitates Many Objective Robust Decision making. The training walks through a demonstration of Rhodium using the Lake Problem. The training introduces a live Jupyter notebook Antonia and I created using Binder.

To follow the training:

  1. Watch the demo video below
  2. Access the Binder Hub this link: https://mybinder.org/v2/gh/dgoldri25/Rhodium/7982d8fcb1de9a84f074cc
  3. Click on the file called “DMDU_Rhodium_Demo.ipynb” to open the live demo
  4. Begin using Rhodium!

Helpful Links

More on subplots with Matplotlib

I got a little ahead of myself with the title of my last post, “Everything you want to know about subplots in Python’s Matplotlib”. While the information in that post can allow you to do quite a lot, there is in fact more that you might want to know. In this post I’ll go over two more subplot tools that are helpful for designing informative and attractive subplots. First, I’ll discuss working with colorbars, a seemly minor task that can be quite time consuming. Then I’ll discuss a brand new feature of Matplotlib, the subplot_mosaic interface. This feature is still in its testing phase, but will likely be the new standard for making subplots in the future.

Formatting colorbars with constrained_layout

Colorbars play an important role in data visualization. Often, you may use a common color to link multiple views of the same data set, or contrast two data sets. Making a colorbar in matplotlib is fairly easy, but unless you use the right tools, making the colorbar fit into the overall graphic can be unexpectedly difficult. Here’s an example of creating a single colorbar for four different subplots. The fig.colorbar() function, allows you to easily add a colorbar to the set of subplots. Unfortunately, with the default settings, this code will shrink two subplots disproportionately.

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np

# create some sample data
x = np.linspace(-10.0, 10.0, 1000)
y = np.linspace(-10.0, 10.0, 1000)
X, Y = np.meshgrid(x, y)
Z1 = X + Y
Z2 = X - Y
Z3 = X*Y
Z4 = X**2+Y**2

# create a figure object
fig, axes = plt.subplots(2,2, figsize=(8,8))
im1 = axes[0,0].imshow(Z1, cmap='BrBG')
im2 = axes[0,1].imshow(Z2, cmap='BrBG')
im3 = axes[1,0].imshow(Z3, cmap='BrBG')
im4 = axes[1,1].imshow(Z4, cmap='BrBG')

cbar = fig.colorbar(im1, ax=axes[:, 1], shrink=0.8)

One way you may attempt to fix this could be to use the tight_layout() function, which can help align subplots. This will make things worse however, because the colorbar confuses the algorithm that tight_layout uses to arrange the axes objects. The result will look like this (and produce an warning output):

Luckily, there is an alternative to tight_layout, called constrained_layout, which uses a constrained solver to optimize subplot placement. Constrained layout should be called during the creation of the figure object, as demonstrated below. Applying constrained layout to this set of subplots fixes the error and creates a nice looking set of plots.

# create a figure object
fig, axes = plt.subplots(2,2, figsize=(8,8), constrained_layout=True)
im1 = axes[0,0].imshow(Z1, cmap='BrBG')
im2 = axes[0,1].imshow(Z2, cmap='BrBG')
im3 = axes[1,0].imshow(Z3, cmap='BrBG')
im4 = axes[1,1].imshow(Z4, cmap='BrBG')

cbar = fig.colorbar(im1, ax=axes[:, 1], shrink=0.8)

I should note that we can edit the colorbar axes object just like any of the other axes objects, adding labels, adjusting the tick marks etc. For more on colorbars, see the documentation here.

The subplot_mosaic interface

In my last post I discussed the Gridspec interface, which allows you to manually configure custom grid of subplots. While Gridspec can create complex configurations of subplots, manually adjusting the grid can become complicated for complex layouts. Matplotlib recently introduced a new feature, subplot_mosaic, which allows a more intuitive interface for configuring subplots. The subplot_mosaic has a simple and streamlined interface that allows you to easily lay out subplots, then stores these subplots as a dictionary. It’s important to note that this feature is still in testing and (at the time of this posting) is not currently supported by many distributions such as Anaconda. For more examples and full documentation, see the official release page.

Everything you want to know about subplots in Python’s Matplotlib

Sometimes its helpful to get back to basics. I recently created a summer course on data visualization with Python, and the experience made me realize that the workings of Python’s main visualization library, Matplotlib, are often left of out formal Python courses. Over the course of my graduate career I’ve taken several courses on programming and Python, and each time visualization or plotting was viewed as an afterthought. I don’t think my experience is uncommon, creating basic visualizations in Python is fairly straight forward, and a quick Google search will yield tutorials on how to make most common plot types. While constructing the summer course, I realized how much time I could have saved if I had learned how Matplotlib worked earlier in my PhD. With this in mind, I’m writing this post to serve as a guide for those new to plotting with Matplotlib, and to help fill in some gaps for those who are already experienced Matplotlib users.

I’ve chosen to devote this post to making subplots, a task often necessary for scientific visualizations that can be surprisingly frustrating if you don’t understand Matplotlib’s structure. Since Python is an open source language, there are multiple ways of creating and working with subplots, so in this post I’ll outline a few ways that work for me, and provide some context about how things work behind the scenes in Matplotlib.

A brief introduction to Matplotlib’s object oriented syntax

Let’s start with some history. As it’s name suggests, Matplotlib was originally created as a means of replicating Matlab style plotting functionality in Python. As such, one way we can create visualizations using Matplotlib is through “Matlab style” syntax which is contained within Matplotlib’s flagship module, Pyplot. For example, to create a simple line plot we can use the following code:

import matplotlib.pyplot as plt
import numpy as np
x = np.arange(0,10)
y = np.arange(0, 10)


Which will generate this plot:

While the Matlab style syntax is easy to use, it is actually quite limiting in its ability to create custom visualizations. Luckily, the Matlab style syntax is simply hiding an object oriented code structure that is at the core of Matplotlib. By directly accessing and working with this object oriented structure we can create highly customized visualizations.

For the purposes of this blog, there are two Matplotlib objects that are important: Figure objects and Axes objects. Figure objects represent the containers that hold a visualization. I think of this as the blank canvas that the plots will be generated on. Figure objects can contain one or multiple Axes objects, each representing an individual chart (this can be a big source of confusion when learning Matplotlib, the term “axes object” is inherited from Matlab and refers to an entire chart rather than a x or y axis) . To create the plot above using Matplotlib’s object oriented syntax we need to make a few small modifications to the original code. First, we’ll generate a Figure object using Pyplot which will automatically create an axes object behind the scenes. Next we access this Axes object using the Figure object’s “gca” method, which stands for “get current axes”. Finally, we’ll use the Axes object to generate the plot:

fig = plt.figure()
ax = plt.gca()

This will make the identical plot as above:

Note that we can use the Axes object to create any kind of plot that can be created with the Matlab style syntax. Examples include ax.plot (line plot), ax.scatter (scatter plot), ax.bar (bar plot), ax.imshow (heatmaps and color mesh) etc.

Creating Basic subplots

Now that we have some insight into the object oriented structure of Matplotlib, we can start making some subplots. There are two main ways we can make subplots with Matplotlib, the first is to use Pyplot’s “subplots” function, which allows us to specify the number of rows and columns of plots in your figure. This function will return both a Figure object and a list of Axes objects. Note that the list of Axes objects is arranged in the same way as the subplots are within the figure (i.e. list entry [0,0] is the top left subplot):

fig, [[ax0, ax1],[ax2, ax3]] = plt.subplots(nrows=2, ncols=2)
ax0.text(0.5, 0.5, "This is axes object 0", ha='center')
ax1.text(0.5, 0.5, "This is axes object 1", ha='center')
ax2.text(0.5, 0.5, "This is axes object 2", ha='center')
ax3.text(0.5, 0.5, "This is axes object 3", ha='center')

Which will make the following set of subplots:

Instead of specifying individual names of each axes, we can alternatively store them in a single named list and axes them via indices like this:

fig, axes = plt.subplots(nrows=2, ncols=2)
axes[0,0].text(0.5, 0.5, "This is axes object 0", ha='center')
axes[0,1].text(0.5, 0.5, "This is axes object 1", ha='center')
axes[1,0].text(0.5, 0.5, "This is axes object 2", ha='center')
axes[1,1].text(0.5, 0.5, "This is axes object 3", ha='center')

An alternative way to create subplots is to first make the Figure object on its own, and then add subplots one at a time using the method “add_subplot” from the figure object. This method takes one argument, a three digit number. The first digit represents the number of rows, the second represents the number of columns and the third represents the placement of the individual subplot (the location from left to right, top to bottom, with the first placement being the top left and the last being the lower right. Oddly, this is 1 indexed unlike everything else in the entire Python language):

fig = plt.figure()
ax0 =fig.add_subplot(221)
ax0.text(0.5, 0.5, "This is axes object 0", ha='center')
ax1 =fig.add_subplot(222)
ax1.text(0.5, 0.5, "This is axes object 1", ha='center')
ax2 =fig.add_subplot(223)
ax2.text(0.5, 0.5, "This is axes object 2", ha='center')
ax3 =fig.add_subplot(224)
ax3.text(0.5, 0.5, "This is axes object 3", ha='center')

This script will make the identical set of subplots as shown above:

Giving your plots some room to breath

You may have noticed that the plots above are very cramped. By default, there is very little room between adjacent subplots. One remedy is to use Matplotlib’s tight_layout function, which will automatically fit your subplots into the figure. Just add this one line after creating your subplots. For demonstration, I’ll also add titles and axes labels to each subplot:

fig, [[ax0, ax1],[ax2, ax3]] = plt.subplots(nrows=2, ncols=2)
ax0.text(0.5, 0.5, "This is axes object 0", ha='center')
ax0.set_title('Title 0')
ax1.text(0.5, 0.5, "This is axes object 1", ha='center')
ax1.set_title('Title 1')
ax2.text(0.5, 0.5, "This is axes object 2", ha='center')
ax2.set_title('Title 2')
ax3.text(0.5, 0.5, "This is axes object 3", ha='center')
ax3.set_title('Title 3')

This will create the following set of plots:

If we want to add some extra padding between subplots, we can add some arguments to override the default tight_layout parameters. The argument “pad” adds padding between subplots and the figure boarders, while “h_pad” and “w_pad” add height and width padding between subplots. The units of this padding are in percentages of the default font size.

plt.tight_layout(pad=1.5, h_pad=2.5, w_pad=2)

Note that to add the necessary padding, tight_layout will make the subplots themselves smaller and smaller. To fix this, we can increase the size of the figure object using the “figsize” argument when we create the Figure object. The units of this function are in inches:

fig, [[ax0, ax1],[ax2, ax3]] = plt.subplots(nrows=2, ncols=2, figsize=(8,6))

Before moving on, I should note that the function: plt.subplots_adjust() has very similar functionality to tight_layout and can allow you to adjust left, right, bottom and top paddings individually. For the sake of brevity I’m omitting it here.

Getting fancy: creating subplots of different sizes

So far we’ve been creating subplots of uniform size. In practice however, it can be helpful to generate plots of varying sizes. We can do this using the Gridspec class. Gridspec will create a grid of subplot locations within a Figure object. When generating subplots, we can assign each a location and size in Gridspec coordinates. Importantly, a single subplot can span multiple rows or columns of Gridspec coordinates. Below, I’ll make a 2×3 Gridspec and use it to create different sized subplots.

fig = plt.figure(figsize=(8,6))
gspec = fig.add_gridspec(nrows=2, ncols=3)

# the first subplot will span one row and two columns
# it will start at the top left
ax0 = fig.add_subplot(gspec[0,:2])
ax0.text(0.5, 0.5, "This is axes object 0, \ngspec coordinates [0,:2]", ha='center')

# the second subplot will span one row and one column
# it will start at the bottom left
ax1 = fig.add_subplot(gspec[1,0])
ax1.text(0.5, 0.5, "This is axes object 1, \ngspec coordinates [1,0]", ha='center')

# the third subplot will span one row and one column
# it will start at the bottom middle
ax2 = fig.add_subplot(gspec[1,1])
ax2.text(0.5, 0.5, "This is axes object 2, \ngspec coordinates [1,1]", ha='center')

# the fourth subplot will span two rows and one column
# it will start at the top right
ax3 = fig.add_subplot(gspec[:,2])
ax3.text(0.5, 0.5, "This is axes object 3, \ngspec coordinates [:,2]", ha='center')


Which will make this plot:

We can also customize the height and width of each Gridspec coordinate using the arguments “height_ratios” and “width_ratios” respectively when creating the Gridspec object.

fig = plt.figure(figsize=(8,6))
# parameters to specify the width and height ratios between rows and columns
widths= [1, 1.5, 2]
heights = [1, .5]

gspec = fig.add_gridspec(ncols=3, nrows=2, width_ratios = widths, height_ratios = heights)

# the first subplot will span one row and two columns
# it will start at the top left
ax0 = fig.add_subplot(gspec[0,:2])
ax0.text(0.5, 0.5, "This is axes object 0, \ngspec coordinates [0,:2]", ha='center')

# the second subplot will span one row and one column
# it will start at the bottom left
ax1 = fig.add_subplot(gspec[1,0])
ax1.text(0.5, 0.5, "This is axes object 1, \ngspec coordinates [1,0]", ha='center')

# the third subplot will span one row and one column
# it will start at the bottom middle
ax2 = fig.add_subplot(gspec[1,1])
ax2.text(0.5, 0.5, "This is axes object 2, \ngspec coordinates [1,1]", ha='center')

# the fourth subplot will span two rows and one column
# it will start at the top right
ax3 = fig.add_subplot(gspec[:,2])
ax3.text(0.5, 0.5, "This is axes object 3, \ngspec coordinates [:,2]", ha='center')


Final thoughts

There are many more ways we can customize subplots in Matplotlib, but the material in this post suits my needs 99% of the time. For further reading, check out the Matplotlib documentation and examples: https://matplotlib.org/gallery/subplots_axes_and_figures/subplots_demo.html#sphx-glr-gallery-subplots-axes-and-figures-subplots-demo-py