Search in sources :

Example 11 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class LogRiskOptimizerTest method test1.

private static void test1() {
    MultiLabelClfDataSet train = MultiLabelSynthesizer.independentNoise();
    MultiLabelClfDataSet test = MultiLabelSynthesizer.independent();
    CMLCRF cmlcrf = new CMLCRF(train);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, 1);
    cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -1);
    MLScorer fScorer = new FScorer();
    LogRiskOptimizer fOptimizer = new LogRiskOptimizer(train, fScorer, cmlcrf, 1, false, false, 1, 1);
    InstanceF1Predictor plugInF1 = new InstanceF1Predictor(cmlcrf);
    System.out.println(cmlcrf);
    System.out.println("initial loss = " + fOptimizer.objective());
    System.out.println("training performance acc");
    System.out.println(new MLMeasures(cmlcrf, train));
    System.out.println("test performance acc");
    System.out.println(new MLMeasures(cmlcrf, test));
    System.out.println("training performance f1");
    System.out.println(new MLMeasures(plugInF1, train));
    System.out.println("test performance f1");
    System.out.println(new MLMeasures(plugInF1, test));
    while (!fOptimizer.getTerminator().shouldTerminate()) {
        System.out.println("------------");
        fOptimizer.iterate();
        System.out.println(fOptimizer.getTerminator().getLastValue());
        System.out.println("training performance acc");
        System.out.println(new MLMeasures(cmlcrf, train));
        System.out.println("test performance acc");
        System.out.println(new MLMeasures(cmlcrf, test));
        System.out.println("training performance f1");
        System.out.println(new MLMeasures(plugInF1, train));
        System.out.println("test performance f1");
        System.out.println(new MLMeasures(plugInF1, test));
    }
    System.out.println(cmlcrf);
}
Also used : MLScorer(edu.neu.ccs.pyramid.multilabel_classification.MLScorer) FScorer(edu.neu.ccs.pyramid.multilabel_classification.FScorer) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 12 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class App6 method train.

private static void train(Config config) throws Exception {
    MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
    CMLCRF cmlcrf = new CMLCRF(trainSet);
    double gaussianVariance = config.getDouble("train.gaussianVariance");
    cmlcrf.setConsiderPair(true);
    CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
    int maxIteration = config.getInt("train.maxIteration");
    crfLoss.setRegularizeAll(true);
    LBFGS optimizer = new LBFGS(crfLoss);
    optimizer.getTerminator().setMaxIteration(maxIteration);
    PluginPredictor<CMLCRF> predictor = null;
    String predictTarget = config.getString("predict.target");
    switch(predictTarget) {
        case "subsetAccuracy":
            predictor = new SubsetAccPredictor(cmlcrf);
            break;
        case "instanceFMeasure":
            predictor = new InstanceF1Predictor(cmlcrf);
            break;
        default:
            throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
    }
    int progressInterval = config.getInt("train.showProgress.interval");
    System.out.println("start training");
    int iteration = 0;
    while (true) {
        optimizer.iterate();
        iteration += 1;
        if (iteration % progressInterval == 0) {
            System.out.println("iteration " + iteration);
            System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
            System.out.println("training performance:");
            System.out.println(new MLMeasures(predictor, trainSet));
            System.out.println("test performance:");
            System.out.println(new MLMeasures(predictor, testSet));
        }
        if (optimizer.getTerminator().shouldTerminate()) {
            System.out.println("iteration " + iteration);
            System.out.println("training objective = " + optimizer.getTerminator().getLastValue());
            System.out.println("training performance:");
            System.out.println(new MLMeasures(predictor, trainSet));
            System.out.println("test performance:");
            System.out.println(new MLMeasures(predictor, testSet));
            System.out.println("training done!");
            break;
        }
    }
    String modelName = "model_crf";
    String output = config.getString("output.folder");
    (new File(output)).mkdirs();
    File serializeModel = new File(output, modelName);
    cmlcrf.serialize(serializeModel);
    MultiLabel[] predictions = cmlcrf.predict(trainSet);
    File predictionFile = new File(output, "train_predictions.txt");
    FileUtils.writeStringToFile(predictionFile, PrintUtil.toMutipleLines(predictions));
    System.out.println("predictions on the training set are written to " + predictionFile.getAbsolutePath());
    if (config.getBoolean("train.generateReports")) {
        report(config, trainSet, "trainSet");
    }
}
Also used : LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File)

Example 13 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class App6 method report.

