Search in sources :

Example 71 with StopWatch

use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.

the class IMLGradientBoostingTest method spam_build.

private static void spam_build() throws Exception {
    ClfDataSet singleLabeldataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"), DataSetType.CLF_DENSE, true);
    int numDataPoints = singleLabeldataSet.getNumDataPoints();
    int numFeatures = singleLabeldataSet.getNumFeatures();
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(2).build();
    int[] labels = singleLabeldataSet.getLabels();
    for (int i = 0; i < numDataPoints; i++) {
        dataSet.addLabel(i, labels[i]);
        for (int j = 0; j < numFeatures; j++) {
            double value = singleLabeldataSet.getRow(i).get(j);
            dataSet.setFeatureValue(i, j, value);
        }
    }
    IMLGradientBoosting boosting = new IMLGradientBoosting(2);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(7).learningRate(0.1).numSplitIntervals(50).minDataPerLeaf(3).dataSamplingRate(1).featureSamplingRate(1).build();
    System.out.println(Arrays.toString(trainConfig.getActiveFeatures()));
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 200; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println("accuracy=" + Accuracy.accuracy(boosting, dataSet));
    //            System.out.println(Arrays.toString(boosting.getGradients(0)));
    //            System.out.println(Arrays.toString(boosting.getGradients(1)));
    }
    stopWatch.stop();
    System.out.println(stopWatch);
    System.out.println(boosting);
    for (int i = 0; i < numDataPoints; i++) {
        org.apache.mahout.math.Vector featureRow = dataSet.getRow(i);
        System.out.println("" + i);
        System.out.println(dataSet.getMultiLabels()[i]);
        System.out.println(boosting.predict(featureRow));
    }
    System.out.println("accuracy");
    System.out.println(Accuracy.accuracy(boosting, dataSet));
    boosting.serialize(new File(TMP, "/imlgb/boosting.ser"));
    Comparator<Map.Entry<List<Integer>, Double>> comparator = Comparator.comparing(entry -> entry.getValue());
    System.out.println(IMLGBInspector.countPathMatches(boosting, dataSet, 0).entrySet().stream().sorted(comparator.reversed()).collect(Collectors.toList()).get(0));
//        System.out.println(pathcount.values().stream().sorted().collect(Collectors.toList()));
}
Also used : StopWatch(org.apache.commons.lang3.time.StopWatch) Vector(org.apache.mahout.math.Vector) File(java.io.File)

Example 72 with StopWatch

use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.

the class MLLogisticTrainerTest method test6.

static void test6() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    MLLogisticTrainer trainer = MLLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MLLogisticRegression mlLogisticRegression = trainer.train(dataSet, assignments);
    System.out.println(stopWatch);
    System.out.println("training accuracy=" + Accuracy.accuracy(mlLogisticRegression, dataSet));
    System.out.println("training overlap = " + Overlap.overlap(mlLogisticRegression, dataSet));
    System.out.println("test accuracy=" + Accuracy.accuracy(mlLogisticRegression, testSet));
    System.out.println("test overlap = " + Overlap.overlap(mlLogisticRegression, testSet));
}
Also used : File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 73 with StopWatch

use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.

the class MLPlattScalingTest method test1.

private static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    IMLGradientBoosting boosting = new IMLGradientBoosting(dataSet.getNumClasses());
    List<MultiLabel> assignments = DataSetUtil.gatherMultiLabels(dataSet);
    boosting.setAssignments(assignments);
    IMLGBConfig trainConfig = new IMLGBConfig.Builder(dataSet).numLeaves(2).learningRate(0.1).numSplitIntervals(1000).minDataPerLeaf(2).dataSamplingRate(1).featureSamplingRate(1).build();
    IMLGBTrainer trainer = new IMLGBTrainer(trainConfig, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 100; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    MLPlattScaling plattScaling = new MLPlattScaling(dataSet, boosting);
    for (int i = 0; i < 10; i++) {
        System.out.println(Arrays.toString(boosting.predictClassScores(dataSet.getRow(i))));
        System.out.println(Arrays.toString(boosting.predictClassProbs(dataSet.getRow(i))));
        System.out.println(Arrays.toString(plattScaling.predictClassProbs(dataSet.getRow(i))));
        System.out.println("======================");
    }
}
Also used : IMLGradientBoosting(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGradientBoosting) IMLGBTrainer(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBTrainer) IMLGBConfig(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBConfig) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 74 with StopWatch

