# A simple ODE Class

A small illustration on using the armadillo C++ linear algebra library for solving an ordinary differential equation of the form $X’(t) = F(t,X(t),U(t)).$

The abstract super class Solver defines the methods solve (for approximating the solution in user-defined time-points) and solveint (for interpolating user-defined input functions on finer grid). As an illustration a simple Runge-Kutta solver is derived in the class RK4.

The first step is to define the ODE, here a simple one-dimensional ODE $$X’(t) = \theta\cdot\{U(t)-X(t)\}$$ with a single input $$U(t)$$:

rowvec dX(const rowvec &input, // time (first element) and additional input variables
const rowvec &x,     // state variables
const rowvec &theta) {   // parameters
rowvec res = { theta(0)*theta(1)*(input(1)-x(0)) };
return( res );
}


The ODE may then be solved using the following syntax

odesolver::RK4 MyODE(dX);
arma::mat res = MyODE.solve(input, init, theta);


with the step size defined implicitly by input (first column is the time variable and the following columns the optional different input variables) and boundary conditions defined by init.

#ifndef _ODESOLVER_H_
#define _ODESOLVER_H_

#ifndef RARMA
#define MATHLIB_STANDALONE
#include "Rmath.h"
#endif
#if defined(RARMA)
#endif

namespace odesolver {

using odefunc = std::function<arma::mat(arma::mat input, arma::mat x, arma::mat theta)>; // Type definition

/*!
Abstract class for ODE Solver
*/
class Solver {
protected:
odefunc F;

public:
Solver(odefunc F) { this->F = F; }
virtual ~Solver() {}

virtual arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta) = 0;
arma::mat solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau=1.0e-1, bool reduce=true);
};

/*!
Clasisc Runge-Kutta solver
*/
class RK4 : public Solver { // Basic 4th order Runge-Kutta solver
public:
using Solver::Solver;

arma::mat solve(const arma::mat &input, arma::mat init, arma::mat theta);
};

} // namespace odesolver

arma::uvec approx(const arma::mat &time,  // Sorted time points (Ascending)
const arma::mat &newtime,
unsigned type=0); // (0: nearest, 1: right, 2: left)

arma::mat interpolate(const arma::mat &input, // first column is time
double tau, // Time-step
bool locf=false); // Last-observation-carried forward, otherwise linear interpolation

#endif /* _ODESOLVER_H_ */

Code Snippet 1: "asdfdsa"

## Class cpp file

#include <vector>
#include "odesolver.h"

using namespace arma;

namespace odesolver {

arma::uvec approx(const arma::mat &time,  // Sorted time points (Ascending)
const arma::mat &newtime,
unsigned type) { // (0: nearest, 1: right, 2: left)
uvec idx(newtime.n_elem);
double vmax = time(time.n_elem-1);
vec::const_iterator it;
double upper=0.0; int pos=0;
for (int i=0; i<newtime.n_elem; i++) {
if (newtime[i]>=vmax) {
pos = time.n_elem-1;
} else {
it = std::lower_bound(time.begin(), time.end(), newtime(i));
upper = *it;
if (it == time.begin()) {
pos = 0;
} else {
pos = int(it-time.begin());
if (type==0 && std::fabs(newtime(i)-time(pos-1)) < std::fabs(newtime(i)-time(pos))) pos -= 1;
}
}
if (type==2 && newtime(i)<upper) pos--;
idx(i) = pos;
}
return(idx);
}

arma::mat interpolate(const arma::mat &input, double tau, bool locf) {
vec time = input.col(0);
unsigned n = time.n_elem;
double t0 = time(0);
double tn = time(n-1);
unsigned N = std::ceil((tn-t0)/tau)+1;
mat input2(N, input.n_cols);
unsigned cur = 0;
input2.row(0) = input.row(0);
double curtime = t0;
rowvec slope(input.n_cols);
if (locf) {
slope.fill(0); slope(0) = 1;
} else {
slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
}
for (unsigned i=0; i<N-1; i++) {
while (time(cur+1)<curtime) {
cur++;
if (cur==(n-1)) break;
if (!locf)
slope = (input.row(cur+1)-input.row(cur))/(time(cur+1)-time(cur));
}
double delta = curtime-time(cur);
input2.row(i) = input.row(cur) + slope*(curtime-time(cur));
curtime += tau;
}
tau = tn-input2(N-2,0);
input2.row(N-1) = input.row(input.n_rows-1);
return( input2 );
}

arma::mat RK4::solve(const arma::mat &input, arma::mat init, arma::mat theta) {
unsigned n = input.n_rows;
unsigned p = init.n_elem;
mat res(n, p);
rowvec y = init;
res.row(0) = init;
for (unsigned i=0; i<n-1; i++)  {
rowvec dinput = input.row(i+1)-input.row(i);
double tau = dinput(0);
rowvec f1 = tau*F(input.row(i),            y,        theta);
rowvec f2 = tau*F(input.row(i) + dinput/2, y + f1/2, theta);
rowvec f3 = tau*F(input.row(i) + dinput/2, y + f2/2, theta);
rowvec f4 = tau*F(input.row(i) + dinput,   y + f3,   theta);
y += (f1+2*f2+2*f3+f4)/6;
res.row(i+1) = y;
}
return( res );
}

arma::mat Solver::solveint(const arma::mat &input, arma::mat init, arma::mat theta, double tau, bool reduce) {
mat newinput = interpolate(input, tau, true);
mat value = solve(newinput, init, theta);
if (reduce) {
uvec idx = odesolver::approx(newinput.col(0), input.col(0), 0);
value = value.rows(idx);
}
return( value );
}

} // namespace odesolver

Code Snippet 2: Class cpp file