Higher-level geophysical modelling

By Roman Nuterman, Dion Häfner, and Markus Jochum.
(Niels Bohr Institute, Copenhagen, Denmark)

This post is the display material for the vEGU 2021 abstract “Higher-level geophysical modelling” in session AS1.1 (Recent Developments in Numerical Earth System Modelling).

Therefore, it is a bit more technical than my usual blog posts, and assumes that you are somewhat familiar within the field of numerical modelling.

A paradigm shift

The year is 2021, but numerical modelling and high-performance computing are still bastions of low-level programming languages. Most (finite difference) models are written in Fortran or C, which have been around since the early days of computing.

This is not surprising on the very largest compute scales like the CMIP climate model ensembles, which run on the world’s largest supercomputers. In this case, even small performance drops could end up consuming funding for several human positions. But this is an extreme example, and typically, human time tends to be more valuable than computer time. (Just think of your poor PhD students trying to compile the model code or get their setup to work.)

On top of this, the efficiency of GPUs has increased dramatically in tandem with the recent machine learning boom. This has also lead to more heterogeneous compute architectures than ever. For example, Finland’s LUMI supercomputer will consist of 550 PFLOP/s worth of GPUs (on top of about 200,000 CPU cores).

We think that these two developments call for a new, flexible generation of geophysical models:

  1. Flexible to run. The same code needs to be able to run on CPU and GPU hardware stacks.
  2. Flexible to use. Simple to install and get started.
  3. Flexible to modify. Readable code with helpful abstractions. Easy to re-build.
  4. Flexible to integrate. Simple to interface with external libraries for plotting, post-processing, machine learning, other models, …

We call these flexible models “high-level”, and to meet these design goals, at least a part of them needs to be implemented in a high-level programming language.

The taxonomy of high-level modelling

Essentially, high-level modelling comes in 3 different flavors. Each of these is a good way forward, and definitely a step up from the status quo.

Let’s talk about performance

Model performance is the elephant in the room whenever we discuss high-performance computing in Python.

As we argue in the introduction, human time is often more valuable than computer time, so we don’t think that it should be prioritized at all cost. However, model performance is a part of the user experience. A long feedback loop between designing an experiment and examining its results is catastrophic to overall productivity.

So, we have to care about performance to some degree. Luckily, this is largely a solved problem. It is already possible to match native Fortran performance in Python (within ±10%).

Pure Python / NumPy is of course nowhere close to Fortran — in our experience, a model written in NumPy is about 5x slower than its Fortran equivalent. But fortunately, there is a rich library ecosystem to accelerate Python code. Most of these libraries are geared towards machine learning, but nothing prevents us from using them for scientific computing instead. (Remember when we mentioned synergy as a major asset of Type III models?)

The following plots are from pyhpc-benchmarks, a repository we created to compare the performance of various Python frameworks on subroutines of our ocean model Veros:

Performance of various Python frameworks on TEOS-10 equation of state. Numba performance is similar to that of the underlying Fortran code.
Performance of various Python frameworks on Veros isoneutral mixing subroutine. Numba performance is similar to that of the underlying Fortran code.

There are two Python frameworks that show particularly strong performance, Numba and JAX. To give you an idea how this works, here is the same code snippet in Fortran, NumPy, Numba, and JAX:


Fortran

tke_surf_corr = 0.0
do j=js_pe,je_pe
  do i=is_pe,ie_pe
    if (tke(i,j,nz,taup1) < 0.0 ) then
      tke_surf_corr(i,j) = -tke(i,j,nz,taup1)*(0.5*dzw(ke)) /dt_tke
      tke(i,j,nz,taup1) = 0.0
    endif
  enddo
enddo

NumPy

mask = tke[2:-2, 2:-2, -1, taup1] < 0.0
tke_surf_corr = np.zeros_like(maskU[..., -1])
tke_surf_corr[2:-2, 2:-2] = np.where(
    mask,
    -tke[2:-2, 2:-2, -1, taup1] * 0.5 * dzw[-1] / dt_tke,
    0.
)

Numba

tke_surf_corr = np.zeros((nx, ny))
for i in range(2, nx-2):
    for j in range(2, ny-2):
        if tke[i, j, -1, taup1] >= 0.:
            continue
        tke_surf_corr[i, j] = -tke[i, j, -1, taup1] * (0.5 * dzw[-1]) / dt_tke
        tke[i, j, -1, taup1] = 0.

