use of edu.neu.ccs.pyramid.classification.ClassProbability in project pyramid by cheng-li.
the class LogisticRegressionInspector method analyzePrediction.
public static PredictionAnalysis analyzePrediction(LogisticRegression logisticRegression, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = logisticRegression.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = logisticRegression.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(logisticRegression, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.classification.ClassProbability in project pyramid by cheng-li.
the class LKBInspector method analyzePrediction.
//todo speed up
public static PredictionAnalysis analyzePrediction(LKBoost boosting, ClfDataSet dataSet, int dataPointIndex, int limit) {
PredictionAnalysis predictionAnalysis = new PredictionAnalysis();
IdTranslator idTranslator = dataSet.getIdTranslator();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
predictionAnalysis.setInternalId(dataPointIndex).setId(idTranslator.toExtId(dataPointIndex)).setInternalLabel(dataSet.getLabels()[dataPointIndex]).setLabel(labelTranslator.toExtLabel(dataSet.getLabels()[dataPointIndex]));
int prediction = boosting.predict(dataSet.getRow(dataPointIndex));
predictionAnalysis.setInternalPrediction(prediction);
predictionAnalysis.setPrediction(labelTranslator.toExtLabel(prediction));
double[] probs = boosting.predictClassProbs(dataSet.getRow(dataPointIndex));
List<ClassProbability> classProbabilities = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassProbability classProbability = new ClassProbability(k, labelTranslator.toExtLabel(k), probs[k]);
classProbabilities.add(classProbability);
}
predictionAnalysis.setClassProbabilities(classProbabilities);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k = 0; k < probs.length; k++) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
return predictionAnalysis;
}
Aggregations