Search in sources :

Example 91 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class BRLREN method train.

private static void train(Config config, HyperParameters hyperParameters, MultiLabelClfDataSet trainSet, MultiLabelClfDataSet validSet, Logger logger) throws Exception {
    List<Integer> unobservedLabels = DataSetUtil.unobservedLabels(trainSet);
    if (!unobservedLabels.isEmpty()) {
        logger.info("The following labels do not actually appear in the training set and therefore cannot be learned:");
        logger.info(ListUtil.toSimpleString(unobservedLabels));
        FileUtils.writeStringToFile(Paths.get(config.getString("output.dir"), "model_predictions", config.getString("output.modelFolder"), "analysis", "unobserved_labels.txt").toFile(), ListUtil.toSimpleString(unobservedLabels));
    }
    String output = config.getString("output.dir");
    EarlyStopper earlyStopper = loadNewEarlyStopper();
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    CBM cbm = newCBM(config, trainSet, hyperParameters, logger);
    ENCBMOptimizer optimizer = getOptimizer(config, hyperParameters, cbm, trainSet);
    logger.info("Initializing the model");
    if (config.getBoolean("train.randomInitialize")) {
        optimizer.randInitialize();
    } else {
        optimizer.initialize();
    }
    logger.info("Initialization done");
    AccPredictor accPredictor = new AccPredictor(cbm);
    accPredictor.setComponentContributionThreshold(config.getDouble("predict.piThreshold"));
    int interval = 1;
    for (int iter = 1; true; iter++) {
        logger.info("Training progress: iteration " + iter);
        optimizer.iterate();
        if (iter % interval == 0) {
            MLMeasures validMeasures = new MLMeasures(accPredictor, validSet);
            if (VERBOSE) {
                logger.info("validation performance");
                logger.info(validMeasures.toString());
            }
            earlyStopper.add(iter, validMeasures.getInstanceAverage().getAccuracy());
            if (earlyStopper.getBestIteration() == iter) {
                Serialization.serialize(cbm, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
            }
            if (earlyStopper.shouldStop()) {
                if (VERBOSE) {
                    logger.info("Early Stopper: the training should stop now!");
                    logger.info("Early Stopper: best iteration found = " + earlyStopper.getBestIteration());
                    logger.info("Early Stopper: best validation performance = " + earlyStopper.getBestValue());
                }
                break;
            }
        }
    }
    logger.info("training done!");
    logger.info("time spent on training = " + stopWatch);
    List<MultiLabel> support = DataSetUtil.gatherMultiLabels(trainSet);
    Serialization.serialize(support, Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "support"));
    CBM bestModel = (CBM) Serialization.deserialize(Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "models", "classifier"));
    boolean topFeaturesToFile = true;
    if (topFeaturesToFile) {
        File analysisFolder = Paths.get(output, "model_predictions", config.getString("output.modelFolder"), "analysis").toFile();
        analysisFolder.mkdirs();
        logger.info("start writing top features");
        List<TopFeatures> topFeaturesList = IntStream.range(0, bestModel.getNumClasses()).mapToObj(k -> topFeatures(bestModel, trainSet.getFeatureList(), trainSet.getLabelTranslator(), k, 100)).collect(Collectors.toList());
        ObjectMapper mapper = new ObjectMapper();
        String file = "top_features.json";
        mapper.writeValue(new File(analysisFolder, file), topFeaturesList);
        StringBuilder sb = new StringBuilder();
        for (int l = 0; l < bestModel.getNumClasses(); l++) {
            sb.append("-------------------------").append("\n");
            sb.append(bestModel.getLabelTranslator().toExtLabel(l)).append(":").append("\n");
            for (Feature feature : topFeaturesList.get(l).getTopFeatures()) {
                sb.append(feature.simpleString()).append(", ");
            }
            sb.append("\n");
        }
        FileUtils.writeStringToFile(new File(analysisFolder, "top_features.txt"), sb.toString());
        logger.info("finish writing top features");
    }
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) SimpleFormatter(java.util.logging.SimpleFormatter) FileHandler(java.util.logging.FileHandler) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) ListUtil(edu.neu.ccs.pyramid.util.ListUtil) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) BufferedWriter(java.io.BufferedWriter) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) MAP(edu.neu.ccs.pyramid.eval.MAP) FileWriter(java.io.FileWriter) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) FileUtils(org.apache.commons.io.FileUtils) IOException(java.io.IOException) StopWatch(org.apache.commons.lang3.time.StopWatch) FeatureList(edu.neu.ccs.pyramid.feature.FeatureList) Logger(java.util.logging.Logger) Collectors(java.util.stream.Collectors) File(java.io.File) IMLGBInspector(edu.neu.ccs.pyramid.multilabel_classification.imlgb.IMLGBInspector) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Feature(edu.neu.ccs.pyramid.feature.Feature) Serialization(edu.neu.ccs.pyramid.util.Serialization) PrintUtil(edu.neu.ccs.pyramid.util.PrintUtil) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) LogisticRegressionInspector(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegressionInspector) edu.neu.ccs.pyramid.multilabel_classification.cbm(edu.neu.ccs.pyramid.multilabel_classification.cbm) TopFeatures(edu.neu.ccs.pyramid.feature.TopFeatures) EarlyStopper(edu.neu.ccs.pyramid.optimization.EarlyStopper) Feature(edu.neu.ccs.pyramid.feature.Feature) StopWatch(org.apache.commons.lang3.time.StopWatch) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) File(java.io.File) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper)

