use of org.matheclipse.parser.client.math.ArithmeticMathException in project symja_android_library by axkr.
the class DoubleEvaluator method derivative.
/**
*
* TODO: add more derivation rules
*
* @param node
* @param var
* @return
*/
public ASTNode derivative(final ASTNode node, SymbolNode var) {
if (node.isFree(var)) {
return new DoubleNode(0.0);
}
if (node instanceof FunctionNode) {
FunctionNode f = (FunctionNode) node;
if (f.size() > 1 && f.getNode(0) instanceof SymbolNode) {
SymbolNode head = (SymbolNode) f.getNode(0);
if (f.size() == 2) {
ASTNode arg1Derived = derivative(f.getNode(1), var);
if (isSymbol(head, "Exp")) {
FunctionNode fun = new FunctionNode(fASTFactory.createSymbol("Exp"));
fun.add(f.getNode(1));
return getDerivativeResult(arg1Derived, fun);
}
if (isSymbol(head, "Cos")) {
FunctionNode fun = new FunctionNode(fASTFactory.createSymbol("Times"));
fun.add(new DoubleNode(-1.0));
fun.add(new FunctionNode(fASTFactory.createSymbol("Cos"), f.getNode(1)));
return getDerivativeResult(arg1Derived, fun);
}
if (isSymbol(head, "Sin")) {
FunctionNode fun = new FunctionNode(fASTFactory.createSymbol("Cos"));
fun.add(f.getNode(1));
return getDerivativeResult(arg1Derived, fun);
}
} else if (f.size() == 3 && isSymbol(head, "Power")) {
if (f.get(2).isFree(var)) {
// derive x^r
ASTNode arg1Derived = derivative(f.getNode(1), var);
// (r-1)
FunctionNode exponent = fASTFactory.createFunction(fASTFactory.createSymbol("Plus"), new DoubleNode(-1.0), f.get(2));
// r*x^(r-1)
FunctionNode fun = fASTFactory.createFunction(fASTFactory.createSymbol("Times"), f.get(2), fASTFactory.createFunction(fASTFactory.createSymbol("Power"), f.get(1), exponent));
return getDerivativeResult(arg1Derived, fun);
}
if (f.get(1).isFree(var)) {
// derive a^x
ASTNode arg2Derived = derivative(f.getNode(2), var);
// log(a) * a^x
FunctionNode fun = fASTFactory.createFunction(fASTFactory.createSymbol("Times"), fASTFactory.createFunction(fASTFactory.createSymbol("Log"), f.get(1)), f);
return getDerivativeResult(arg2Derived, fun);
}
} else {
if (isSymbol(head, "Plus")) {
FunctionNode result = new FunctionNode(f.getNode(0));
for (int i = 1; i < f.size(); i++) {
ASTNode deriv = derivative(f.getNode(i), var);
if (!deriv.equals(new DoubleNode(0.0))) {
result.add(deriv);
}
}
return result;
}
if (isSymbol(head, "Times")) {
FunctionNode plusResult = new FunctionNode(fASTFactory.createSymbol("Plus"));
for (int i = 1; i < f.size(); i++) {
FunctionNode timesResult = new FunctionNode(f.getNode(0));
boolean valid = true;
for (int j = 1; j < f.size(); j++) {
if (j == i) {
ASTNode deriv = derivative(f.getNode(j), var);
if (deriv.equals(new DoubleNode(0.0))) {
valid = false;
} else {
timesResult.add(deriv);
}
} else {
timesResult.add(f.getNode(j));
}
}
if (valid) {
plusResult.add(timesResult);
}
}
return plusResult;
}
}
}
return new FunctionNode(new SymbolNode("D"), node, var);
// return evaluateFunction((FunctionNode) node);
}
if (node instanceof SymbolNode) {
if (isSymbol((SymbolNode) node, var)) {
return new DoubleNode(1.0);
}
IDoubleValue v = fVariableMap.get(node.toString());
if (v != null) {
return new DoubleNode(0.0);
}
Double dbl = SYMBOL_DOUBLE_MAP.get(node.toString());
if (dbl != null) {
return new DoubleNode(0.0);
}
return new DoubleNode(0.0);
} else if (node instanceof NumberNode) {
return new DoubleNode(0.0);
}
throw new ArithmeticMathException("EvalDouble#evaluate(ASTNode) not possible for: " + node.toString());
}
Aggregations