use of de.lmu.ifi.dbs.elki.evaluation.classification.holdout.TrainingAndTestSet in project elki by elki-project.
the class ClassifierHoldoutEvaluationTask method run.
@Override
public void run() {
Duration ptime = LOG.newDuration("evaluation.time.load").begin();
MultipleObjectsBundle allData = databaseConnection.loadData();
holdout.initialize(allData);
LOG.statistics(ptime.end());
Duration time = LOG.newDuration("evaluation.time.total").begin();
ArrayList<ClassLabel> labels = holdout.getLabels();
int[][] confusion = new int[labels.size()][labels.size()];
for (int p = 0; p < holdout.numberOfPartitions(); p++) {
TrainingAndTestSet partition = holdout.nextPartitioning();
// Load the data set into a database structure (for indexing)
Duration dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".init.time").begin();
Database db = new StaticArrayDatabase(new MultipleObjectsBundleDatabaseConnection(partition.getTraining()), indexFactories);
db.initialize();
LOG.statistics(dur.end());
// Train the classifier
dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".train.time").begin();
Relation<ClassLabel> lrel = db.getRelation(TypeUtil.CLASSLABEL);
algorithm.buildClassifier(db, lrel);
LOG.statistics(dur.end());
// Evaluate the test set
dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".evaluation.time").begin();
// FIXME: this part is still a big hack, unfortunately!
MultipleObjectsBundle test = partition.getTest();
int lcol = AbstractHoldout.findClassLabelColumn(test);
int tcol = (lcol == 0) ? 1 : 0;
for (int i = 0, l = test.dataLength(); i < l; ++i) {
@SuppressWarnings("unchecked") O obj = (O) test.data(i, tcol);
ClassLabel truelbl = (ClassLabel) test.data(i, lcol);
ClassLabel predlbl = algorithm.classify(obj);
int pred = Collections.binarySearch(labels, predlbl);
int real = Collections.binarySearch(labels, truelbl);
confusion[pred][real]++;
}
LOG.statistics(dur.end());
}
LOG.statistics(time.end());
ConfusionMatrix m = new ConfusionMatrix(labels, confusion);
LOG.statistics(m.toString());
}
Aggregations