Example 92 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class BRRerank method classification_eval.

private static void classification_eval(Config config) throws Exception {
    System.out.println("classification performance on test set");
    MultiLabelClfDataSet dataset = TRECFormat.loadMultiLabelClfDataSet(Paths.get(config.getString("dataPath"), "test").toFile(), DataSetType.ML_CLF_SPARSE, true);
    MultiLabel[] predictions = FileUtils.readLines(Paths.get(config.getString("outputDir"), "reports", "set_prediction_and_confidence.txt").toFile()).stream().skip(1).map(line -> new MultiLabel(line.split("\t")[0].replace("{", "").replace("}", ""))).toArray(MultiLabel[]::new);
    MLMeasures mlMeasures = new MLMeasures(dataset.getNumClasses(), dataset.getMultiLabels(), predictions);
    System.out.println("classification performance");
    System.out.println(mlMeasures);
    Paths.get(config.getString("outputDir")).toFile().mkdirs();
    FileUtils.writeStringToFile(Paths.get(config.getString("outputDir"), "reports", "classification_performance.txt").toFile(), mlMeasures.toString());
}
Also used : IntStream(java.util.stream.IntStream) IndependentPredictor(edu.neu.ccs.pyramid.multilabel_classification.predictor.IndependentPredictor) MultiLabelClassifier(edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier) CalibrationEval(edu.neu.ccs.pyramid.eval.CalibrationEval) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) FileUtils(org.apache.commons.io.FileUtils) StopWatch(org.apache.commons.lang3.time.StopWatch) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) edu.neu.ccs.pyramid.calibration(edu.neu.ccs.pyramid.calibration) List(java.util.List) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures) Serialization(edu.neu.ccs.pyramid.util.Serialization) Paths(java.nio.file.Paths) edu.neu.ccs.pyramid.dataset(edu.neu.ccs.pyramid.dataset) Pair(edu.neu.ccs.pyramid.util.Pair) Config(edu.neu.ccs.pyramid.configuration.Config) MLMeasures(edu.neu.ccs.pyramid.eval.MLMeasures)

