Search in sources :

Example 1 with RMSE

use of edu.neu.ccs.pyramid.eval.RMSE in project pyramid by cheng-li.

the class EMLevelEval method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    RegDataSet train = TRECFormat.loadRegDataSet(config.getString("input.trainData"), DataSetType.REG_SPARSE, true);
    Set<Double> unique = new HashSet<>();
    for (double d : train.getLabels()) {
        unique.add(d);
    }
    List<Double> levels = unique.stream().sorted().collect(Collectors.toList());
    RegDataSet test = TRECFormat.loadRegDataSet(config.getString("input.testData"), DataSetType.REG_SPARSE, true);
    double[] doubleTruth = test.getLabels();
    double[] doublePred = loadPrediction(config.getString("input.prediction"));
    double[] roundedPred = Arrays.stream(doublePred).map(d -> round(d, levels)).toArray();
    System.out.println("before rounding");
    System.out.println("rmse = " + RMSE.rmse(doubleTruth, doublePred));
    System.out.println("after rounding");
    System.out.println("rmse = " + RMSE.rmse(doubleTruth, roundedPred));
    System.out.println("accuracy = " + IntStream.range(0, test.getNumDataPoints()).filter(i -> doubleTruth[i] == roundedPred[i]).count() / (double) test.getNumDataPoints());
    System.out.println("the distribution of predicted label for a given true label");
    for (int l = 0; l < levels.size(); l++) {
        double level = levels.get(l);
        System.out.println("for true label " + level);
        truthToPred(test.getLabels(), roundedPred, level, levels);
    }
    System.out.println("=============================");
    System.out.println("the distribution of true label for a given predicted label");
    for (int l = 0; l < levels.size(); l++) {
        double level = levels.get(l);
        System.out.println("for predicted label " + level);
        predToTruth(test.getLabels(), roundedPred, level, levels);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Set(java.util.Set) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) File(java.io.File) HashSet(java.util.HashSet) List(java.util.List) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Accuracy(edu.neu.ccs.pyramid.eval.Accuracy) RMSE(edu.neu.ccs.pyramid.eval.RMSE) Config(edu.neu.ccs.pyramid.configuration.Config) Config(edu.neu.ccs.pyramid.configuration.Config) HashSet(java.util.HashSet)

Example 2 with RMSE

use of edu.neu.ccs.pyramid.eval.RMSE in project pyramid by cheng-li.

the class PMMLConverterTest method main.

public static void main(String[] args) throws Exception {
    RegDataSet trainSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//train"), DataSetType.REG_DENSE, true);
    RegDataSet testSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//test"), DataSetType.REG_DENSE, true);
    LSBoost lsBoost = new LSBoost();
    RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
    RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
    LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
    optimizer.setShrinkage(0.1);
    optimizer.initialize();
    for (int i = 0; i < 10; i++) {
        System.out.println("iteration " + i);
        System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
        System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
        optimizer.iterate();
    }
    FeatureList featureList = trainSet.getFeatureList();
    List<RegressionTree> regressionTrees = lsBoost.getEnsemble(0).getRegressors().stream().filter(a -> a instanceof RegressionTree).map(a -> (RegressionTree) a).collect(Collectors.toList());
    System.out.println(regressionTrees);
    double constant = ((ConstantRegressor) lsBoost.getEnsemble(0).get(0)).getScore();
    PMML pmml = PMMLConverter.encodePMML(null, null, featureList, regressionTrees, (float) constant);
    System.out.println(pmml.toString());
    try (OutputStream os = new FileOutputStream("/Users/chengli/tmp/pmml.xml")) {
        MetroJAXBUtil.marshalPMML(pmml, os);
    }
}
Also used : OutputStream(java.io.OutputStream) DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) TRECFormat(edu.neu.ccs.pyramid.dataset.TRECFormat) PMML(org.dmg.pmml.PMML) FileOutputStream(java.io.FileOutputStream) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Collectors(java.util.stream.Collectors) File(java.io.File) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) List(java.util.List) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) MetroJAXBUtil(org.jpmml.model.MetroJAXBUtil) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) RMSE(edu.neu.ccs.pyramid.eval.RMSE) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) RegressionTree(edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree) OutputStream(java.io.OutputStream) FileOutputStream(java.io.FileOutputStream) ConstantRegressor(edu.neu.ccs.pyramid.regression.ConstantRegressor) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) FileOutputStream(java.io.FileOutputStream) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) PMML(org.dmg.pmml.PMML) File(java.io.File)

Example 3 with RMSE

use of edu.neu.ccs.pyramid.eval.RMSE in project pyramid by cheng-li.

the class ElasticNetLinearRegTrainerTest method test3.

private static void test3() throws Exception {
    RegDataSet dataSet = RegressionSynthesizer.linear();
    LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
    ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet);
    trainer.setRegularization(0.001);
    trainer.setL1Ratio(0.1);
    System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
    trainer.optimize();
    System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
    System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias());
    List<Pair<Integer, Double>> pairs = new ArrayList<>();
    for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
        pairs.add(new Pair<>(element.index(), element.get()));
    }
    Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
    Set<Integer> set = pairs.stream().sorted(comparator.reversed()).limit(4).map(pair -> pair.getFirst()).collect(Collectors.toSet());
    Set<Integer> trueSet = new HashSet<>();
    trueSet.add(0);
    trueSet.add(1);
    trueSet.add(2);
    trueSet.add(3);
    if (set.equals(trueSet)) {
        System.out.println("correct");
    } else {
        System.out.println("incorrect");
    }
}
Also used : DataSetType(edu.neu.ccs.pyramid.dataset.DataSetType) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) java.util(java.util) Grid(edu.neu.ccs.pyramid.util.Grid) StandardFormat(edu.neu.ccs.pyramid.dataset.StandardFormat) Vector(org.apache.mahout.math.Vector) RegressionSynthesizer(edu.neu.ccs.pyramid.simulation.RegressionSynthesizer) RMSE(edu.neu.ccs.pyramid.eval.RMSE) Collectors(java.util.stream.Collectors) Pair(edu.neu.ccs.pyramid.util.Pair) File(java.io.File) Config(edu.neu.ccs.pyramid.configuration.Config) RegDataSet(edu.neu.ccs.pyramid.dataset.RegDataSet) Vector(org.apache.mahout.math.Vector) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

RMSE (edu.neu.ccs.pyramid.eval.RMSE)3 File (java.io.File)3 Collectors (java.util.stream.Collectors)3 Config (edu.neu.ccs.pyramid.configuration.Config)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)2 List (java.util.List)2 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)1 StandardFormat (edu.neu.ccs.pyramid.dataset.StandardFormat)1 TRECFormat (edu.neu.ccs.pyramid.dataset.TRECFormat)1 Accuracy (edu.neu.ccs.pyramid.eval.Accuracy)1 FeatureList (edu.neu.ccs.pyramid.feature.FeatureList)1 ConstantRegressor (edu.neu.ccs.pyramid.regression.ConstantRegressor)1 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)1 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)1 RegressionTree (edu.neu.ccs.pyramid.regression.regression_tree.RegressionTree)1 RegressionSynthesizer (edu.neu.ccs.pyramid.simulation.RegressionSynthesizer)1 Grid (edu.neu.ccs.pyramid.util.Grid)1 Pair (edu.neu.ccs.pyramid.util.Pair)1 FileOutputStream (java.io.FileOutputStream)1