use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class LogRiskOptimizerTest method test1.
private static void test1() {
MultiLabelClfDataSet train = MultiLabelSynthesizer.independentNoise();
MultiLabelClfDataSet test = MultiLabelSynthesizer.independent();
CMLCRF cmlcrf = new CMLCRF(train);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, 1);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -1);
MLScorer fScorer = new FScorer();
LogRiskOptimizer fOptimizer = new LogRiskOptimizer(train, fScorer, cmlcrf, 1, false, false, 1, 1);
InstanceF1Predictor plugInF1 = new InstanceF1Predictor(cmlcrf);
System.out.println(cmlcrf);
System.out.println("initial loss = " + fOptimizer.objective());
System.out.println("training performance acc");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance acc");
System.out.println(new MLMeasures(cmlcrf, test));
System.out.println("training performance f1");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test performance f1");
System.out.println(new MLMeasures(plugInF1, test));
while (!fOptimizer.getTerminator().shouldTerminate()) {
System.out.println("------------");
fOptimizer.iterate();
System.out.println(fOptimizer.getTerminator().getLastValue());
System.out.println("training performance acc");
System.out.println(new MLMeasures(cmlcrf, train));
System.out.println("test performance acc");
System.out.println(new MLMeasures(cmlcrf, test));
System.out.println("training performance f1");
System.out.println(new MLMeasures(plugInF1, train));
System.out.println("test performance f1");
System.out.println(new MLMeasures(plugInF1, test));
}
System.out.println(cmlcrf);
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class App6 method train.
private static void train(Config config) throws Exception {
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(trainSet);
double gaussianVariance = config.getDouble("train.gaussianVariance");
cmlcrf.setConsiderPair(true);
CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
int maxIteration = config.getInt("train.maxIteration");
crfLoss.setRegularizeAll(true);
LBFGS optimizer = new LBFGS(crfLoss);
optimizer.getTerminator().setMaxIteration(maxIteration);
PluginPredictor<CMLCRF> predictor = null;
String predictTarget = config.getString("predict.target");
switch(predictTarget) {
case "subsetAccuracy":
predictor = new SubsetAccPredictor(cmlcrf);
break;
case "instanceFMeasure":
predictor = new InstanceF1Predictor(cmlcrf);
break;
default:
throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
}
int progressInterval = config.getInt("train.showProgress.interval");
System.out.println("start training");
int iteration = 0;
while (true) {
optimizer.iterate();
iteration += 1;
if (iteration % progressInterval == 0) {
System.out.println("iteration " + iteration);
System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
System.out.println("training performance:");
System.out.println(new MLMeasures(predictor, trainSet));
System.out.println("test performance:");
System.out.println(new MLMeasures(predictor, testSet));
}
if (optimizer.getTerminator().shouldTerminate()) {
System.out.println("iteration " + iteration);
System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
System.out.println("training performance:");
System.out.println(new MLMeasures(predictor, trainSet));
System.out.println("test performance:");
System.out.println(new MLMeasures(predictor, testSet));
System.out.println("training done!");
break;
}
}
String modelName = "model_crf";
String output = config.getString("output.folder");
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
MultiLabel[] predictions = cmlcrf.predict(trainSet);
File predictionFile = new File(output, "train_predictions.txt");
FileUtils.writeStringToFile(predictionFile, PrintUtil.toMutipleLines(predictions));
System.out.println("predictions on the training set are written to " + predictionFile.getAbsolutePath());
if (config.getBoolean("train.generateReports")) {
report(config, trainSet, "trainSet");
}
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class App6 method report.
static void report(Config config, MultiLabelClfDataSet dataSet, String dataName) throws Exception {
System.out.println("generating reports for data set " + dataName);
String output = config.getString("output.folder");
String modelName = "model_crf";
File analysisFolder = new File(new File(output, "reports_crf"), dataName + "_reports");
analysisFolder.mkdirs();
FileUtils.cleanDirectory(analysisFolder);
CMLCRF crf = (CMLCRF) Serialization.deserialize(new File(output, modelName));
PluginPredictor<CMLCRF> predictorTmp = null;
String predictTarget = config.getString("predict.target");
switch(predictTarget) {
case "subsetAccuracy":
predictorTmp = new SubsetAccPredictor(crf);
break;
case "instanceFMeasure":
predictorTmp = new InstanceF1Predictor(crf);
break;
default:
throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
}
// just to make Lambda expressions happy
final PluginPredictor<CMLCRF> predictor = predictorTmp;
MLMeasures mlMeasures = new MLMeasures(predictor, dataSet);
mlMeasures.getMacroAverage().setLabelTranslator(crf.getLabelTranslator());
System.out.println("performance on dataset " + dataName);
System.out.println(mlMeasures);
boolean simpleCSV = true;
if (simpleCSV) {
// System.out.println("start generating simple CSV report");
double probThreshold = config.getDouble("report.classProbThreshold");
File csv = new File(analysisFolder, "report.csv");
List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> CRFInspector.simplePredictionAnalysis(crf, predictor, dataSet, i, probThreshold)).collect(Collectors.toList());
StringBuilder sb = new StringBuilder();
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
String str = strs.get(i);
sb.append(str);
}
FileUtils.writeStringToFile(csv, sb.toString(), false);
// System.out.println("finish generating simple CSV report");
}
boolean dataInfoToJson = true;
if (dataInfoToJson) {
// System.out.println("start writing data info to json");
Set<String> modelLabels = IntStream.range(0, crf.getNumClasses()).mapToObj(i -> crf.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
jsonGenerator.writeStartObject();
jsonGenerator.writeStringField("dataSet", dataName);
jsonGenerator.writeNumberField("numClassesInModel", crf.getNumClasses());
jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
for (String label : dataNotModelLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
for (String label : modelNotDataLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
jsonGenerator.writeEndObject();
jsonGenerator.close();
// System.out.println("finish writing data info to json");
}
boolean modelConfigToJson = true;
if (modelConfigToJson) {
// System.out.println("start writing model config to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
// System.out.println("finish writing model config to json");
}
boolean performanceToJson = true;
if (performanceToJson) {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
}
boolean individualPerformance = true;
if (individualPerformance) {
// System.out.println("start writing individual label performance to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
// System.out.println("finish writing individual label performance to json");
}
System.out.println("reports generated");
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMEN method tune.
private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
CBM cbm = newCBM(config, trainSet, hyperParameters);
EarlyStopper earlyStopper = loadNewEarlyStopper(config);
ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
if (config.getBoolean("train.randomInitialize")) {
optimizer.randInitialize();
} else {
optimizer.initialize();
}
MultiLabelClassifier classifier;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
classifier = accPredictor;
break;
case "instance_f1":
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = pluginF1;
break;
case "instance_hamming_loss":
MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
classifier = marginalPredictor;
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
int interval = config.getInt("tune.monitorInterval");
for (int iter = 1; true; iter++) {
if (VERBOSE) {
System.out.println("iteration " + iter);
}
optimizer.iterate();
if (iter % interval == 0) {
MLMeasures validMeasures = new MLMeasures(classifier, validSet);
if (VERBOSE) {
System.out.println("validation performance with " + predictTarget + " optimal predictor:");
System.out.println(validMeasures);
}
switch(predictTarget) {
case "instance_set_accuracy":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
break;
case "instance_f1":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
break;
case "instance_hamming_loss":
earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
break;
default:
throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
}
if (earlyStopper.shouldStop()) {
if (VERBOSE) {
System.out.println("Early Stopper: the training should stop now!");
}
break;
}
}
}
if (VERBOSE) {
System.out.println("done!");
}
hyperParameters.iterations = earlyStopper.getBestIteration();
TuneResult tuneResult = new TuneResult();
tuneResult.hyperParameters = hyperParameters;
tuneResult.performance = earlyStopper.getBestValue();
return tuneResult;
}
use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.
the class CBMEN method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
Aggregations