static void report(Config config, MultiLabelClfDataSet dataSet, String dataName) throws Exception {
    System.out.println("generating reports for data set " + dataName);
    String output = config.getString("output.folder");
    String modelName = "model_crf";
    File analysisFolder = new File(new File(output, "reports_crf"), dataName + "_reports");
    analysisFolder.mkdirs();
    FileUtils.cleanDirectory(analysisFolder);
    CMLCRF crf = (CMLCRF) Serialization.deserialize(new File(output, modelName));
    PluginPredictor<CMLCRF> predictorTmp = null;
    String predictTarget = config.getString("predict.target");
    switch(predictTarget) {
        case "subsetAccuracy":
            predictorTmp = new SubsetAccPredictor(crf);
            break;
        case "instanceFMeasure":
            predictorTmp = new InstanceF1Predictor(crf);
            break;
        default:
            throw new IllegalArgumentException("predict.target must be subsetAccuracy or instanceFMeasure");
    }
    // just to make Lambda expressions happy
    final PluginPredictor<CMLCRF> predictor = predictorTmp;
    MLMeasures mlMeasures = new MLMeasures(predictor, dataSet);
    mlMeasures.getMacroAverage().setLabelTranslator(crf.getLabelTranslator());
    System.out.println("performance on dataset " + dataName);
    System.out.println(mlMeasures);
    boolean simpleCSV = true;
    if (simpleCSV) {
        //            System.out.println("start generating simple CSV report");
        double probThreshold = config.getDouble("report.classProbThreshold");
        File csv = new File(analysisFolder, "report.csv");
        List<String> strs = IntStream.range(0, dataSet.getNumDataPoints()).parallel().mapToObj(i -> CRFInspector.simplePredictionAnalysis(crf, predictor, dataSet, i, probThreshold)).collect(Collectors.toList());
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            String str = strs.get(i);
            sb.append(str);
        }
        FileUtils.writeStringToFile(csv, sb.toString(), false);
    //            System.out.println("finish generating simple CSV report");
    }
    boolean dataInfoToJson = true;
    if (dataInfoToJson) {
        //            System.out.println("start writing data info to json");
        Set<String> modelLabels = IntStream.range(0, crf.getNumClasses()).mapToObj(i -> crf.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        Set<String> dataSetLabels = DataSetUtil.gatherLabels(dataSet).stream().map(i -> dataSet.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
        JsonGenerator jsonGenerator = new JsonFactory().createGenerator(new File(analysisFolder, "data_info.json"), JsonEncoding.UTF8);
        jsonGenerator.writeStartObject();
        jsonGenerator.writeStringField("dataSet", dataName);
        jsonGenerator.writeNumberField("numClassesInModel", crf.getNumClasses());
        jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", dataSet.getNumClasses());
        Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
        Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
        jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
        jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
        jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
        for (String label : dataNotModelLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
        for (String label : modelNotDataLabels) {
            jsonGenerator.writeObject(label);
        }
        jsonGenerator.writeEndArray();
        jsonGenerator.writeNumberField("labelCardinality", dataSet.labelCardinality());
        jsonGenerator.writeEndObject();
        jsonGenerator.close();
    //            System.out.println("finish writing data info to json");
    }
    boolean modelConfigToJson = true;
    if (modelConfigToJson) {
        //            System.out.println("start writing model config to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "model_config.json"), config);
    //            System.out.println("finish writing model config to json");
    }
    boolean performanceToJson = true;
    if (performanceToJson) {
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "performance.json"), mlMeasures);
    }
    boolean individualPerformance = true;
    if (individualPerformance) {
        //            System.out.println("start writing individual label performance to json");
        ObjectMapper objectMapper = new ObjectMapper();
        objectMapper.writeValue(new File(analysisFolder, "individual_performance.json"), mlMeasures.getMacroAverage());
    //            System.out.println("finish writing individual label performance to json");
    }
    System.out.println("reports generated");
}
Also used : edu.neu.ccs.pyramid.multilabel_classification.crf(edu.neu.ccs.pyramid.multilabel_classification.crf) IntStream(java.util.stream.IntStream) MacroF1Predictor(edu.neu.ccs.pyramid.multilabel_classification.imlgb.MacroF1Predictor) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) HammingPredictor(edu.neu.ccs.pyramid.multilabel_classification.imlgb.HammingPredictor) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) LBFGS(edu.neu.ccs.pyramid.optimization.LBFGS) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) Overlap(edu.neu.ccs.pyramid.eval.Overlap) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FMeasure(edu.neu.ccs.pyramid.eval.FMeasure) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) TimeUnit(java.util.concurrent.TimeUnit) Progress(edu.neu.ccs.pyramid.util.Progress) IMLGBInspector(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBInspector) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) Handlers.output(edu.stanford.nlp.util.logging.RedwoodConfiguration.Handlers.output) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) JsonFactory(com.fasterxml.jackson.core.JsonFactory) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 14 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class CBMEN method tune.