use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.

the class AdaBoostMHTest method test1.

static void test1() throws Exception {
    MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
    MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
    AdaBoostMH boosting = new AdaBoostMH(dataSet.getNumClasses());
    AdaBoostMHTrainer trainer = new AdaBoostMHTrainer(dataSet, boosting);
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int round = 0; round < 500; round++) {
        System.out.println("round=" + round);
        trainer.iterate();
        System.out.println(stopWatch);
    }
    System.out.println("training accuracy=" + Accuracy.accuracy(boosting, dataSet));
    System.out.println("training overlap = " + Overlap.overlap(boosting, dataSet));
    System.out.println("test accuracy=" + Accuracy.accuracy(boosting, testSet));
    System.out.println("test overlap = " + Overlap.overlap(boosting, testSet));
}
Also used : File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 75 with StopWatch

use of org.apache.commons.lang3.time.StopWatch in project pyramid by cheng-li.

the class App2 method train.

static void train(Config config, Logger logger) throws Exception {
    String output = config.getString("output.folder");
    int numIterations = config.getInt("train.numIterations");
    int numLeaves = config.getInt("train.numLeaves");
    double learningRate = config.getDouble("train.learningRate");
    int minDataPerLeaf = config.getInt("train.minDataPerLeaf");
    String modelName = "model_app3";
    //        double featureSamplingRate = config.getDouble("train.featureSamplingRate");
    //        double dataSamplingRate = config.getDouble("train.dataSamplingRate");
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    MultiLabelClfDataSet dataSet = loadData(config, config.getString("input.trainData"));
    MultiLabelClfDataSet testSet = null;
    if (config.getBoolean("train.showTestProgress")) {
        testSet = loadData(config, config.getString("input.testData"));
    }
    int numClasses = dataSet.getNumClasses();
    logger.info("number of class = " + numClasses);
    IMLGBConfig imlgbConfig = new IMLGBConfig.Builder(dataSet).learningRate(learningRate).minDataPerLeaf(minDataPerLeaf).numLeaves(numLeaves).numSplitIntervals(config.getInt("train.numSplitIntervals")).usePrior(config.getBoolean("train.usePrior")).build();
    IMLGradientBoosting boosting;
    if (config.getBoolean("train.warmStart")) {
        boosting = IMLGradientBoosting.deserialize(new File(output, modelName));
    } else {
        boosting = new IMLGradientBoosting(numClasses);
    }
    logger.info("During training, the performance is reported using Hamming loss optimal predictor");
    logger.info("initialing trainer");
    IMLGBTrainer trainer = new IMLGBTrainer(imlgbConfig, boosting);
    boolean earlyStop = config.getBoolean("train.earlyStop");
    List<EarlyStopper> earlyStoppers = new ArrayList<>();
    List<Terminator> terminators = new ArrayList<>();
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            EarlyStopper earlyStopper = new EarlyStopper(EarlyStopper.Goal.MINIMIZE, config.getInt("train.earlyStop.patience"));
            earlyStopper.setMinimumIterations(config.getInt("train.earlyStop.minIterations"));
            earlyStoppers.add(earlyStopper);
        }
        for (int l = 0; l < numClasses; l++) {
            Terminator terminator = new Terminator();
            terminator.setMaxStableIterations(config.getInt("train.earlyStop.patience")).setMinIterations(config.getInt("train.earlyStop.minIterations") / config.getInt("train.showProgress.interval")).setAbsoluteEpsilon(config.getDouble("train.earlyStop.absoluteChange")).setRelativeEpsilon(config.getDouble("train.earlyStop.relativeChange")).setOperation(Terminator.Operation.OR);
            terminators.add(terminator);
        }
    }
    logger.info("trainer initialized");
    int numLabelsLeftToTrain = numClasses;
    int progressInterval = config.getInt("train.showProgress.interval");
    for (int i = 1; i <= numIterations; i++) {
        logger.info("iteration " + i);
        trainer.iterate();
        if (config.getBoolean("train.showTrainProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("training set performance");
            logger.info(new MLMeasures(boosting, dataSet).toString());
        }
        if (config.getBoolean("train.showTestProgress") && (i % progressInterval == 0 || i == numIterations)) {
            logger.info("test set performance");
            logger.info(new MLMeasures(boosting, testSet).toString());
            if (earlyStop) {
                for (int l = 0; l < numClasses; l++) {
                    EarlyStopper earlyStopper = earlyStoppers.get(l);
                    Terminator terminator = terminators.get(l);
                    if (!trainer.getShouldStop()[l]) {
                        double kl = KL(boosting, testSet, l);
                        earlyStopper.add(i, kl);
                        terminator.add(kl);
                        if (earlyStopper.shouldStop() || terminator.shouldTerminate()) {
                            logger.info("training for label " + l + " (" + dataSet.getLabelTranslator().toExtLabel(l) + ") should stop now");
                            logger.info("the best number of training iterations for the label is " + earlyStopper.getBestIteration());
                            trainer.setShouldStop(l);
                            numLabelsLeftToTrain -= 1;
                            logger.info("the number of labels left to be trained on = " + numLabelsLeftToTrain);
                        }
                    }
                }
            }
        }
        if (numLabelsLeftToTrain == 0) {
            logger.info("all label training finished");
            break;
        }
    }
    logger.info("training done");
    File serializedModel = new File(output, modelName);
    //todo pick best models
    boosting.serialize(serializedModel);
    logger.info(stopWatch.toString());
    if (earlyStop) {
        for (int l = 0; l < numClasses; l++) {
            logger.info("----------------------------------------------------");
            logger.info("test performance history for label " + l + ": " + earlyStoppers.get(l).history());
            logger.info("model size for label " + l + " = " + (boosting.getRegressors(l).size() - 1));
        }
    }
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        logger.info("start writing top features");
        int limit = config.getInt("report.topFeatures.limit");
        List<TopFeatures> topFeaturesList = IntStream.range(0, boosting.getNumClasses()).mapToObj(k -> IMLGBInspector.topFeatures(boosting, k, limit)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(output, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < boosting.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(dataSet.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(output, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) JsonGenerator(com.fasterxml.jackson.core.JsonGenerator) SimpleFormatter(java.util.logging.SimpleFormatter) PluginPredictor(edu.neu.ccs.pyramid.multilabel_classification.PluginPredictor) edu.neu.ccs.pyramid.multilabel_classification.imlgb(edu.neu.ccs.pyramid.multilabel_classification.imlgb) ArrayList(java.util.ArrayList) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) FileHandler(java.util.logging.FileHandler) FeatureDistribution(edu.neu.ccs.pyramid.feature_selection.FeatureDistribution) JsonEncoding(com.fasterxml.jackson.core.JsonEncoding) Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelPredictionAnalysis(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelPredictionAnalysis) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) edu.neu.ccs.pyramid.eval(edu.neu.ccs.pyramid.eval) Collection(java.util.Collection) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) TunedMarginalClassifier(edu.neu.ccs.pyramid.multilabel_classification.thresholding.TunedMarginalClassifier) FileWriter(java.io.FileWriter) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) Progress(edu.neu.ccs.pyramid.util.Progress) List(java.util.List) JsonFactory(com.fasterxml.jackson.core.JsonFactory) Feature(edu.neu.ccs.pyramid.feature.Feature) MacroFMeasureTuner(edu.neu.ccs.pyramid.multilabel_classification.thresholding.MacroFMeasureTuner) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Vector(org.apache.mahout.math.Vector) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) SetUtil(edu.neu.ccs.pyramid.util.SetUtil) ArrayList(java.util.ArrayList) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Aggregations

StopWatch (org.apache.commons.lang3.time.StopWatch)78 File (java.io.File)48 ArrayList (java.util.ArrayList)17 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)8 Vector (org.apache.mahout.math.Vector)8 VirtualMachine (com.microsoft.azure.management.compute.VirtualMachine)7 Creatable (com.microsoft.azure.management.resources.fluentcore.model.Creatable)7 Config (edu.neu.ccs.pyramid.configuration.Config)7 Network (com.microsoft.azure.management.network.Network)6 IOException (java.io.IOException)6 ResourceGroup (com.microsoft.azure.management.resources.ResourceGroup)5 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)5 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)4 Region (com.microsoft.azure.management.resources.fluentcore.arm.Region)4 StorageAccount (com.microsoft.azure.management.storage.StorageAccount)4 List (java.util.List)4 PublicIPAddress (com.microsoft.azure.management.network.PublicIPAddress)3 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)3 HashMap (java.util.HashMap)3 IntStream (java.util.stream.IntStream)3