use of org.deeplearning4j.eval.meta.Prediction in project deeplearning4j by deeplearning4j.
the class Evaluation method getPredictionByPredictedClass.
/**
* Get a list of predictions, for all data with the specified <i>predicted</i> class, regardless of the actual data
* class.
* <p>
* <b>Note</b>: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)}
* Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
* splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
* via {@link #getConfusionMatrix()}
*
* @param predictedClass Actual class to get predictions for
* @return List of predictions, or null if the "evaluate with metadata" method was not used
*/
public List<Prediction> getPredictionByPredictedClass(int predictedClass) {
if (confusionMatrixMetaData == null)
return null;
List<Prediction> out = new ArrayList<>();
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : confusionMatrixMetaData.entrySet()) {
//Entry Pair: (Actual,Predicted)
if (entry.getKey().getSecond() == predictedClass) {
int actual = entry.getKey().getFirst();
int predicted = entry.getKey().getSecond();
for (Object m : entry.getValue()) {
out.add(new Prediction(actual, predicted, m));
}
}
}
return out;
}
use of org.deeplearning4j.eval.meta.Prediction in project deeplearning4j by deeplearning4j.
the class Evaluation method getPredictionErrors.
/**
* Get a list of prediction errors, on a per-record basis<br>
* <p>
* <b>Note</b>: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)}
* Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
* splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
* via {@link #getConfusionMatrix()}
*
* @return A list of prediction errors, or null if no metadata has been recorded
*/
public List<Prediction> getPredictionErrors() {
if (this.confusionMatrixMetaData == null)
return null;
List<Prediction> list = new ArrayList<>();
List<Map.Entry<Pair<Integer, Integer>, List<Object>>> sorted = new ArrayList<>(confusionMatrixMetaData.entrySet());
Collections.sort(sorted, new Comparator<Map.Entry<Pair<Integer, Integer>, List<Object>>>() {
@Override
public int compare(Map.Entry<Pair<Integer, Integer>, List<Object>> o1, Map.Entry<Pair<Integer, Integer>, List<Object>> o2) {
Pair<Integer, Integer> p1 = o1.getKey();
Pair<Integer, Integer> p2 = o2.getKey();
int order = Integer.compare(p1.getFirst(), p2.getFirst());
if (order != 0)
return order;
order = Integer.compare(p1.getSecond(), p2.getSecond());
return order;
}
});
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : sorted) {
Pair<Integer, Integer> p = entry.getKey();
if (p.getFirst().equals(p.getSecond())) {
//predicted = actual -> not an error -> skip
continue;
}
for (Object m : entry.getValue()) {
list.add(new Prediction(p.getFirst(), p.getSecond(), m));
}
}
return list;
}
use of org.deeplearning4j.eval.meta.Prediction in project deeplearning4j by deeplearning4j.
the class EvalTest method testEvaluationWithMetaData.
@Test
public void testEvaluationWithMetaData() throws Exception {
RecordReader csv = new CSVRecordReader();
csv.initialize(new FileSplit(new ClassPathResource("iris.txt").getTempFileFromArchive()));
int batchSize = 10;
int labelIdx = 4;
int numClasses = 3;
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(csv, batchSize, labelIdx, numClasses);
NormalizerStandardize ns = new NormalizerStandardize();
ns.fit(rrdsi);
rrdsi.setPreProcessor(ns);
rrdsi.reset();
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).learningRate(0.1).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
for (int i = 0; i < 4; i++) {
net.fit(rrdsi);
rrdsi.reset();
}
Evaluation e = new Evaluation();
//*** New: Enable collection of metadata (stored in the DataSets) ***
rrdsi.setCollectMetaData(true);
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
//*** New - cross dependencies here make types difficult, usid Object internally in DataSet for this***
List<RecordMetaData> meta = ds.getExampleMetaData(RecordMetaData.class);
INDArray out = net.output(ds.getFeatures());
//*** New - evaluate and also store metadata ***
e.eval(ds.getLabels(), out, meta);
}
System.out.println(e.stats());
System.out.println("\n\n*** Prediction Errors: ***");
//*** New - get list of prediction errors from evaluation ***
List<Prediction> errors = e.getPredictionErrors();
List<RecordMetaData> metaForErrors = new ArrayList<>();
for (Prediction p : errors) {
metaForErrors.add((RecordMetaData) p.getRecordMetaData());
}
//*** New - dynamically load a subset of the data, just for prediction errors ***
DataSet ds = rrdsi.loadFromMetaData(metaForErrors);
INDArray output = net.output(ds.getFeatures());
int count = 0;
for (Prediction t : errors) {
System.out.println(t + "\t\tRaw Data: " + //*** New - load subset of data from MetaData object (usually batched for efficiency) ***
csv.loadFromMetaData((RecordMetaData) t.getRecordMetaData()).getRecord() + "\tNormalized: " + ds.getFeatureMatrix().getRow(count) + "\tLabels: " + ds.getLabels().getRow(count) + "\tNetwork predictions: " + output.getRow(count));
count++;
}
int errorCount = errors.size();
double expAcc = 1.0 - errorCount / 150.0;
assertEquals(expAcc, e.accuracy(), 1e-5);
ConfusionMatrix<Integer> confusion = e.getConfusionMatrix();
int[] actualCounts = new int[3];
int[] predictedCounts = new int[3];
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
//(actual,predicted)
int entry = confusion.getCount(i, j);
List<Prediction> list = e.getPredictions(i, j);
assertEquals(entry, list.size());
actualCounts[i] += entry;
predictedCounts[j] += entry;
}
}
for (int i = 0; i < 3; i++) {
List<Prediction> actualClassI = e.getPredictionsByActualClass(i);
List<Prediction> predictedClassI = e.getPredictionByPredictedClass(i);
assertEquals(actualCounts[i], actualClassI.size());
assertEquals(predictedCounts[i], predictedClassI.size());
}
}
use of org.deeplearning4j.eval.meta.Prediction in project deeplearning4j by deeplearning4j.
the class Evaluation method getPredictionsByActualClass.
/**
* Get a list of predictions, for all data with the specified <i>actual</i> class, regardless of the predicted
* class.
* <p>
* <b>Note</b>: Prediction errors are ONLY available if the "evaluate with metadata" method is used: {@link #eval(INDArray, INDArray, List)}
* Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
* splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
* via {@link #getConfusionMatrix()}
*
* @param actualClass Actual class to get predictions for
* @return List of predictions, or null if the "evaluate with metadata" method was not used
*/
public List<Prediction> getPredictionsByActualClass(int actualClass) {
if (confusionMatrixMetaData == null)
return null;
List<Prediction> out = new ArrayList<>();
for (Map.Entry<Pair<Integer, Integer>, List<Object>> entry : confusionMatrixMetaData.entrySet()) {
//Entry Pair: (Actual,Predicted)
if (entry.getKey().getFirst() == actualClass) {
int actual = entry.getKey().getFirst();
int predicted = entry.getKey().getSecond();
for (Object m : entry.getValue()) {
out.add(new Prediction(actual, predicted, m));
}
}
}
return out;
}
use of org.deeplearning4j.eval.meta.Prediction in project deeplearning4j by deeplearning4j.
the class Evaluation method getPredictions.
/**
* Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)
*
* @param actualClass Actual class
* @param predictedClass Predicted class
* @return List of predictions that match the specified actual/predicted classes, or null if the "evaluate with metadata" method was not used
*/
public List<Prediction> getPredictions(int actualClass, int predictedClass) {
if (confusionMatrixMetaData == null)
return null;
List<Prediction> out = new ArrayList<>();
List<Object> list = confusionMatrixMetaData.get(new Pair<>(actualClass, predictedClass));
if (list == null)
return out;
for (Object meta : list) {
out.add(new Prediction(actualClass, predictedClass, meta));
}
return out;
}
Aggregations