use of edu.cmu.minorthird.classify.Dataset in project lucida by claritylab.
the class ScoreNormalizationFilter method evaluate.
/**
* Performs a cross-validation on the given data set for the given features
* and model.
*
* @param serializedDir directory containing serialized results
* @param features selected features
* @param model selected model
* @return evaluation statistics
*/
public static Evaluation evaluate(String serializedDir, String[] features, String model) {
// create data set with selected features from serialized results
Dataset dataSet = createDataset(features, serializedDir);
// create learner for selected model
ClassifierLearner learner = createLearner(model);
// cross-validate model on data set
RandomElement r = new RandomElement(System.currentTimeMillis());
Splitter splitter = new CrossValSplitter(r, NUM_FOLDS);
CrossValidatedDataset cvDataset = new CrossValidatedDataset(learner, dataSet, splitter, true);
Evaluation eval = cvDataset.getEvaluation();
return eval;
}
use of edu.cmu.minorthird.classify.Dataset in project lucida by claritylab.
the class ScoreNormalizationFilter method createDataset.
/**
* Creates a training/evaluation set from serialized judged
* <code>Result</code> objects.
*
* @param features selected features
* @param serializedDir directory containing serialized results
* @return training/evaluation set
*/
private static Dataset createDataset(String[] features, String serializedDir) {
Dataset set = new BasicDataset();
File[] files = FileUtils.getFilesRec(serializedDir);
for (File file : files) {
// one file per question
String filename = file.getName();
if (!filename.endsWith(".serialized"))
continue;
// get question ID and results
String qid = filename.replace(".serialized", "");
Result[] results = readSerializedResults(file);
// create examples and add to data set
for (Result result : results) {
// only factoid answers with 1 extraction technique
if (result.getScore() <= 0 || result.getScore() == Float.POSITIVE_INFINITY || result.getExtractionTechniques() == null || result.getExtractionTechniques().length != 1)
continue;
Example example = createExample(features, result, results, qid);
set.add(example);
}
}
return set;
}
use of edu.cmu.minorthird.classify.Dataset in project lucida by claritylab.
the class ScoreNormalizationFilter method train.
/**
* Trains a classifier using the given training data, features and model.
*
* @param serializedDir directory containing serialized results
* @param features selected features
* @param model selected model
* @return trained classifier
*/
public static Classifier train(String serializedDir, String[] features, String model) {
// create training set with given features from serialized results
Dataset trainingSet = createDataset(features, serializedDir);
// create learner for given model
ClassifierLearner learner = createLearner(model);
// train classifier
Classifier classifier = new DatasetClassifierTeacher(trainingSet).train(learner);
return classifier;
}
use of edu.cmu.minorthird.classify.Dataset in project lucida by claritylab.
the class HierarchicalClassifierTrainer method makeDataset.
private Dataset makeDataset(String fileName) {
if (trainingLabels == null) {
loadTraining = true;
trainingLabels = new HashSet<String>();
}
Dataset set = new BasicDataset();
extractor.setUseClassLevels(useClassLevels);
extractor.setClassLevels(learnerNames.length);
Example[] examples = extractor.loadFile(fileName);
for (int i = 0; i < examples.length; i++) {
String label = examples[i].getLabel().bestClassName();
if (classLabels.contains(label)) {
MutableInstance instance = new MutableInstance(examples[i].getSource(), examples[i].getSubpopulationId());
Feature.Looper bLooper = examples[i].binaryFeatureIterator();
while (bLooper.hasNext()) {
Feature f = bLooper.nextFeature();
if (featureTypes.contains(f.getPart(0))) {
instance.addBinary(f);
}
}
Feature.Looper nLooper = examples[i].numericFeatureIterator();
while (nLooper.hasNext()) {
Feature f = nLooper.nextFeature();
if (featureTypes.contains(f.getPart(0))) {
instance.addNumeric(f, examples[i].getWeight(f));
}
}
Example example = new Example(instance, examples[i].getLabel());
MLToolkit.println(example);
if (loadTraining) {
trainingLabels.add(label);
set.add(example);
} else {
if (!trainingLabels.contains(label))
MLToolkit.println("Label of test example not found in training set (discarding): " + label);
else
set.add(example);
}
} else {
MLToolkit.println("Discarding example for Class: " + label);
}
}
if (loadTraining)
loadTraining = false;
MLToolkit.println("Loaded " + set.size() + " examples for experiment from " + fileName);
return set;
}
Aggregations