use of edu.neu.ccs.pyramid.eval.RMSE in project pyramid by cheng-li.
the class EMLevelEval method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
RegDataSet train = TRECFormat.loadRegDataSet(config.getString("input.trainData"), DataSetType.REG_SPARSE, true);
Set<Double> unique = new HashSet<>();
for (double d : train.getLabels()) {
unique.add(d);
}
List<Double> levels = unique.stream().sorted().collect(Collectors.toList());
RegDataSet test = TRECFormat.loadRegDataSet(config.getString("input.testData"), DataSetType.REG_SPARSE, true);
double[] doubleTruth = test.getLabels();
double[] doublePred = loadPrediction(config.getString("input.prediction"));
double[] roundedPred = Arrays.stream(doublePred).map(d -> round(d, levels)).toArray();
System.out.println("before rounding");
System.out.println("rmse = " + RMSE.rmse(doubleTruth, doublePred));
System.out.println("after rounding");
System.out.println("rmse = " + RMSE.rmse(doubleTruth, roundedPred));
System.out.println("accuracy = " + IntStream.range(0, test.getNumDataPoints()).filter(i -> doubleTruth[i] == roundedPred[i]).count() / (double) test.getNumDataPoints());
System.out.println("the distribution of predicted label for a given true label");
for (int l = 0; l < levels.size(); l++) {
double level = levels.get(l);
System.out.println("for true label " + level);
truthToPred(test.getLabels(), roundedPred, level, levels);
}
System.out.println("=============================");
System.out.println("the distribution of true label for a given predicted label");
for (int l = 0; l < levels.size(); l++) {
double level = levels.get(l);
System.out.println("for predicted label " + level);
predToTruth(test.getLabels(), roundedPred, level, levels);
}
}
use of edu.neu.ccs.pyramid.eval.RMSE in project pyramid by cheng-li.
the class PMMLConverterTest method main.
public static void main(String[] args) throws Exception {
RegDataSet trainSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//train"), DataSetType.REG_DENSE, true);
RegDataSet testSet = TRECFormat.loadRegDataSet(new File("/Users/chengli/Dropbox/Public/pyramid/abalone//test"), DataSetType.REG_DENSE, true);
LSBoost lsBoost = new LSBoost();
RegTreeConfig regTreeConfig = new RegTreeConfig().setMaxNumLeaves(3);
RegTreeFactory regTreeFactory = new RegTreeFactory(regTreeConfig);
LSBoostOptimizer optimizer = new LSBoostOptimizer(lsBoost, trainSet, regTreeFactory);
optimizer.setShrinkage(0.1);
optimizer.initialize();
for (int i = 0; i < 10; i++) {
System.out.println("iteration " + i);
System.out.println("train RMSE = " + RMSE.rmse(lsBoost, trainSet));
System.out.println("test RMSE = " + RMSE.rmse(lsBoost, testSet));
optimizer.iterate();
}
FeatureList featureList = trainSet.getFeatureList();
List<RegressionTree> regressionTrees = lsBoost.getEnsemble(0).getRegressors().stream().filter(a -> a instanceof RegressionTree).map(a -> (RegressionTree) a).collect(Collectors.toList());
System.out.println(regressionTrees);
double constant = ((ConstantRegressor) lsBoost.getEnsemble(0).get(0)).getScore();
PMML pmml = PMMLConverter.encodePMML(null, null, featureList, regressionTrees, (float) constant);
System.out.println(pmml.toString());
try (OutputStream os = new FileOutputStream("/Users/chengli/tmp/pmml.xml")) {
MetroJAXBUtil.marshalPMML(pmml, os);
}
}
use of edu.neu.ccs.pyramid.eval.RMSE in project pyramid by cheng-li.
the class ElasticNetLinearRegTrainerTest method test3.
private static void test3() throws Exception {
RegDataSet dataSet = RegressionSynthesizer.linear();
LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet);
trainer.setRegularization(0.001);
trainer.setL1Ratio(0.1);
System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
trainer.optimize();
System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias());
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
pairs.add(new Pair<>(element.index(), element.get()));
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
Set<Integer> set = pairs.stream().sorted(comparator.reversed()).limit(4).map(pair -> pair.getFirst()).collect(Collectors.toSet());
Set<Integer> trueSet = new HashSet<>();
trueSet.add(0);
trueSet.add(1);
trueSet.add(2);
trueSet.add(3);
if (set.equals(trueSet)) {
System.out.println("correct");
} else {
System.out.println("incorrect");
}
}
Aggregations