use of edu.neu.ccs.pyramid.feature.CategoricalFeature in project pyramid by cheng-li.
the class BRCalibration method calibrate.
private static void calibrate(Config config, Logger logger) throws Exception {
logger.info("start training calibrators");
DataSetType dataSetType;
switch(config.getString("dataSetType")) {
case "sparse_random":
dataSetType = DataSetType.ML_CLF_SPARSE;
break;
case "sparse_sequential":
dataSetType = DataSetType.ML_CLF_SEQ_SPARSE;
break;
default:
throw new IllegalArgumentException("unknown dataSetType");
}
MultiLabelClfDataSet train = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), dataSetType, true);
MultiLabelClfDataSet cal = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.calibrationData"), dataSetType, true);
MultiLabelClfDataSet valid = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.validData"), dataSetType, true);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "support").toFile());
MultiLabelClassifier.ClassProbEstimator classProbEstimator = (MultiLabelClassifier.ClassProbEstimator) Serialization.deserialize(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
List<Integer> labelCalIndices = IntStream.range(0, cal.getNumDataPoints()).filter(i -> i % 2 == 0).boxed().collect(Collectors.toList());
List<Integer> setCalIndices = IntStream.range(0, cal.getNumDataPoints()).filter(i -> i % 2 == 1).boxed().collect(Collectors.toList());
MultiLabelClfDataSet labelCalData = DataSetUtil.sampleData(cal, labelCalIndices);
MultiLabelClfDataSet setCalData = DataSetUtil.sampleData(cal, setCalIndices);
logger.info("start training label calibrator");
LabelCalibrator labelCalibrator = null;
switch(config.getString("labelCalibrator")) {
case "isotonic":
labelCalibrator = new IsoLabelCalibrator(classProbEstimator, labelCalData, false);
break;
case "identity":
labelCalibrator = new IdentityLabelCalibrator();
break;
}
logger.info("finish training label calibrator");
logger.info("start training set calibrator");
List<PredictionFeatureExtractor> extractors = new ArrayList<>();
if (config.getBoolean("brProb")) {
// todo order matters; the first one will be used by iso, card iso
extractors.add(new BRProbFeatureExtractor());
}
if (config.getBoolean("expectedF1")) {
extractors.add(new ExpectedF1FeatureExtractor());
}
if (config.getBoolean("expectedPrecision")) {
extractors.add(new ExpectedPrecisionFeatureExtractor());
}
if (config.getBoolean("expectedRecall")) {
extractors.add(new ExpectedRecallFeatureExtractor());
}
if (config.getBoolean("setPrior")) {
extractors.add(new PriorFeatureExtractor(train));
}
if (config.getBoolean("card")) {
extractors.add(new CardFeatureExtractor());
}
if (config.getBoolean("encodeLabel")) {
extractors.add(new LabelBinaryFeatureExtractor(classProbEstimator.getNumClasses(), train.getLabelTranslator()));
}
if (config.getBoolean("useInitialFeatures")) {
Set<String> prefixes = new HashSet<>(config.getStrings("featureFieldPrefix"));
FeatureList featureList = train.getFeatureList();
List<Integer> featureIds = new ArrayList<>();
for (int j = 0; j < featureList.size(); j++) {
Feature feature = featureList.get(j);
if (feature instanceof CategoricalFeature) {
if (matchPrefixes(((CategoricalFeature) feature).getVariableName(), prefixes)) {
featureIds.add(j);
}
} else {
if (!(feature instanceof Ngram)) {
if (matchPrefixes(feature.getName(), prefixes)) {
featureIds.add(j);
}
}
}
}
extractors.add(new InstanceFeatureExtractor(featureIds, train.getFeatureList()));
}
PredictionFeatureExtractor predictionFeatureExtractor = new CombinedPredictionFeatureExtractor(extractors);
CalibrationDataGenerator calibrationDataGenerator = new CalibrationDataGenerator(labelCalibrator, predictionFeatureExtractor);
CalibrationDataGenerator.TrainData caliTrainingData;
CalibrationDataGenerator.TrainData caliValidData;
caliTrainingData = calibrationDataGenerator.createCaliTrainingData(setCalData, classProbEstimator, config.getInt("numCandidates"), config.getString("calibrate.target"), support, 10);
caliValidData = calibrationDataGenerator.createCaliTrainingData(valid, classProbEstimator, config.getInt("numCandidates"), config.getString("calibrate.target"), support, 10);
RegDataSet calibratorTrainData = caliTrainingData.regDataSet;
double[] weights = caliTrainingData.instanceWeights;
VectorCalibrator setCalibrator = null;
switch(config.getString("setCalibrator")) {
case "cardinality_isotonic":
setCalibrator = new VectorCardIsoSetCalibrator(calibratorTrainData, 0, 2, false);
break;
case "reranker":
RerankerTrainer rerankerTrainer = RerankerTrainer.newBuilder().numCandidates(config.getInt("numCandidates")).numLeaves(config.getInt("numLeaves")).monotonicityType("weak").build();
setCalibrator = rerankerTrainer.trainWithSigmoid(calibratorTrainData, weights, classProbEstimator, predictionFeatureExtractor, labelCalibrator, caliValidData.regDataSet);
break;
case "isotonic":
setCalibrator = new VectorIsoSetCalibrator(calibratorTrainData, 0, false);
break;
case "identity":
setCalibrator = new VectorIdentityCalibrator(0);
break;
case "zero":
setCalibrator = new ZeroCalibrator();
break;
default:
throw new IllegalArgumentException("illegal setCalibrator");
}
logger.info("finish training set calibrator");
Serialization.serialize(labelCalibrator, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "label_calibrator").toFile());
Serialization.serialize(setCalibrator, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "set_calibrator").toFile());
Serialization.serialize(predictionFeatureExtractor, Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "models", "calibrators", config.getString("output.calibratorFolder"), "prediction_feature_extractor").toFile());
logger.info("finish training calibrators");
MultiLabelClassifier classifier = null;
switch(config.getString("predict.mode")) {
case "independent":
classifier = new IndependentPredictor(classProbEstimator, labelCalibrator);
break;
case "support":
classifier = new edu.neu.ccs.pyramid.multilabel_classification.predictor.SupportPredictor(classProbEstimator, labelCalibrator, setCalibrator, predictionFeatureExtractor, support);
break;
case "reranker":
Reranker reranker = (Reranker) setCalibrator;
reranker.setMinPredictionSize(config.getInt("predict.minSize"));
reranker.setMaxPredictionSize(config.getInt("predict.maxSize"));
classifier = reranker;
break;
default:
throw new IllegalArgumentException("illegal predict.mode");
}
MultiLabel[] predictions = classifier.predict(cal);
MultiLabel[] predictions_valid = classifier.predict(valid);
if (true) {
logger.info("calibration performance on " + config.getString("input.calibrationFolder") + " set");
List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, cal.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, cal.getRow(i), predictions[i], cal.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
}
logger.info("classification performance on " + config.getString("input.validFolder") + " set");
logger.info(new MLMeasures(valid.getNumClasses(), valid.getMultiLabels(), predictions_valid).toString());
if (true) {
logger.info("calibration performance on " + config.getString("input.validFolder") + " set");
List<CalibrationDataGenerator.CalibrationInstance> instances = IntStream.range(0, valid.getNumDataPoints()).parallel().boxed().map(i -> calibrationDataGenerator.createInstance(classProbEstimator, valid.getRow(i), predictions_valid[i], valid.getMultiLabels()[i], config.getString("calibrate.target"))).collect(Collectors.toList());
eval(instances, setCalibrator, logger, config.getString("calibrate.target"));
}
}
Aggregations