60 lines
1.7 KiB
Plaintext
60 lines
1.7 KiB
Plaintext
# Test the sexpm functions that are used for fast matrix exponentials
|
|
#
|
|
library(Matrix)
|
|
library(survival)
|
|
aeq <- function(x, y, ...) all.equal(as.vector(x),as.vector(y), ...)
|
|
|
|
nfun <- length(sexpm)
|
|
times <- c(.1, 1, 5)
|
|
test1 <- matrix(FALSE, nfun, length(times),
|
|
dimnames=list(names(sexpm), paste0("t=", times)))
|
|
test2 <- test1
|
|
eps <- 1e-8
|
|
|
|
dtest <- function(x, tmat, time= 1.3, eps=1e-8) {
|
|
# Check a derivative
|
|
if (missing(tmat)) {
|
|
tmat <- matrix(0, x$nstate, x$nstate)
|
|
tmat[x$nonzero] <- runif(length(x$nonzero), .1, 3)
|
|
diag(tmat) <- -rowSums(tmat)
|
|
}
|
|
else {
|
|
if (!aeq(dim(tmat), c(x$nstate, x$nstate))) stop("invalid tmat")
|
|
diag(tmat) <- diag(tmat) - rowSums(tmat)
|
|
}
|
|
|
|
d1 <- x$deriv(tmat, time)
|
|
d2 <- 0*d1
|
|
nz <- x$nonzero
|
|
exp1 <- as.matrix(expm(tmat* time))
|
|
|
|
for (j in 1:length(nz)) {
|
|
temp <- tmat
|
|
temp[nz[j]] <- temp[nz[j]] + eps
|
|
diag(temp) <- diag(temp) - rowSums(temp)
|
|
exp2 <- as.matrix(expm(temp* time))
|
|
d2[,,j] <- (exp2 - exp1)/eps
|
|
}
|
|
list(d1=d1, d2=d2)
|
|
}
|
|
|
|
for (i in 1:nfun) {
|
|
j <- sexpm[[i]]
|
|
tmat <- matrix(0, j$nstate, j$nstate)
|
|
tmat[j$nonzero] <- runif(length(j$nonzero), .1, 4)
|
|
diag(tmat) <- -rowSums(tmat)
|
|
|
|
dtemp <- logical(length(j$nonzero))
|
|
for (k in 1:length(times)) {
|
|
m1 <- j$mexp(tmat, times[k])
|
|
m2 <- expm(tmat*times[k])
|
|
test1[i,k] <- isTRUE(aeq(m1, as.matrix(m2)))
|
|
# now derivatives
|
|
temp <- dtest(j, tmat, times[k], eps)
|
|
test2[i,k] <- isTRUE(aeq(temp$d1, temp$d2, tol=sqrt(eps)))
|
|
}
|
|
}
|
|
|
|
all(test1) # should all be TRUE
|
|
all(test2)
|