Thursday, January 12, 2012

Numerical Simulation

The other day I came across a neat problem that I'd used to teach the basics of numerical simulation to some undergraduates a few summers ago. I thought I'd re-post it here. The goal is to solve for the solution to the following third order differential equation:'
  • y''' + 5y'' + 8y' + 4y = 1
Use a timestep of dt = 200us and run the simulation for t = 10 seconds. The initial conditions are:
  • y(0) = 5
  • y'(0) = 1
  • y''(0) = 0
This problem is fun because you can solve for a closed form solution and then compare that hand-solved answer to the simulated version. The simulated solution can be run using Matlab, C, or a combination. The combination solution is neat because it uses Matlab to create the data and plot the solution, but uses C-code (compiled in Matlab as a mex-function) as the super-fast solution engine. Mex-Functions are a bit tricky to learn but can often lead to valuable speed-ups in simulations.

The solution to the third order differential equation can be solved by hand. You should get the following function:

The plot for this function is shown below along with the iterated solution that I generated using C++:

Simulated Solutions
Here I present three numerical (i.e. simulated) solutions to the third order differential equation. All three yield the same simulated plot (see blue curve above).

Matlab-Only Solution

dt = 200e-6;
tMax = 10;
nSamples = tMax/dt;

y = zeros(nSamples,1);
a = 1;
b = 0;
y(1) = 5;

for i=2:nSamples
    dy = a;
    da = b;
    db = -8*a - 5*b - 4*y(i-1) + 1;

    y(i) = y(i-1) + dy*dt;
    a    = a      + da*dt;
    b    = b      + db*dt;

t = (1:length(y))*dt;   

Using the tic/toc commands in Matlab, I determined that the Matlab-only solution took an average of 12ms, including the time to create the plot.

C++ Only Solution

The C++ Only solution is fast, but we need a way to pass the data back to Matlab in order to plot it. In this solution, C++ writes the data to a binary file. Then a separate Matlab script reads the data from the file and plots it.

#include <iostream>
#include <cmath>
#include <fstream>
using namespace std;
int main()
double dt = 200e-6;
        double tmax = 10;
double pi = 4 * atan(1);
int    nSamples = floor(tmax/dt);
double da,db,dy;
double a=1;
double b=0;
double y[nSamples];
int i;
ofstream out("data2.bin",ios::out|ios::binary);
y[0] = 5;
dy = a;
da = b;
db = -8*a - 5*b - 4*y[i-1] + 1;
y[i] = y[i-1] + dy*dt;
a    = a      + da*dt;
b    = b      + db*dt;
out.write((char *)&dt , sizeof(double));
out.write((char *)y  ,nSamples*sizeof(double));
return 0;
This is the Matlab plot code:

fid = fopen('data2.bin','rb');
dt  = fread(fid,1,'double');y   = fread(fid,'double');
t = (0:length(y)-1)*dt;
plot(t,y);xlabel('time (s)');title('Solution to 3rd Order Diff-Eq');

C++ / Matlab / Mex Solution
While the previous solution works, it requires that we switch back and forth between the Matlab and C++ environments, which is inherently inefficient; it also requires that we generate a large data file for the purpose of shuttling the data back and forth. A better solution is to create a mex-file. A mex file is a C/C++ file that is compiled directly within Matlab. The compiled executable can be called directly from Matlab; data parameters can be passed back and forth from Matlab to the executable without the intermediate step of dumping it in a file. This solution requires that we only work with one programming environment: Matlab. The coding is a bit more complicated, but the solution is ultimately more elegant. The complexity of the coding is primarily due to the way Matlab passes data to C/C++. The Matlab data comes in structures, with pointers everywhere; learning to maneuver in this manner takes some getting used to. There are good references for this process here and here.

#include <mex.h>
#include <string.h>
#include <math.h>
// This is the subroutine that actually performs the simulation
void runSim(double **py, double **pt, double dt, double tMax, int *nSamples){
    double pi = 4 * atan(1);
    double da, db, dy;
    double a=1;
    double b=0;
    int i;
    double tTemp = dt;
    double *y, *t;
    *nSamples = floor(tMax/dt);
    *py = new double[*nSamples];
    y = *py;
    *pt = new double[*nSamples];
    t = *pt;
    y[0] = 5;
    t[0] = dt;
        dy = a;
        da = b;
        db = -8*a - 5*b - 4* y[i-1] + 1;
        y[i] = y[i-1] + dy*dt;
        a    = a      + da*dt;
        b    = b      + db*dt;
        t[i] = t[i-1] + dt;
// ****************************************************
// ******************** START HERE ********************
// ****************************************************
// Mex routines must always start with "mexFunction"
// Here, the input data is imported from Matlab. Then the actual function
// executed (in this case "runSim"), and finally the output data is
// exported back to Matlab
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    // Step 1: Import input variables from Matlab
    double dt   = *(double *)mxGetData(prhs[0]);
    double tMax = *(double *)mxGetData(prhs[1]);
    // Step 2: Declare variables
    int nSamples;
    double *y,*t;
    // Step 3: Run the simulation
    runSim(&y, &t, dt, tMax, &nSamples);
    // Step 4: Export the results back to Matlab
    double *output;
    if (nlhs>=1){
        plhs[0] = mxCreateDoubleMatrix(nSamples, 1, mxREAL);
        output = mxGetPr(plhs[0]);
        memcpy(output, y, nSamples*sizeof(double));
    if (nlhs>=2){
        plhs[1] = mxCreateDoubleMatrix(nSamples, 1, mxREAL);
        output = mxGetPr(plhs[1]);
        memcpy(output, t, nSamples*sizeof(double));
    // Step 5: Housekeeping
    delete [] y;
    delete [] t;

Once the C/C++ code has been written and compiled, the function is called from Matlab in the same way that any other function would be. In this case, the C file was called "diff_eq3.cpp", so therefore the compiled executable is called using "diff_eq3". This simulation took an average of 2ms, including plotting time. That is 6x faster than the Matlab-only solution!

clear; clf;
dt = 200e-6;
tmax = 10;
[y,t] = diff_eq3(dt,tmax);


  1. Hi,
    Thank you so much for your great description of the three different kind of implementations side by side. I'm trying to create a mex-file and I started to learn from your code but when I was trying to compile it, I got some errors. It would be awesome if you can give some hints.
    I'm using Matlab 7.12.0 (R2011), win 7 and Compiler: Microsoft Visual C++ 2010.
    And here is what I got:
    >> mex diff_eq3.c
    diff_eq3.c(14) : warning C4244: '=' : conversion from 'double' to 'int', possible loss of data
    diff_eq3.c(15) : error C2065: 'new' : undeclared identifier
    diff_eq3.c(15) : warning C4047: '=' : 'double *' differs in levels of indirection from 'int'
    diff_eq3.c(15) : error C2143: syntax error : missing ';' before 'type'
    diff_eq3.c(18) : error C2065: 'new' : undeclared identifier
    diff_eq3.c(18) : warning C4047: '=' : 'double *' differs in levels of indirection from 'int'
    diff_eq3.c(18) : error C2143: syntax error : missing ';' before 'type'
    diff_eq3.c(54) : error C2143: syntax error : missing ';' before 'type'
    diff_eq3.c(57) : error C2065: 'output' : undeclared identifier
    diff_eq3.c(57) : warning C4047: '=' : 'int' differs in levels of indirection from 'double *'
    diff_eq3.c(58) : error C2065: 'output' : undeclared identifier
    diff_eq3.c(58) : warning C4022: 'memcpy' : pointer mismatch for actual parameter 1
    diff_eq3.c(62) : error C2065: 'output' : undeclared identifier
    diff_eq3.c(62) : warning C4047: '=' : 'int' differs in levels of indirection from 'double *'
    diff_eq3.c(63) : error C2065: 'output' : undeclared identifier
    diff_eq3.c(63) : warning C4022: 'memcpy' : pointer mismatch for actual parameter 1
    diff_eq3.c(67) : error C2065: 'delete' : undeclared identifier
    diff_eq3.c(67) : error C2059: syntax error : ']'
    diff_eq3.c(68) : error C2065: 'delete' : undeclared identifier
    diff_eq3.c(68) : error C2059: syntax error : ']'

  2. Are you sure you are compiling in C++ and not C (your compiler probably has the option to let you do one or the other)?