A simple ODE Class

A small illustration on using the armadillo C++ linear algebra library for solving an ordinary differential equation of the form \[ X’(t) = F(t,X(t),U(t)).\]

The abstract super class Solver defines the methods solve (for approximating the solution in user-defined time-points) and solveint (for interpolating user-defined input functions on finer grid). As an illustration a simple Runge-Kutta solver is derived in the class RK4.

The first step is to define the ODE, here a simple one-dimensional ODE \(X’(t) = \theta\cdot\{U(t)-X(t)\}\) with a single input \(U(t)\):

rowvec dX(const rowvec &input, // time (first element) and additional input variables
	 const rowvec &x,     // state variables
	 const rowvec &theta) {   // parameters
  rowvec res = { theta(0)*theta(1)*(input(1)-x(0)) };
  return( res );
}

The ODE may then be solved using the following syntax

odesolver::RK4 MyODE(dX);
arma::mat res = MyODE.solve(input, init, theta);

with the step size defined implicitly by input (first column is the time variable and the following columns the optional different input variables) and boundary conditions defined by init.

The header file

#ifndef _ODESOLVER_H_
#define _ODESOLVER_H_

#ifndef RARMA
#define MATHLIB_STANDALONE
#include <armadillo>
#include "Rmath.h"
#endif
#if defined(RARMA)
#include <RcppArmadillo.h>
#endif

namespace odesolver {

  using odefunc = std::function<arma::mat(arma::mat input, arma::mat x, arma::mat theta)>; // Type definition

  /*!
    Abstract class for ODE Solver
  */
  class Solver {
  protected:
    odefunc F;

  public:
    Solver(odefunc F) { this->F = F; }
    virtual ~Solver() {}

    virtual arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta) = 0;
    arma::mat solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau=1.0e-1, bool reduce=true);
  };

  /*!
    Clasisc Runge-Kutta solver
   */
  class RK4 : public Solver { // Basic 4th order Runge-Kutta solver
    public:
      using Solver::Solver;

    arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta);
    };

  } // namespace odesolver



  arma::uvec approx(const arma::mat &time,  // Sorted time points (Ascending)
		  const arma::mat &newtime,
		  unsigned type=0); // (0: nearest, 1: right, 2: left)

  arma::mat interpolate(const arma::mat &input, // first column is time
		      double tau, // Time-step
		      bool locf=false); // Last-observation-carried forward, otherwise linear interpolation

#endif /* _ODESOLVER_H_ */
Code Snippet 1: "asdfdsa"

Class cpp file

#include <vector>
#include "odesolver.h"

using namespace arma;

namespace odesolver {

  arma::uvec approx(const arma::mat &time,  // Sorted time points (Ascending)
		  const arma::mat &newtime,
		  unsigned type) { // (0: nearest, 1: right, 2: left)
    uvec idx(newtime.n_elem);
    double vmax = time(time.n_elem-1);
    vec::const_iterator it;
    double upper=0.0; int pos=0;
    for (int i=0; i<newtime.n_elem; i++) {
      if (newtime[i]>=vmax) {
      pos = time.n_elem-1;
      } else {
      it = std::lower_bound(time.begin(), time.end(), newtime(i));
      upper = *it;
      if (it == time.begin()) {
	pos = 0;
      } else {
	pos = int(it-time.begin());
	if (type==0 && std::fabs(newtime(i)-time(pos-1)) < std::fabs(newtime(i)-time(pos))) pos -= 1;
      }
      }
      if (type==2 && newtime(i)<upper) pos--;
      idx(i) = pos;
    }
    return(idx);
  }

  arma::mat interpolate(const arma::mat &input, double tau, bool locf) {
    vec time = input.col(0);
    unsigned n = time.n_elem;
    double t0 = time(0);
    double tn = time(n-1);
    unsigned N = std::ceil((tn-t0)/tau)+1;
    mat input2(N, input.n_cols);
    unsigned cur = 0;
    input2.row(0) = input.row(0);
    double curtime = t0;
    rowvec slope(input.n_cols);
    if (locf) {
      slope.fill(0); slope(0) = 1;
    } else {
      slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
    }
    for (unsigned i=0; i<N-1; i++) {
      while (time(cur+1)<curtime) {
      cur++;
      if (cur==(n-1)) break;
      if (!locf)
	slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
      }
      double delta = curtime-time(cur);
      input2.row(i) = input.row(cur) + slope*(curtime-time(cur));
      curtime += tau;
    }
    tau = tn-input2(N-2,0);
    input2.row(N-1) = input.row(input.n_rows-1);
    return( input2 );
  }

  arma::mat RK4::solve(const arma::mat &input, arma::mat init, arma::mat theta) {
    unsigned n = input.n_rows;
    unsigned p = init.n_elem;
    mat res(n, p);
    rowvec y = init;
    res.row(0) = init;
    for (unsigned i=0; i<n-1; i++)  {
      rowvec dinput = input.row(i+1)-input.row(i);
      double tau = dinput(0);
      rowvec f1 = tau*F(input.row(i),            y,        theta);
      rowvec f2 = tau*F(input.row(i) + dinput/2, y + f1/2, theta);
      rowvec f3 = tau*F(input.row(i) + dinput/2, y + f2/2, theta);
      rowvec f4 = tau*F(input.row(i) + dinput,   y + f3,   theta);
      y += (f1+2*f2+2*f3+f4)/6;
      res.row(i+1) = y;
    }
    return( res );
  }

  arma::mat Solver::solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau, bool reduce) {
    mat newinput = interpolate(input, tau, true);
    mat value = solve(newinput, init, theta);
    if (reduce) {
      uvec idx = odesolver::approx(newinput.col(0), input.col(0), 0);
      value = value.rows(idx);
    }
    return( value );
  }

} // namespace odesolver
Code Snippet 2: Class cpp file