60 lines
1.7 KiB
Plaintext
Raw Normal View History

2025-01-12 00:52:51 +08:00
# Test the sexpm functions that are used for fast matrix exponentials
#
library(Matrix)
library(survival)
aeq <- function(x, y, ...) all.equal(as.vector(x),as.vector(y), ...)
nfun <- length(sexpm)
times <- c(.1, 1, 5)
test1 <- matrix(FALSE, nfun, length(times),
dimnames=list(names(sexpm), paste0("t=", times)))
test2 <- test1
eps <- 1e-8
dtest <- function(x, tmat, time= 1.3, eps=1e-8) {
# Check a derivative
if (missing(tmat)) {
tmat <- matrix(0, x$nstate, x$nstate)
tmat[x$nonzero] <- runif(length(x$nonzero), .1, 3)
diag(tmat) <- -rowSums(tmat)
}
else {
if (!aeq(dim(tmat), c(x$nstate, x$nstate))) stop("invalid tmat")
diag(tmat) <- diag(tmat) - rowSums(tmat)
}
d1 <- x$deriv(tmat, time)
d2 <- 0*d1
nz <- x$nonzero
exp1 <- as.matrix(expm(tmat* time))
for (j in 1:length(nz)) {
temp <- tmat
temp[nz[j]] <- temp[nz[j]] + eps
diag(temp) <- diag(temp) - rowSums(temp)
exp2 <- as.matrix(expm(temp* time))
d2[,,j] <- (exp2 - exp1)/eps
}
list(d1=d1, d2=d2)
}
for (i in 1:nfun) {
j <- sexpm[[i]]
tmat <- matrix(0, j$nstate, j$nstate)
tmat[j$nonzero] <- runif(length(j$nonzero), .1, 4)
diag(tmat) <- -rowSums(tmat)
dtemp <- logical(length(j$nonzero))
for (k in 1:length(times)) {
m1 <- j$mexp(tmat, times[k])
m2 <- expm(tmat*times[k])
test1[i,k] <- isTRUE(aeq(m1, as.matrix(m2)))
# now derivatives
temp <- dtest(j, tmat, times[k], eps)
test2[i,k] <- isTRUE(aeq(temp$d1, temp$d2, tol=sqrt(eps)))
}
}
all(test1) # should all be TRUE
all(test2)