private static TuneResult tune(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet) throws Exception {
    CBM cbm = newCBM(config, trainSet, hyperParameters);
    EarlyStopper earlyStopper = loadNewEarlyStopper(config);
    ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    if (config.getBoolean("train.randomInitialize")) {
        optimizer.randInitialize();
    } else {
        optimizer.initialize();
    }
    MultiLabelClassifier classifier;
    String predictTarget = config.getString("tune.targetMetric");
    switch(predictTarget) {
        case "instance_set_accuracy":
            AccPredictor accPredictor = new AccPredictor(cbm);
            accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
            classifier = accPredictor;
            break;
        case "instance_f1":
            PluginF1 pluginF1 = new PluginF1(cbm);
            List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
            pluginF1.setSupport(support);
            pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = pluginF1;
            break;
        case "instance_hamming_loss":
            MarginalPredictor marginalPredictor = new MarginalPredictor(cbm);
            marginalPredictor.setPiThreshold(config.getDouble("predict.piThreshold"));
            classifier = marginalPredictor;
            break;
        default:
            throw new IllegalArgumentException("predictTarget should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
    }
    int interval = config.getInt("tune.monitorInterval");
    for (int iter = 1; true; iter++) {
        if (VERBOSE) {
            System.out.println("iteration " + iter);
        }
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(classifier, validSet);
            if (VERBOSE) {
                System.out.println("validation performance with " + predictTarget + " optimal predictor:");
                System.out.println(validMeasures);
            }
            switch(predictTarget) {
                case "instance_set_accuracy":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
                    break;
                case "instance_f1":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getF1());
                    break;
                case "instance_hamming_loss":
                    earlyStopper.add(iter, validMeasures.getInstanceAverage().getHammingLoss());
                    break;
                default:
                    throw new IllegalArgumentException("predictTarget should be instance_set_accuracy or instance_f1");
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    System.out.println("Early Stopper: the training should stop now!");
                }
                break;
            }
        }
    }
    if (VERBOSE) {
        System.out.println("done!");
    }
    hyperParameters.iterations = earlyStopper.getBestIteration();
    TuneResult tuneResult = new TuneResult();
    tuneResult.hyperParameters = hyperParameters;
    tuneResult.performance = earlyStopper.getBestValue();
    return tuneResult;
}
Also used : EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 15 with MLMeasures

use of edu.neu.ccs.pyramid.eval.MLMeasures in project pyramid by cheng-li.

the class CBMEN method reportF1Prediction.

private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
    System.out.println("============================================================");
    System.out.println("Making predictions on test set with the instance F1 optimal predictor");
    String output = config.getString("output.dir");
    PluginF1 pluginF1 = new PluginF1(cbm);
    List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
    pluginF1.setSupport(support);
    pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
    MultiLabel[] predictions = pluginF1.predict(dataSet);
    MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
    System.out.println("test performance with the instance F1 optimal predictor");
    System.out.println(mlMeasures);
    File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
    FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
    System.out.println("test performance is saved to " + performanceFile.toString());
    // Here we do not use approximation
    double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
    File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
    try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            br.write(predictions[i].toString());
            br.write(":");
            br.write("" + setProbs[i]);
            br.newLine();
        }
    }
    System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
    System.out.println("============================================================");
}
Also used : IntStream(java.util.stream.IntStream) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) java.util(java.util) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) FileWriter(java.io.FileWriter) BufferedWriter(java.io.BufferedWriter) File(java.io.File) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Aggregations

MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)24 File (java.io.File)17 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)12 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)12 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)11 Config (edu.neu.ccs.pyramid.configuration.Config)10 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)10 Serialization (edu.neu.ccs.pyramid.util.Serialization)10 Paths (java.nio.file.Paths)10 Collectors (java.util.stream.Collectors)10 IntStream (java.util.stream.IntStream)10 FileUtils (org.apache.commons.io.FileUtils)10 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)9 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)9 Pair (edu.neu.ccs.pyramid.util.Pair)9 BufferedWriter (java.io.BufferedWriter)9 FileWriter (java.io.FileWriter)9 java.util (java.util)9 StopWatch (org.apache.commons.lang3.time.StopWatch)9 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)7