#include "Markov.h"
#include "1stOrder.h"
#include <iostream.h>		// for error handling... ought to use exceptions...

real Markov::unity = 1;

typedef real* realPtr;		// some compilers can't parse new (real *)[qty]

Markov::Markov( const int pQtyStates )
{
	itsStateValues = new real[pQtyStates];
	itsRates = new real[pQtyStates * pQtyStates];
	itsRateCoeff = new realPtr[pQtyStates * pQtyStates];
	if (!itsStateValues || !itsRates || !itsRateCoeff) {
		cerr << "Memory allocation error in Markov::Markov(" << pQtyStates
			 << ")." << endl;
		itsQtyStates = 0;
	}
	else {
		itsQtyStates = pQtyStates;
		for (int i=0; i<itsQtyStates; i++) {
			itsStateValues[i] = 0;
			for (int j=0; j<itsQtyStates; j++) {
				SetRate ( i,j, 0 );
				SetCoeff( i,j, 0 );
			}
		}
	}
}

Markov::~Markov()
{
	if (itsStateValues) delete[] itsStateValues;
	if (itsRates) delete[] itsRates;
	if (itsRateCoeff) delete[] itsRateCoeff;
}

void Markov::Step( const real dt )
{
	// for common low-order kinetic schemes, we can solve exactly
	switch (itsQtyStates) {
		case 0:
		case 1:
			return;
			
		case 2:				// two states, first-order reaction
			itsStateValues[0] = FirstOrder( itsStateValues[0], 
											itsRates[1*2+0] * *itsRateCoeff[1*2+0],
											itsRates[0*2+1] * *itsRateCoeff[0*2+1],
											dt );
			itsStateValues[1] = 1-itsStateValues[0];
			return;

		default:
			// higher-order schemes integrated with exponential Euler
			
			real *A = new real[itsQtyStates];	// sum of input flux per state
			real *B = new real[itsQtyStates];	// sum of output rates per state
			int i;
			
			// first, find the coefficients for all states
			for (i=0; i<itsQtyStates; i++) {
				A[i] = B[i] = 0;
				for (int j=0; j<itsQtyStates; j++) {
					A[i] += itsStateValues[j] * itsRates[j*itsQtyStates+i]
										 * *itsRateCoeff[j*itsQtyStates+i];
					B[i] += itsRates[i*itsQtyStates+j]
					 * *itsRateCoeff[i*itsQtyStates+j];
				}
			}
			
			// next, update all states, using the above coefficients
			// NOTE: for states with no output path, must use forward Euler
			real stateSum = 0.0;
			for (i=0; i<itsQtyStates; i++) {
				if (B[i]) {
					real expterm = exp(-B[i]*dt);
					stateSum += ( itsStateValues[i]
							= itsStateValues[i]*expterm + (A[i]/B[i])*(1-expterm) );
				}
				else stateSum += ( itsStateValues[i]
							= itsStateValues[i] + A[i]*dt );
			}

			// finally, apply a correction to keep the total equal to 1
			for (i=0; i<itsQtyStates; i++)
				itsStateValues[i] /= stateSum;

	}
}