Example 93 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class GBRegressor method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    Logger logger = Logger.getAnonymousLogger();
    String logFile = config.getString("output.log");
    FileHandler fileHandler = null;
    if (!logFile.isEmpty()) {
        new File(logFile).getParentFile().mkdirs();
        // todo should append?
        fileHandler = new FileHandler(logFile, true);
        java.util.logging.Formatter formatter = new SimpleFormatter();
        fileHandler.setFormatter(formatter);
        logger.addHandler(fileHandler);
        logger.setUseParentHandlers(false);
    }
    logger.info(config.toString());
    if (config.getBoolean("train")) {
        train(config, logger);
    }
    if (config.getBoolean("test")) {
        test(config, logger);
    }
    if (fileHandler != null) {
        fileHandler.close();
    }
}
Also used : RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Config(edu.neu.ccs.pyramid.configuration.Config) SimpleFormatter(java.util.logging.SimpleFormatter) Logger(java.util.logging.Logger) File(java.io.File) FileHandler(java.util.logging.FileHandler)

Example 94 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class InstanceConcatenator method main.

public static void main(String[] args) throws Exception {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    List<String> inputs = config.getStrings("input.dataSets");
    String output = config.getString("output.dataSet");
    List<MultiLabelClfDataSet> dataSets = new ArrayList<>();
    for (String input : inputs) {
        MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(input, DataSetType.ML_CLF_SPARSE, true);
        // MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(input, DataSetType.ML_CLF_DENSE,true);
        dataSets.add(dataSet);
    }
    MultiLabelClfDataSet merged = DataSetUtil.concatenateByRow(dataSets);
    TRECFormat.save(merged, output);
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) ArrayList(java.util.ArrayList) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Example 95 with Config

use of edu.neu.ccs.pyramid.configuration.Config in project pyramid by cheng-li.

the class Meka2Trec method main.

/**
 * this is only support multi-label classification dataset.
 * @param args
 */
public static void main(String[] args) throws IOException {
    if (args.length != 1) {
        throw new IllegalArgumentException("Please specify a properties file.");
    }
    Config config = new Config(args[0]);
    System.out.println(config);
    List<String> trecs = config.getStrings("trec");
    List<String> mekas = config.getStrings("meka");
    int numLabels = config.getInt("numLabels");
    int numFeatures = config.getInt("numFeatures");
    String dataMode = config.getString("dataMode");
    for (int i = 0; i < mekas.size(); i++) {
        System.out.println("processing on: " + trecs.get(i));
        MultiLabelClfDataSet dataSet = MekaFormat.loadMLClfDataset(mekas.get(i), numFeatures, numLabels, dataMode);
        TRECFormat.save(dataSet, trecs.get(i));
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)

Aggregations

Config (edu.neu.ccs.pyramid.configuration.Config)119 File (java.io.File)68 Collectors (java.util.stream.Collectors)40 FileUtils (org.apache.commons.io.FileUtils)40 Paths (java.nio.file.Paths)39 IntStream (java.util.stream.IntStream)37 Pair (edu.neu.ccs.pyramid.util.Pair)36 Serialization (edu.neu.ccs.pyramid.util.Serialization)35 StopWatch (org.apache.commons.lang3.time.StopWatch)34 MLMeasures (edu.neu.ccs.pyramid.eval.MLMeasures)33 BufferedWriter (java.io.BufferedWriter)32 FileWriter (java.io.FileWriter)32 java.util (java.util)32 edu.neu.ccs.pyramid.dataset (edu.neu.ccs.pyramid.dataset)31 MultiLabelClassifier (edu.neu.ccs.pyramid.multilabel_classification.MultiLabelClassifier)29 EarlyStopper (edu.neu.ccs.pyramid.optimization.EarlyStopper)28 PrintUtil (edu.neu.ccs.pyramid.util.PrintUtil)26 edu.neu.ccs.pyramid.multilabel_classification.cbm (edu.neu.ccs.pyramid.multilabel_classification.cbm)25 ListUtil (edu.neu.ccs.pyramid.util.ListUtil)25 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)22