use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class CRFInspector method simplePredictionAnalysis.
public static String simplePredictionAnalysis(CMLCRF crf, PluginPredictor<CMLCRF> pluginPredictor, MultiLabelClfDataSet dataSet, int dataPointIndex, double classProbThreshold) {
StringBuilder sb = new StringBuilder();
MultiLabel trueLabels = dataSet.getMultiLabels()[dataPointIndex];
String id = dataSet.getIdTranslator().toExtId(dataPointIndex);
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
double[] combProbs = crf.predictCombinationProbs(dataSet.getRow(dataPointIndex));
double[] classProbs = crf.calClassProbs(combProbs);
MultiLabel predicted = pluginPredictor.predict(dataSet.getRow(dataPointIndex));
List<Integer> classes = new ArrayList<Integer>();
for (int k = 0; k < crf.getNumClasses(); k++) {
if (classProbs[k] >= classProbThreshold || dataSet.getMultiLabels()[dataPointIndex].matchClass(k) || predicted.matchClass(k)) {
classes.add(k);
}
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
List<Pair<Integer, Double>> list = classes.stream().map(l -> new Pair<Integer, Double>(l, classProbs[l])).sorted(comparator.reversed()).collect(Collectors.toList());
for (Pair<Integer, Double> pair : list) {
int label = pair.getFirst();
double prob = pair.getSecond();
int match = 0;
if (trueLabels.matchClass(label)) {
match = 1;
}
sb.append(id).append("\t").append(labelTranslator.toExtLabel(label)).append("\t").append("single").append("\t").append(prob).append("\t").append(match).append("\n");
}
double probability = 0;
List<MultiLabel> support = crf.getSupportCombinations();
for (int i = 0; i < support.size(); i++) {
MultiLabel candidate = support.get(i);
if (candidate.equals(predicted)) {
probability = combProbs[i];
break;
}
}
List<Integer> predictedList = predicted.getMatchedLabelsOrdered();
sb.append(id).append("\t");
for (int i = 0; i < predictedList.size(); i++) {
sb.append(labelTranslator.toExtLabel(predictedList.get(i)));
if (i != predictedList.size() - 1) {
sb.append(",");
}
}
sb.append("\t");
int setMatch = 0;
if (predicted.equals(trueLabels)) {
setMatch = 1;
}
sb.append("set").append("\t").append(probability).append("\t").append(setMatch).append("\n");
return sb.toString();
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class ElasticNetLinearRegTrainerTest method test3.
private static void test3() throws Exception {
RegDataSet dataSet = RegressionSynthesizer.linear();
LinearRegression linearRegression = new LinearRegression(dataSet.getNumFeatures());
ElasticNetLinearRegOptimizer trainer = new ElasticNetLinearRegOptimizer(linearRegression, dataSet);
trainer.setRegularization(0.001);
trainer.setL1Ratio(0.1);
System.out.println("train rmse before training = " + RMSE.rmse(linearRegression, dataSet));
trainer.optimize();
System.out.println("train rmse after training = " + RMSE.rmse(linearRegression, dataSet));
System.out.println("non-zeros = " + linearRegression.getWeights().getWeightsWithoutBias());
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
pairs.add(new Pair<>(element.index(), element.get()));
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> pair.getSecond());
Set<Integer> set = pairs.stream().sorted(comparator.reversed()).limit(4).map(pair -> pair.getFirst()).collect(Collectors.toSet());
Set<Integer> trueSet = new HashSet<>();
trueSet.add(0);
trueSet.add(1);
trueSet.add(2);
trueSet.add(3);
if (set.equals(trueSet)) {
System.out.println("correct");
} else {
System.out.println("incorrect");
}
}
Aggregations