use of edu.neu.ccs.pyramid.classification.PredictionAnalysis in project pyramid by cheng-li.
the class LogisticRegressionInspector method analyzePrediction.
public static PredictionAnalysis analyzePrediction(LogisticRegression logisticRegression, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = logisticRegression.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = logisticRegression.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.classification.PredictionAnalysis in project pyramid by cheng-li.
the class LKTreeBoostTest method spam_load.
static void spam_load() throws Exception {
System.out.println("loading ensemble");
LKBoost lkBoost = LKBoost.deserialize(new File(TMP, "/LKTreeBoostTest/ensemble.ser"));
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"), DataSetType.CLF_DENSE, true);
System.out.println("test data:");
System.out.println(dataSet.getMetaInfo());
System.out.println(dataSet.getLabelTranslator());
double accuracy = Accuracy.accuracy(lkBoost, dataSet);
System.out.println(accuracy);
System.out.println("auc = " + AUC.auc(lkBoost, dataSet));
ConfusionMatrix confusionMatrix = new ConfusionMatrix(lkBoost, dataSet);
System.out.println("confusion matrix:");
System.out.println(confusionMatrix.printWithExtLabels());
System.out.println("top featureList for class 0");
System.out.println(LKBInspector.topFeatures(lkBoost, 0));
System.out.println(new PerClassMeasures(confusionMatrix, 0));
System.out.println(new PerClassMeasures(confusionMatrix, 1));
System.out.println("macor-averaged:");
System.out.println(new MacroAveragedMeasures(confusionMatrix));
// System.out.println(lkTreeBoost);
int[] labels = dataSet.getLabels();
List<double[]> classProbs = lkBoost.predictClassProbs(dataSet);
for (int k = 0; k < dataSet.getNumClasses(); k++) {
int numMatches = 0;
double sumProbs = 0;
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
if (labels[i] == k) {
numMatches += 1;
}
sumProbs += classProbs.get(i)[k];
}
System.out.println("for class " + k);
System.out.println("number of matches =" + numMatches);
System.out.println("sum of probs = " + sumProbs);
}
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
ClassScoreCalculation classScoreCalculation = LKBInspector.decisionProcess(lkBoost, labelTranslator, dataSet.getRow(0), 0, 10);
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(new File(TMP, "score_calculation.json"), classScoreCalculation);
PredictionAnalysis predictionAnalysis = LKBInspector.analyzePrediction(lkBoost, dataSet, 0, 10);
ObjectMapper mapper1 = new ObjectMapper();
mapper1.writeValue(new File(TMP, "prediction_analysis.json"), predictionAnalysis);
}
use of edu.neu.ccs.pyramid.classification.PredictionAnalysis in project pyramid by cheng-li.
the class LKBInspector method analyzePrediction.
//todo speed up
public static PredictionAnalysis analyzePrediction(LKBoost boosting, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = boosting.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
Aggregations