use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMGB method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
VERBOSE = config.getBoolean("output.verbose");
new File(config.getString("output.dir")).mkdirs();
if (config.getBoolean("tune")) {
System.out.println("============================================================");
System.out.println("Start hyper parameter tuning");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
List<TuneResult> tuneResults = new ArrayList<>();
List<MultiLabelClfDataSet> dataSets = loadTrainValidData(config);
List<Integer> leaveNums = config.getIntegers("tune.numLeaves.candidates");
List<Integer> components = config.getIntegers("tune.numComponents.candidates");
for (int numLeaves : leaveNums) {
for (int component : components) {
StopWatch stopWatch1 = new StopWatch();
stopWatch1.start();
HyperParameters hyperParameters = new HyperParameters();
hyperParameters.numComponents = component;
hyperParameters.numLeaves = numLeaves;
System.out.println("---------------------------");
System.out.println("Trying hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
TuneResult tuneResult = tune(config, hyperParameters, dataSets.get(0), dataSets.get(1));
System.out.println("Found optimal train.iterations = " + tuneResult.hyperParameters.iterations);
System.out.println("Validation performance = " + tuneResult.performance);
tuneResults.add(tuneResult);
System.out.println("Time spent on trying this set of hyper parameters = " + stopWatch1);
}
}
Comparator<TuneResult> comparator = Comparator.comparing(res -> res.performance);
TuneResult best;
String predictTarget = config.getString("tune.targetMetric");
switch(predictTarget) {
case "instance_set_accuracy":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_f1":
best = tuneResults.stream().max(comparator).get();
break;
case "instance_hamming_loss":
best = tuneResults.stream().min(comparator).get();
break;
default:
throw new IllegalArgumentException("tune.targetMetric should be instance_set_accuracy, instance_f1 or instance_hamming_loss");
}
System.out.println("---------------------------");
System.out.println("Hyper parameter tuning done.");
System.out.println("Time spent on entire hyper parameter tuning = " + stopWatch);
System.out.println("Best validation performance = " + best.performance);
System.out.println("Best hyper parameters:");
System.out.println("train.numComponents = " + best.hyperParameters.numComponents);
System.out.println("train.numLeaves = " + best.hyperParameters.numLeaves);
System.out.println("train.iterations = " + best.hyperParameters.iterations);
Config tunedHypers = best.hyperParameters.asConfig();
tunedHypers.store(new File(config.getString("output.dir"), "tuned_hyper_parameters.properties"));
System.out.println("Tuned hyper parameters saved to " + new File(config.getString("output.dir"), "tuned_hyper_parameters.properties").getAbsolutePath());
System.out.println("============================================================");
}
if (config.getBoolean("train")) {
System.out.println("============================================================");
if (config.getBoolean("train.useTunedHyperParameters")) {
File hyperFile = new File(config.getString("output.dir"), "tuned_hyper_parameters.properties");
if (!hyperFile.exists()) {
System.out.println("train.useTunedHyperParameters is set to true. But no tuned hyper parameters can be found in the output directory.");
System.out.println("Please either run hyper parameter tuning, or provide hyper parameters manually and set train.useTunedHyperParameters=false.");
System.exit(1);
}
Config tunedHypers = new Config(hyperFile);
HyperParameters hyperParameters = new HyperParameters(tunedHypers);
System.out.println("Start training with tuned hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
} else {
HyperParameters hyperParameters = new HyperParameters(config);
System.out.println("Start training with given hyper parameters:");
System.out.println("train.numComponents = " + hyperParameters.numComponents);
System.out.println("train.numLeaves = " + hyperParameters.numLeaves);
System.out.println("train.iterations = " + hyperParameters.iterations);
MultiLabelClfDataSet trainSet = loadTrainData(config);
train(config, hyperParameters, trainSet);
}
System.out.println("============================================================");
}
if (config.getBoolean("test")) {
System.out.println("============================================================");
test(config);
System.out.println("============================================================");
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMGB method reportAccPrediction.
private static void reportAccPrediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance set accuracy optimal predictor");
String output = config.getString("output.dir");
AccPredictor accPredictor = new AccPredictor(cbm);
accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = accPredictor.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance set accuracy optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_accuracy_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class CBMGB method reportF1Prediction.
private static void reportF1Prediction(Config config, CBM cbm, MultiLabelClfDataSet dataSet) throws Exception {
System.out.println("============================================================");
System.out.println("Making predictions on test set with the instance F1 optimal predictor");
String output = config.getString("output.dir");
PluginF1 pluginF1 = new PluginF1(cbm);
List<MultiLabel> support = (List<MultiLabel>) Serialization.deserialize(new File(output, "support"));
pluginF1.setSupport(support);
pluginF1.setPiThreshold(config.getDouble("predict.piThreshold"));
MultiLabel[] predictions = pluginF1.predict(dataSet);
MLMeasures mlMeasures = new MLMeasures(dataSet.getNumClasses(), dataSet.getMultiLabels(), predictions);
System.out.println("test performance with the instance F1 optimal predictor");
System.out.println(mlMeasures);
File performanceFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "performance.txt").toFile();
FileUtils.writeStringToFile(performanceFile, mlMeasures.toString());
System.out.println("test performance is saved to " + performanceFile.toString());
// Here we do not use approximation
double[] setProbs = IntStream.range(0, predictions.length).parallel().mapToDouble(i -> cbm.predictAssignmentProb(dataSet.getRow(i), predictions[i])).toArray();
File predictionFile = Paths.get(output, "test_predictions", "instance_f1_optimal", "predictions.txt").toFile();
try (BufferedWriter br = new BufferedWriter(new FileWriter(predictionFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
br.write(predictions[i].toString());
br.write(":");
br.write("" + setProbs[i]);
br.newLine();
}
}
System.out.println("predicted sets and their probabilities are saved to " + predictionFile.getAbsolutePath());
System.out.println("============================================================");
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class FoldPartitioner method main.
public static void main(String[] args) throws Exception {
Config config = new Config(args[0]);
System.out.println(config);
String dataType = config.getString("dataSetType");
switch(dataType) {
case "clf":
partitionClfData(config);
break;
case "reg":
partitionRegData(config);
break;
case "mlclf":
partitionMLClfData(config);
break;
default:
throw new IllegalArgumentException("illegal dataSetType");
}
}
use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.
the class App1 method loadData.
//todo keep track of feature types(numerical /binary)
static MultiLabelClfDataSet loadData(Config config, MultiLabelIndex index, FeatureList featureList, IdTranslator idTranslator, int totalDim, LabelTranslator labelTranslator, String docFilter) throws Exception {
File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
Config savedConfig = new Config(new File(metaDataFolder, "saved_config_app1"));
int numDataPoints = idTranslator.numData();
int numClasses = labelTranslator.getNumClasses();
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(totalDim).numClasses(numClasses).density(Density.SPARSE_RANDOM).missingValue(savedConfig.getBoolean("train.feature.missingValue")).build();
for (int i = 0; i < numDataPoints; i++) {
String dataIndexId = idTranslator.toExtId(i);
List<String> extMultiLabel = index.getExtMultiLabel(dataIndexId);
if (savedConfig.getBoolean("train.label.filter")) {
String prefix = savedConfig.getString("train.label.filter.prefix");
extMultiLabel = extMultiLabel.stream().filter(extLabel -> extLabel.startsWith(prefix)).collect(Collectors.toList());
}
for (String extLabel : extMultiLabel) {
int intLabel = labelTranslator.toIntLabel(extLabel);
dataSet.addLabel(i, intLabel);
}
}
String matchScoreTypeString = savedConfig.getString("train.feature.ngram.matchScoreType");
FeatureLoader.MatchScoreType matchScoreType;
switch(matchScoreTypeString) {
case "es_original":
matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
break;
case "binary":
matchScoreType = FeatureLoader.MatchScoreType.BINARY;
break;
case "frequency":
matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
break;
case "tfifl":
matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
break;
default:
throw new IllegalArgumentException("unknown ngramMatchScoreType");
}
FeatureLoader.loadFeatures(index, dataSet, featureList, idTranslator, matchScoreType, docFilter);
dataSet.setIdTranslator(idTranslator);
dataSet.setLabelTranslator(labelTranslator);
return dataSet;
}
Aggregations