use of edu.cmu.minorthird.classify.Classifier in project lucida by claritylab.
the class HierarchicalClassifier method classification.
public ClassLabel classification(Instance instance) {
String labelName = "";
double weight = 1;
for (int i = 0; i < classLevels; i++) {
Classifier currentClassifier = (Classifier) classifiers.get(labelName);
ClassLabel currentLabel = currentClassifier.classification(instance);
labelName = getNewLabelName(labelName, currentLabel.bestClassName(), i);
weight *= currentLabel.bestWeight();
}
return new ClassLabel(labelName, weight);
}
use of edu.cmu.minorthird.classify.Classifier in project lucida by claritylab.
the class HierarchicalClassifier method explain.
public String explain(Instance instance) {
String labelName = "";
String explanation = "";
for (int i = 0; i < classLevels; i++) {
Classifier currentClassifier = (Classifier) classifiers.get(labelName);
ClassLabel currentLabel = currentClassifier.classification(instance);
labelName = getNewLabelName(labelName, currentLabel.bestClassName(), i);
explanation += currentClassifier.explain(instance);
}
return explanation;
}
use of edu.cmu.minorthird.classify.Classifier in project lucida by claritylab.
the class ScoreNormalizationFilter method main.
/**
* Evaluates all combinations of features and models and trains a classifier
* using the best combination.
*
* @param args {directory containing serialized results,
* output directory for evaluation reports and classifier}
*/
public static void main(String[] args) {
// enable output of status and error messages
MsgPrinter.enableStatusMsgs(true);
MsgPrinter.enableErrorMsgs(true);
// get command line parameters
if (args.length < 2) {
MsgPrinter.printUsage("java ScoreNormalizationFilter " + "serialized_results_dir output_dir");
System.exit(1);
}
String serializedDir = args[0];
String outputDir = args[1];
// // evaluate all combinations of features and models,
// // get best combination according to F1 measure
// String reportsDir = new File(outputDir, "reports").getPath();
// String[][] combination = evaluateAll(serializedDir, reportsDir);
// String[] features = combination[0];
// String model = combination[1][0];
// or simply get selected features and model
String[] features = SELECTED_FEATURES;
String model = SELECTED_MODEL;
// train classifier using best/selected features and model
String msg = "Training classifier using model " + model + " with feature(s) " + StringUtils.concat(features, ", ") + " (" + MsgPrinter.getTimestamp() + ")...";
MsgPrinter.printStatusMsg(StringUtils.repeat("-", msg.length()));
MsgPrinter.printStatusMsg(msg);
MsgPrinter.printStatusMsg(StringUtils.repeat("-", msg.length()));
Classifier classifier = train(serializedDir, features, model);
// serialize classifier to file
String classifiersDir = new File(outputDir, "classifiers").getPath();
String[] dataSets = FileUtils.getVisibleSubDirs(serializedDir);
String filename = model + "_" + StringUtils.concat(features, "+") + "_" + StringUtils.concat(dataSets, "+") + ".serialized";
try {
FileUtils.writeSerialized(classifier, new File(classifiersDir, filename));
} catch (IOException e) {
MsgPrinter.printErrorMsg("Failed to serialize classifier to file " + filename + ":");
MsgPrinter.printErrorMsg(e.toString());
System.exit(1);
}
MsgPrinter.printStatusMsg("...done.");
}
use of edu.cmu.minorthird.classify.Classifier in project lucida by claritylab.
the class ScoreNormalizationFilter method train.
/**
* Trains a classifier using the given training data, features and model.
*
* @param serializedDir directory containing serialized results
* @param features selected features
* @param model selected model
* @return trained classifier
*/
public static Classifier train(String serializedDir, String[] features, String model) {
// create training set with given features from serialized results
Dataset trainingSet = createDataset(features, serializedDir);
// create learner for given model
ClassifierLearner learner = createLearner(model);
// train classifier
Classifier classifier = new DatasetClassifierTeacher(trainingSet).train(learner);
return classifier;
}
Aggregations