JAX

mask = tke[2:-2, 2:-2, -1, taup1] < 0.0
tke_surf_corr = jnp.zeros_like(maskU[..., -1])
tke_surf_corr = tke_surf_corr.at[2:-2, 2:-2].set(
    jnp.where(
        mask,
        -tke[2:-2, 2:-2, -1, taup1] * 0.5 * dzw[-1] / dt_tke,
        0.
    )
)

See how the Numba implementation is basically translated Fortran, including the explicit loops? Numba’s JIT compiler is very efficient at transforming explicit loops like this one, essentially giving us the same performance as the Fortran code after compilation.

Unfortunately, there is no way to use the same Numba implementation for both CPU and GPU, as efficient GPU code generation typically requires a vectorized approach instead of explicit loops.

JAX on the other hand reads very similarly to NumPy, but its JIT compiler generates code that is competitive with Numba / Fortran on CPU and has great performance on GPU (with speedups of 40x — 3000x over NumPy). The only major restrictions are that JAX arrays are immutable and all JAX functions have to be pure (i.e., have no side effects).

We have therefore decided on JAX as the new computational backend for Veros. First benchmarks of the full model show that, with JAX, high-resolution setups are ~10% slower than Fortran on CPU, and about as fast as 50 Fortran CPUs on a single high-end GPU.

We have not measured power consumption yet, but with some back-of-the-envelope math, this should yield a GPU model that is at least 2x more energy efficient than its Fortran equivalent.

We think that JAX shows exceptional promise to become the de-facto computational high-performance backend for Python, because of its user friendliness and consistently high performance on both CPU and GPU.

Abstraction is key

Surprisingly, the main problem with high-level models is not performance, but to find the right level of abstraction.

Just translating model code from Fortran to Python does not make it more readable.

In fact, it can be significantly less readable, because in Python, there is a delicate balance between performance and readability. The most readable way to write the code will not be performant, and the most performant way to write the code will not be readable. In our experience, finding the perfect middle ground is extremely hard, and the different performance characteristics of each computational framework make it so there is no universal answer to this.

Additionally, people have been writing model code in Fortran since the 1970s, but are only starting to do so in Python / NumPy / JAX. This means that there are no established community standards on how to write models.

In our experience, more abstraction is needed at all levels of a Python model:

  1. At the lowest level, to separate numerics from physics.

    dydx(salt, order=1) represents the intent of the code much better than (salt[1:] - salt[:-1]) / dx. Numerical computations should also be aware of physical units and be able to perform conversions between them (e.g. via pint). Additionally, the code representing the physics should looks the same regardless of the computational backend used (NumPy or JAX or Fortran or something else).

  2. At the intermediate level, to encapsulate model state and define the data flow between model routines.

    There are several projects that address this, including sympl, xarray-simlab, and climlab. But neither has been adopted by more than a handful of projects yet, which we interpret as evidence that they are not flexible or approachable enough just yet.

  3. At the highest level, to provide a common interface for setup specification, introspection, coupling, and interactive data analysis.

    Ideally, running the model should happen in the same environment and use the same tools as post-processing of its output. For example, every physical model could expose its state as a self-describing xarray dataset.

We can address these points only through dialogue and collaboration. So if you have an idea or a project that can scratch one of these itches, please share.

Quo vadis?

While we think that high-performance modelling in Python has several decisive advantages, the future is still unclear.

To be fair, there is a movement to revive Fortran called “modern Fortran”, and a modern LLVM-based Fortran compiler LFortran, which are efforts we applaud. Fortran is still immensely powerful and a good language to write performant CPU models in.

But if the current trend continues, first-class GPU support will become more and more important. This alone means that there is no turning back (usability issues aside).

We are therefore convinced that Type II and Type III models will eventually take over, but there is still a long way to go. We as a community need to find a way to handle the increased complexity of more dynamical languages, but it can be done.

The future of high-level modelling is bright.


If you enjoyed this post, make sure to visit my vPICO presentation during vEGU 2021 (Tue, 27 Apr), and join the discussion afterwards.

I’m also happy to respond to your comments below.