## A simple symbolic differentiator. The first argument, the ## expression to be differentiated, is assumed to be a "call" object, ## a "name" object or a constant. The second argument can be a name ## object or a string (i.e. a character vector of length 1). Only ## very little error checking is done. d <- function(e, v) { if (is.expression(e)) stop("expression objects are not supported; use 'quote'") if (is.call(e)) dCall(e[[1]], as.list(e)[-1], v) else if (is.name(e)) if (e == v) 1 else 0 else 0 ## everything else is assumed to be a constant } ## dCall handles differentiation of calls. dCall <- function(fun, args, v) { fname <- as.character(fun) if (length(args) == 1) switch(fname, "+" = d(args[[1]], v), "-" = makeUminus(d(args[[1]], v)), "(" = d(args[[1]], v), dUnaryCall(fun, args[[1]], v)) else if (length(args) == 2) switch(fname, "+" = dAdd(args, v), "-" = dSub(args, v), "*" = dMul(args, v), "/" = dDiv(args, v), stop("function ", sQuote(fun), " is not supported yet")) else stop("only unary and binary operations supported for now") } ## The following functions handle differentiation rules for specific ## functions. dAdd <- function(args, v) makeAdd(d(args[[1]], v), d(args[[2]], v)) dSub <- function(args, v) makeSub(d(args[[1]], v), d(args[[2]], v)) dMul <- function(args, v) { a <- args[[1]] b <- args[[2]] makeAdd(makeMul(d(a, v), b), makeMul(a, d(b, v))) } dDiv <- function(args, v) { a <- args[[1]] b <- args[[2]] makeDiv(makeSub(makeMul(d(a, v), b), makeMul(a, d(b, v))), makePow(b, 2)) } ## Constructors for result expressions have been broken out into ## functions of their own to reduce repetition and make the code more ## readable. These include some simplifications for cases with scalar ## arguments. is.scalar <- function(x) is.numeric(x) && length(x) == 1 makeAdd <- function(x, y) { if (is.scalar(x) && x == 0) y else if (is.scalar(y) && y == 0) x else makeCall("+", list(x, y)) } makeSub <- function(x, y) { if (is.scalar(x) && x == 0) makeUminus(y) else if (is.scalar(y) && y == 0) x else makeCall("-", list(x, y)) } makeMul <- function(x, y) { if (is.scalar(x) && is.scalar(y)) return(x * y) if (is.scalar(x)) { if (x == 1) return(y) else if (x == 0) return(0) } if (is.scalar(y)) { if (y == 1) return(x) else if (y == 0) return(0) } makeCall("*", list(x, y)) } makeDiv <- function(x, y) { if (is.scalar(y) && y == 1) x else makeCall("/", list(x, y)) } makePow <- function(x, y) { if (is.scalar(x) && is.scalar(y)) x ^ y else makeCall("^", list(x, y)) } makeUminus <- function(x) { if (is.scalar(x)) -x else makeCall("-", list(x)) } makeCall <- function(fun, args) as.call(c(list(as.name(fun)), args)) ## General calls to unary functions are handled by a data base of ## rules stored in an environment. The rules are functions of one ## argument that return an expression for the derivative of the ## associated function at the specified argument. dUnaryCall <- function(fun, arg, v) { handler <- lookupDeriv(fun) makeMul(handler(arg), d(arg, v)) } derivsDB <- new.env(hash = TRUE, parent = emptyenv()) lookupDeriv <- function(fun) { fname <- as.character(fun) if (exists(fname, derivsDB)) get(fname, envir = derivsDB) else stop("function ", sQuote(fun), " is not supported yet") } registerDeriv <- function(fname, handler) assign(fname, handler, envir = derivsDB) registerDeriv("sin", function(arg) makeCall("cos", arg)) registerDeriv("cos", function(arg) makeUminus(makeCall("sin", arg)))