use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.
the class SerializationTest method test.
@Nullable
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer layer, final Tensor... inputPrototype) {
log.h1("Serialization");
log.p("This apply will demonstrate the layer's JSON serialization, and verify deserialization integrity.");
String prettyPrint = "";
log.h2("Raw Json");
try {
prettyPrint = log.code(() -> {
final JsonObject json = layer.getJson();
@Nonnull final Layer echo = Layer.fromJson(json);
if (echo == null)
throw new AssertionError("Failed to deserialize");
if (layer == echo)
throw new AssertionError("Serialization did not copy");
if (!layer.equals(echo))
throw new AssertionError("Serialization not equal");
echo.freeRef();
return new GsonBuilder().setPrettyPrinting().create().toJson(json);
});
@Nonnull String filename = layer.getClass().getSimpleName() + "_" + log.getName() + ".json";
log.p(log.file(prettyPrint, filename, String.format("Wrote Model to %s; %s characters", filename, prettyPrint.length())));
} catch (RuntimeException e) {
e.printStackTrace();
Util.sleep(1000);
} catch (OutOfMemoryError e) {
e.printStackTrace();
Util.sleep(1000);
}
log.p("");
@Nonnull Object outSync = new Object();
if (prettyPrint.isEmpty() || prettyPrint.length() > 1024 * 64)
Arrays.stream(SerialPrecision.values()).parallel().forEach(precision -> {
try {
@Nonnull File file = new File(log.getResourceDir(), log.getName() + "_" + precision.name() + ".zip");
layer.writeZip(file, precision);
@Nonnull final Layer echo = Layer.fromZip(new ZipFile(file));
getModels().put(precision, echo);
synchronized (outSync) {
log.h2(String.format("Zipfile %s", precision.name()));
log.p(log.link(file, String.format("Wrote Model apply %s precision to %s; %.3fMiB bytes", precision, file.getName(), file.length() * 1.0 / (0x100000))));
}
if (!isPersist())
file.delete();
if (echo == null)
throw new AssertionError("Failed to deserialize");
if (layer == echo)
throw new AssertionError("Serialization did not copy");
if (!layer.equals(echo))
throw new AssertionError("Serialization not equal");
} catch (RuntimeException e) {
e.printStackTrace();
} catch (OutOfMemoryError e) {
e.printStackTrace();
} catch (ZipException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
});
return null;
}
use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.
the class BatchDerivativeTester method testLearning.
/**
* Test learning tolerance statistics.
*
* @param component the component
* @param IOPair the io pair
* @param statistics the statistics
* @return the tolerance statistics
*/
public ToleranceStatistics testLearning(@Nonnull Layer component, @Nonnull IOPair IOPair, ToleranceStatistics statistics) {
final ToleranceStatistics prev = statistics;
statistics = IntStream.range(0, component.state().size()).mapToObj(i -> {
@Nullable final Tensor measuredGradient = !verify ? null : measureLearningGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
@Nonnull final Tensor implementedGradient = getLearningGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
try {
final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
}).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
if (!(result.absoluteTol.getMax() < tolerance)) {
throw new AssertionError(result.toString());
} else {
// log.info(String.format("Component: %s", component));
if (verbose) {
log.info(String.format("Learning Gradient for weight setByCoord %s", i));
log.info(String.format("Weights: %s", new Tensor(component.state().get(i)).prettyPrint()));
log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
if (null != measuredGradient) {
log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
log.info(String.format("Gradient Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
}
}
return result;
}
} catch (@Nonnull final Throwable e) {
// log.info(String.format("Component: %s", component));
log.info(String.format("Learning Gradient for weight setByCoord %s", i));
log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
if (null != measuredGradient) {
log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
log.info(String.format("Gradient Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
}
throw e;
}
}).reduce((a, b) -> a.combine(b)).map(x -> x.combine(prev)).orElseGet(() -> prev);
return statistics;
}
use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.
the class BatchDerivativeTester method test.
/**
* Test tolerance statistics.
*
* @param log
* @param component the component
* @param inputPrototype the input prototype
* @return the tolerance statistics
*/
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer component, @Nonnull final Tensor... inputPrototype) {
log.h1("Differential Validation");
@Nonnull IOPair ioPair = new IOPair(component, inputPrototype[0]).invoke();
if (verbose) {
log.code(() -> {
BatchDerivativeTester.log.info(String.format("Inputs: %s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get()));
BatchDerivativeTester.log.info(String.format("Inputs Statistics: %s", Arrays.stream(inputPrototype).map(x -> new ScalarStatistics().add(x.getData()).toString()).reduce((a, b) -> a + ",\n" + b).get()));
BatchDerivativeTester.log.info(String.format("Output: %s", ioPair.getOutputPrototype().prettyPrint()));
BatchDerivativeTester.log.info(String.format("Outputs Statistics: %s", new ScalarStatistics().add(ioPair.getOutputPrototype().getData())));
});
}
ToleranceStatistics _statistics = new ToleranceStatistics();
if (isTestFeedback()) {
log.h2("Feedback Validation");
log.p("We validate the agreement between the implemented derivative _of the inputs_ apply finite difference estimations:");
ToleranceStatistics statistics = _statistics;
_statistics = log.code(() -> {
return testFeedback(component, ioPair, statistics);
});
}
if (isTestLearning()) {
log.h2("Learning Validation");
log.p("We validate the agreement between the implemented derivative _of the internal weights_ apply finite difference estimations:");
ToleranceStatistics statistics = _statistics;
_statistics = log.code(() -> {
return testLearning(component, ioPair, statistics);
});
}
log.h2("Total Accuracy");
log.p("The overall agreement accuracy between the implemented derivative and the finite difference estimations:");
ToleranceStatistics statistics = _statistics;
log.code(() -> {
// log.info(String.format("Component: %s\nInputs: %s\noutput=%s", component, Arrays.toString(inputPrototype), outputPrototype));
BatchDerivativeTester.log.info(String.format("Finite-Difference Derivative Accuracy:"));
BatchDerivativeTester.log.info(String.format("absoluteTol: %s", statistics.absoluteTol));
BatchDerivativeTester.log.info(String.format("relativeTol: %s", statistics.relativeTol));
});
log.h2("Frozen and Alive Status");
log.code(() -> {
testFrozen(component, ioPair.getInputPrototype());
testUnFrozen(component, ioPair.getInputPrototype());
});
return _statistics;
}
use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.
the class CudaLayerTester method compareLayerDerivatives.
/**
* Compare layer derivatives tolerance statistics.
*
* @param expected the expected
* @param actual the actual
* @return the tolerance statistics
*/
@Nullable
public ToleranceStatistics compareLayerDerivatives(final SimpleResult expected, final SimpleResult actual) {
@Nonnull final ToleranceStatistics derivativeAgreement = IntStream.range(0, getBatchSize()).mapToObj(batch -> {
@Nonnull Function<Layer, ToleranceStatistics> compareInputDerivative = input -> {
double[] b = actual.getLayerDerivative().getMap().get(input).getDelta();
double[] a = expected.getLayerDerivative().getMap().get(input).getDelta();
ToleranceStatistics statistics = new ToleranceStatistics().accumulate(a, b);
return statistics;
};
return Stream.concat(actual.getLayerDerivative().getMap().keySet().stream(), expected.getLayerDerivative().getMap().keySet().stream()).distinct().map(compareInputDerivative).reduce((a, b) -> a.combine(b));
}).filter(x -> x.isPresent()).map(x -> x.get()).reduce((a, b) -> a.combine(b)).orElse(null);
if (null != derivativeAgreement && !(derivativeAgreement.absoluteTol.getMax() < tolerance)) {
logger.info("Expected Derivative: " + Arrays.stream(expected.getInputDerivative()).flatMap(TensorList::stream).map(x -> {
String str = x.prettyPrint();
x.freeRef();
return str;
}).collect(Collectors.toList()));
logger.info("Actual Derivative: " + Arrays.stream(actual.getInputDerivative()).flatMap(TensorList::stream).map(x -> {
String str = x.prettyPrint();
x.freeRef();
return str;
}).collect(Collectors.toList()));
throw new AssertionError("Layer Derivatives Corrupt: " + derivativeAgreement);
}
return derivativeAgreement;
}
use of com.simiacryptus.mindseye.test.ToleranceStatistics in project MindsEye by SimiaCryptus.
the class BatchDerivativeTester method testFeedback.
/**
* Test feedback tolerance statistics.
*
* @param component the component
* @param IOPair the io pair
* @param statistics the statistics
* @return the tolerance statistics
*/
public ToleranceStatistics testFeedback(@Nonnull Layer component, @Nonnull IOPair IOPair, ToleranceStatistics statistics) {
statistics = statistics.combine(IntStream.range(0, IOPair.getInputPrototype().length).mapToObj(i -> {
@Nullable final Tensor measuredGradient = !verify ? null : measureFeedbackGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
@Nonnull final Tensor implementedGradient = getFeedbackGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
try {
final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
}).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
if (!(result.absoluteTol.getMax() < tolerance))
throw new AssertionError(result.toString());
// log.info(String.format("Component: %s", component));
if (verbose) {
log.info(String.format("Feedback for input %s", i));
log.info(String.format("Inputs Values: %s", IOPair.getInputPrototype()[i].prettyPrint()));
log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(IOPair.getInputPrototype()[i].getData())));
log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
if (null != measuredGradient) {
log.info(String.format("Measured Feedback: %s", measuredGradient.prettyPrint()));
log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
log.info(String.format("Feedback Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
}
}
return result;
} catch (@Nonnull final Throwable e) {
// log.info(String.format("Component: %s", component));
log.info(String.format("Feedback for input %s", i));
log.info(String.format("Inputs Values: %s", IOPair.getInputPrototype()[i].prettyPrint()));
log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(IOPair.getInputPrototype()[i].getData())));
log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
if (null != measuredGradient) {
log.info(String.format("Measured: %s", measuredGradient.prettyPrint()));
log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
log.info(String.format("Feedback Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
}
throw e;
}
}).reduce((a, b) -> a.combine(b)).get());
return statistics;
}
Aggregations