use of edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor in project pyramid by cheng-li.
the class BRCalibration method calibrate.
private static void calibrate(Config config, Logger logger) throws Exception {
logger.info("start training calibrators");
DataSetType dataSetType;
switch(config.getString("dataSetType")) {
case "sparse_random":
dataSetType = DataSetType.ML_CLF_SPARSE;
break;
case "sparse_sequential":
dataSetType = DataSetType.ML_CLF_SEQ_SPARSE;
break;
default:
throw new IllegalArgumentException("unknown dataSetType");
}
MultiLabelClfDataSet train = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), dataSetType, true);
MultiLabelClfDataSet cal = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.calibrationData"), dataSetType, true);
MultiLabelClfDataSet valid = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.validData"), dataSetType, true);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "support").toFile());
MultiLabelClassifier.ClassProbEstimator classProbEstimator = (MultiLabelClassifier.ClassProbEstimator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
List<Integer> labelCalIndices = IntStream.range(0, cal.getNumDataPoints()).filter(i -> i % 2 == 0).boxed().collect(Collectors.toList());
List<Integer> setCalIndices = IntStream.range(0, cal.getNumDataPoints()).filter(i -> i % 2 == 1).boxed().collect(Collectors.toList());
MultiLabelClfDataSet labelCalData = DataSetUtil.sampleData(cal, labelCalIndices);
MultiLabelClfDataSet setCalData = DataSetUtil.sampleData(cal, setCalIndices);
logger.info("start training label calibrator");
LabelCalibrator labelCalibrator = null;
switch(config.getString("labelCalibrator")) {
case "isotonic":
labelCalibrator = new IsoLabelCalibrator(classProbEstimator, labelCalData, false);
break;
case "identity":
labelCalibrator = new IdentityLabelCalibrator();
break;
}
logger.info("finish training label calibrator");
logger.info("start training set calibrator");
List<PredictionFeatureExtractor> extractors = new ArrayList<>();
if (config.getBoolean("brProb")) {
// todo order matters; the first one will be used by iso, card iso
extractors.add(new BRProbFeatureExtractor());
}
if (config.getBoolean("expectedF1")) {
extractors.add(new ExpectedF1FeatureExtractor());
}
if (config.getBoolean("expectedPrecision")) {
extractors.add(new ExpectedPrecisionFeatureExtractor());
}
if (config.getBoolean("expectedRecall")) {
extractors.add(new ExpectedRecallFeatureExtractor());
}
if (config.getBoolean("setPrior")) {
extractors.add(new PriorFeatureExtractor(train));
}
if (config.getBoolean("card")) {
extractors.add(new CardFeatureExtractor());
}
if (config.getBoolean("encodeLabel")) {
extractors.add(new LabelBinaryFeatureExtractor(classProbEstimator.getNumClasses(), train.getLabelTranslator()));
}
if (config.getBoolean("useInitialFeatures")) {
Set<String> prefixes = new HashSet<>(config.getStrings("featureFieldPrefix"));
FeatureList featureList = train.getFeatureList();
List<Integer> featureIds = new ArrayList<>();
for (int j = 0; j < featureList.size(); j++) {
Feature feature = featureList.get(j);
if (feature instanceof CategoricalFeature) {
if (matchPrefixes(((CategoricalFeature) feature).getVariableName(), prefixes)) {
featureIds.add(j);
}
} else {
if (!(feature instanceof Ngram)) {
if (matchPrefixes(feature.getName(), prefixes)) {
featureIds.add(j);
}
}
}
}
extractors.add(new InstanceFeatureExtractor(featureIds, train.getFeatureList()));
}
PredictionFeatureExtractor predictionFeatureExtractor = new CombinedPredictionFeatureExtractor(extractors);
CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
CalibrationDataGenerator.TrainData caliTrainingData;
CalibrationDataGenerator.TrainData caliValidData;
caliTrainingData = calibrationDataGenerator.createCaliTrainingData(setCalData, classProbEstimator, config.getInt("numCandidates"), config.getString("calibrate.target"), support, 10);
caliValidData = calibrationDataGenerator.createCaliTrainingData(valid, classProbEstimator, config.getInt("numCandidates"), config.getString("calibrate.target"), support, 10);
RegDataSet calibratorTrainData = caliTrainingData.regDataSet;
double[] weights = caliTrainingData.instanceWeights;
VectorCalibrator setCalibrator = null;
switch(config.getString("setCalibrator")) {
case "cardinality_isotonic":
setCalibrator = new VectorCardIsoSetCalibrator(calibratorTrainData, 0, 2, false);
break;
case "reranker":
RerankerTrainer rerankerTrainer = RerankerTrainer.newBuilder().numCandidates(config.getInt("numCandidates")).numLeaves(config.getInt("numLeaves")).monotonicityType("weak").build();
setCalibrator = rerankerTrainer.trainWithSigmoid(calibratorTrainData, weights, classProbEstimator, predictionFeatureExtractor, labelCalibrator, caliValidData.regDataSet);
break;
case "isotonic":
setCalibrator = new VectorIsoSetCalibrator(calibratorTrainData, 0, false);
break;
case "identity":
setCalibrator = new VectorIdentityCalibrator(0);
break;
case "zero":
setCalibrator = new ZeroCalibrator();
break;
default:
throw new IllegalArgumentException("illegal setCalibrator");
}
logger.info("finish training set calibrator");
Serialization.serialize(labelCalibrator, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "label_calibrator").toFile());
Serialization.serialize(setCalibrator, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "set_calibrator").toFile());
Serialization.serialize(predictionFeatureExtractor, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "prediction_feature_extractor").toFile());
logger.info("finish training calibrators");
MultiLabelClassifier classifier = null;
switch(config.getString("predict.mode")) {
case "independent":
classifier = new IndependentPredictor(classProbEstimator, labelCalibrator);
break;
case "support":
classifier = new edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor(classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, support);
break;
case "reranker":
Reranker reranker = (Reranker) setCalibrator;
reranker.setMinPredictionSize(config.getInt("predict.minSize"));
reranker.setMaxPredictionSize(config.getInt("predict.maxSize"));
classifier = reranker;
break;
default:
throw new IllegalArgumentException("illegal predict.mode");
}
MultiLabel[] predictions = classifier.predict(cal);
MultiLabel[] predictions_valid = classifier.predict(valid);
if (true) {
logger.info("calibration performance on " + config.getString("input.calibrationFolder") + " set");
List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, cal.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, cal.getRow(i), predictions[i], cal.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
}
logger.info("classification performance on " + config.getString("input.validFolder") + " set");
logger.info(new MLMeasures(valid.getNumClasses(), valid.getMultiLabels(), predictions_valid).toString());
if (true) {
logger.info("calibration performance on " + config.getString("input.validFolder") + " set");
List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, valid.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, valid.getRow(i), predictions_valid[i], valid.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor in project pyramid by cheng-li.
the class BRPrediction method report.
private static void report(Config config, String dataPath, Logger logger) throws Exception {
DataSetType dataSetType;
switch(config.getString("dataSetType")) {
case "sparse_random":
dataSetType = DataSetType.ML_CLF_SPARSE;
break;
case "sparse_sequential":
dataSetType = DataSetType.ML_CLF_SEQ_SPARSE;
break;
default:
throw new IllegalArgumentException("unknown dataSetType");
}
MultiLabelClfDataSet test = TRECFormat.loadMultiLabelClfDataSet(dataPath, dataSetType, true);
MultiLabelClassifier.ClassProbEstimator classProbEstimator = (MultiLabelClassifier.ClassProbEstimator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "label_calibrator").toFile());
VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "set_calibrator").toFile());
PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "prediction_feature_extractor").toFile());
File testDataFile = new File(dataPath);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "support").toFile());
String reportFolder = Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "predictions", testDataFile.getName() + "_reports").toString();
MultiLabelClassifier classifier = null;
switch(config.getString("predict.mode")) {
case "independent":
classifier = new IndependentPredictor(classProbEstimator, labelCalibrator);
break;
case "support":
classifier = new edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor(classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, support);
break;
case "reranker":
Reranker reranker = (Reranker) setCalibrator;
reranker.setMinPredictionSize(config.getInt("predict.minSize"));
reranker.setMaxPredictionSize(config.getInt("predict.maxSize"));
classifier = reranker;
break;
default:
throw new IllegalArgumentException("illegal predict.mode");
}
MultiLabel[] predictions = classifier.predict(test);
logger.info("classification performance on dataset " + testDataFile.getName());
MLMeasures mlMeasures = new MLMeasures(test.getNumClasses(), test.getMultiLabels(), predictions);
mlMeasures.getMacroAverage().updateAveragePrecision(classProbEstimator, test);
logger.info(mlMeasures.toString());
CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
if (true) {
logger.info("calibration performance on dataset " + testDataFile.getName());
List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, test.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, test.getRow(i), predictions[i], test.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
BRCalibration.eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
}
MultiLabelClassifier fClassifier = classifier;
boolean simpleCSV = true;
if (simpleCSV) {
File csv = Paths.get(reportFolder, "report.csv").toFile();
csv.getParentFile().mkdirs();
if (csv.exists()) {
csv.delete();
}
StringBuilder sb = new StringBuilder();
sb.append("doc_id").append("\t").append("predictions").append("\t").append("prediction_type").append("\t").append("confidence").append("\t").append("truth").append("\t").append("ground_truth").append("\t").append("precision").append("\t").append("recall").append("\t").append("F1").append("\n");
FileUtils.writeStringToFile(csv, sb.toString());
List<Integer> list = IntStream.range(0, test.getNumDataPoints()).boxed().collect(Collectors.toList());
ParallelStringMapper<Integer> mapper = (list1, i) -> simplePredictionAnalysisCalibrated(classProbEstimator, labelCalibrator, setCalibrator, test, i, fClassifier, predictionFeatureExtractor);
ParallelFileWriter.mapToString(mapper, list, csv, 100, true);
}
boolean topSets = true;
if (topSets) {
File csv = Paths.get(reportFolder, "top_sets.csv").toFile();
csv.getParentFile().mkdirs();
List<Integer> list = IntStream.range(0, test.getNumDataPoints()).boxed().collect(Collectors.toList());
ParallelStringMapper<Integer> mapper = (list1, i) -> topKSets(config, classProbEstimator, labelCalibrator, setCalibrator, test, i, fClassifier, predictionFeatureExtractor);
ParallelFileWriter.mapToString(mapper, list, csv, 100);
}
boolean rulesToJson = config.getBoolean("report.showPredictionDetail");
if (rulesToJson) {
logger.info("start writing rules to json");
int ruleLimit = config.getInt("report.rule.limit");
int numDocsPerFile = config.getInt("report.numDocsPerFile");
int numFiles = (int) Math.ceil((double) test.getNumDataPoints() / numDocsPerFile);
double probThreshold = config.getDouble("report.classProbThreshold");
int labelSetLimit = config.getInt("report.labelSetLimit");
IntStream.range(0, numFiles).forEach(i -> {
int start = i * numDocsPerFile;
int end = start + numDocsPerFile;
List<MultiLabelPredictionAnalysis> partition = IntStream.range(start, Math.min(end, test.getNumDataPoints())).parallel().mapToObj(a -> BRInspector.analyzePrediction(classProbEstimator, labelCalibrator, setCalibrator, test, fClassifier, predictionFeatureExtractor, a, ruleLimit, labelSetLimit, probThreshold)).collect(Collectors.toList());
ObjectMapper mapper = new ObjectMapper();
File jsonFile = Paths.get(reportFolder, "report_" + (i + 1) + ".json").toFile();
try {
mapper.writeValue(jsonFile, partition);
} catch (IOException e) {
e.printStackTrace();
}
logger.info("progress = " + Progress.percentage(i + 1, numFiles));
});
logger.info("finish writing rules to json");
}
boolean individualPerformance = true;
if (individualPerformance) {
logger.info("start writing individual label performance to json");
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(Paths.get(reportFolder, "individual_performance.json").toFile(), mlMeasures.getMacroAverage());
logger.info("finish writing individual label performance to json");
}
boolean dataConfigToJson = true;
if (dataConfigToJson) {
logger.info("start writing data config to json");
File dataConfigFile = Paths.get(dataPath, "data_config.json").toFile();
if (dataConfigFile.exists()) {
FileUtils.copyFileToDirectory(dataConfigFile, new File(reportFolder));
}
logger.info("finish writing data config to json");
}
boolean dataInfoToJson = true;
if (dataInfoToJson) {
logger.info("start writing data info to json");
Set<String> modelLabels = IntStream.range(0, classifier.getNumClasses()).mapToObj(i -> classProbEstimator.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
Set<String> dataSetLabels = DataSetUtil.gatherLabels(test).stream().map(i -> test.getLabelTranslator().toExtLabel(i)).collect(Collectors.toSet());
JsonGenerator jsonGenerator = new JsonFactory().createGenerator(Paths.get(reportFolder, "data_info.json").toFile(), JsonEncoding.UTF8);
jsonGenerator.writeStartObject();
jsonGenerator.writeStringField("dataSet", testDataFile.getName());
jsonGenerator.writeNumberField("numClassesInModel", classifier.getNumClasses());
jsonGenerator.writeNumberField("numClassesInDataSet", dataSetLabels.size());
jsonGenerator.writeNumberField("numClassesInModelDataSetCombined", test.getNumClasses());
Set<String> modelNotDataLabels = SetUtil.complement(modelLabels, dataSetLabels);
Set<String> dataNotModelLabels = SetUtil.complement(dataSetLabels, modelLabels);
jsonGenerator.writeNumberField("numClassesInDataSetButNotModel", dataNotModelLabels.size());
jsonGenerator.writeNumberField("numClassesInModelButNotDataSet", modelNotDataLabels.size());
jsonGenerator.writeArrayFieldStart("classesInDataSetButNotModel");
for (String label : dataNotModelLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeArrayFieldStart("classesInModelButNotDataSet");
for (String label : modelNotDataLabels) {
jsonGenerator.writeObject(label);
}
jsonGenerator.writeEndArray();
jsonGenerator.writeNumberField("labelCardinality", test.labelCardinality());
jsonGenerator.writeEndObject();
jsonGenerator.close();
logger.info("finish writing data info to json");
}
boolean performanceToJson = true;
if (performanceToJson) {
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.writeValue(Paths.get(reportFolder, "performance.json").toFile(), mlMeasures);
}
if (config.getBoolean("report.produceHTML")) {
logger.info("start producing html files");
Config savedApp1Config = new Config(Paths.get(config.getString("output.dir"), "meta_data", "saved_config_app1").toFile());
List<String> hosts = savedApp1Config.getStrings("index.hosts");
List<Integer> ports = savedApp1Config.getIntegers("index.ports");
// todo make it better
if (savedApp1Config.getString("index.clientType").equals("node")) {
hosts = new ArrayList<>();
for (int port : ports) {
hosts.add("localhost");
}
// default setting
hosts.add("localhost");
ports.add(9200);
}
try (Visualizer visualizer = new Visualizer(logger, hosts, ports)) {
visualizer.produceHtml(new File(reportFolder));
logger.info("finish producing html files");
}
}
}
use of edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor in project pyramid by cheng-li.
the class BRRerank method report.
private static void report(Config config) throws Exception {
MultiLabelClfDataSet dataset = TRECFormat.loadMultiLabelClfDataSet(Paths.get(config.getString("dataPath"), "test").toFile(), DataSetType.ML_CLF_SPARSE, true);
CBM cbm = (CBM) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "model"));
cbm.setAllowEmpty(true);
MultiLabelClassifier classifier = null;
LabelCalibrator labelCalibrator = (LabelCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "label_calibrator"));
VectorCalibrator setCalibrator = (VectorCalibrator) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "set_calibrator"));
PredictionFeatureExtractor predictionFeatureExtractor = (PredictionFeatureExtractor) Serialization.deserialize(Paths.get(config.getString("outputDir"), "models", "calibration_feature_extractor"));
CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
switch(config.getString("predictMode")) {
case "independent":
classifier = new IndependentPredictor(cbm, labelCalibrator);
break;
case "rerank":
classifier = (Reranker) setCalibrator;
break;
default:
throw new IllegalArgumentException("illegal predict.mode");
}
MultiLabel[] predictions = classifier.predict(dataset);
List<CalibInfo> confidenceScores = IntStream.range(0, dataset.getNumDataPoints()).parallel().boxed().map(i -> {
CalibrationDataGenerator.CalibrationInstance predictionInstance = calibrationDataGenerator.createInstance(cbm, dataset.getRow(i), predictions[i], dataset.getMultiLabels()[i], "accuracy");
double calibrated = setCalibrator.calibrate(predictionInstance.vector);
CalibInfo calibInfo = new CalibInfo();
calibInfo.uncalibrated = predictionInstance.vector.get(0);
calibInfo.calibrated = calibrated;
calibInfo.accuracy = predictionInstance.correctness;
return calibInfo;
}).collect(Collectors.toList());
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("set_prediction").append("\t").append("uncalibrated_confidence").append("\t").append("calibrated_confidence").append("\t").append("ground_truth").append("\t").append("set_accuracy").append("\n");
for (int i = 0; i < dataset.getNumDataPoints(); i++) {
stringBuilder.append(predictions[i]).append("\t").append(confidenceScores.get(i).uncalibrated).append("\t").append(confidenceScores.get(i).calibrated).append("\t").append(dataset.getMultiLabels()[i]).append("\t").append(predictions[i].equals(dataset.getMultiLabels()[i]) ? 1 : 0).append("\n");
}
FileUtils.writeStringToFile(Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toFile(), stringBuilder.toString());
System.out.println("set predictions and confidence scores are saved to " + Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toString() + "\n");
}
Aggregations