use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class App1 method ngramSelection.
private static void ngramSelection(Config config, MultiLabelIndex index, String docFilter, Logger logger) throws Exception {
logger.info("start ngram selection");
File metaDataFolder = new File(config.getString("output.folder"), "meta_data");
FeatureLoader.MatchScoreType matchScoreType;
String matchScoreTypeString = config.getString("train.feature.ngram.matchScoreType");
String[] indexIds = getDocsForSplitFromQuery(index, config.getString("train.splitQuery"));
IdTranslator idTranslator = loadIdTranslator(indexIds);
LabelTranslator labelTranslator = (LabelTranslator) Serialization.deserialize(new File(metaDataFolder, "label_translator.ser"));
switch(matchScoreTypeString) {
case "es_original":
matchScoreType = FeatureLoader.MatchScoreType.ES_ORIGINAL;
break;
case "binary":
matchScoreType = FeatureLoader.MatchScoreType.BINARY;
break;
case "frequency":
matchScoreType = FeatureLoader.MatchScoreType.FREQUENCY;
break;
case "tfifl":
matchScoreType = FeatureLoader.MatchScoreType.TFIFL;
break;
default:
throw new IllegalArgumentException("unknown ngramMatchScoreType");
}
double[][] labels = loadLabels(config, index, idTranslator, labelTranslator);
int numLabels = labels.length;
int toKeep = config.getInt("train.feature.ngram.selectPerLabel");
List<BoundedBlockPriorityQueue<Pair<Ngram, Double>>> queues = new ArrayList<>();
Comparator<Pair<Ngram, Double>> comparator = Comparator.comparing(p -> p.getSecond());
for (int l = 0; l < numLabels; l++) {
queues.add(new BoundedBlockPriorityQueue<>(toKeep, comparator));
}
FeatureList featureList = (FeatureList) Serialization.deserialize(new File(metaDataFolder, "feature_list.ser"));
featureList.getAll().stream().parallel().filter(feature -> feature instanceof Ngram).map(feature -> (Ngram) feature).filter(ngram -> ngram.getN() > 1).forEach(ngram -> {
double[] scores = StumpSelector.scores(index, labels, ngram, idTranslator, matchScoreType, docFilter);
for (int l = 0; l < numLabels; l++) {
queues.get(l).add(new Pair<>(ngram, scores[l]));
}
});
Set<Ngram> kept = new HashSet<>();
StringBuilder stringBuilder = new StringBuilder();
for (int l = 0; l < numLabels; l++) {
stringBuilder.append("-------------------------").append("\n");
stringBuilder.append(labelTranslator.toExtLabel(l)).append(":").append("\n");
BoundedBlockPriorityQueue<Pair<Ngram, Double>> queue = queues.get(l);
while (queue.size() > 0) {
Ngram ngram = queue.poll().getFirst();
kept.add(ngram);
stringBuilder.append(ngram.getNgram()).append(", ");
}
stringBuilder.append("\n");
}
File selectionFile = new File(metaDataFolder, "selected_ngrams.txt");
FileUtils.writeStringToFile(selectionFile, stringBuilder.toString());
logger.info("finish ngram selection");
logger.info("selected ngrams are written to " + selectionFile.getAbsolutePath());
// after feature selection, overwrite the feature_list.ser file; rename old files
FeatureList selectedFeatures = new FeatureList();
for (Feature feature : featureList.getAll()) {
if (!(feature instanceof Ngram)) {
selectedFeatures.add(feature);
}
if ((feature instanceof Ngram) && ((Ngram) feature).getN() == 1) {
selectedFeatures.add(feature);
}
if ((feature instanceof Ngram) && ((Ngram) feature).getN() > 1 && kept.contains(feature)) {
selectedFeatures.add(feature);
}
}
FileUtils.copyFile(new File(metaDataFolder, "feature_list.ser"), new File(metaDataFolder, "feature_list_all.ser"));
FileUtils.copyFile(new File(metaDataFolder, "feature_list.txt"), new File(metaDataFolder, "feature_list_all.txt"));
Serialization.serialize(selectedFeatures, new File(metaDataFolder, "feature_list.ser"));
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(metaDataFolder, "feature_list.txt")))) {
for (Feature feature : selectedFeatures.getAll()) {
bufferedWriter.write(feature.toString());
bufferedWriter.newLine();
}
}
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class TRECFormat method writeMatrixFile.
private static void writeMatrixFile(ClfDataSet dataSet, File trecFile) {
File matrixFile = new File(trecFile, TREC_MATRIX_FILE_NAME);
int numDataPoints = dataSet.getNumDataPoints();
int[] labels = dataSet.getLabels();
try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
for (int i = 0; i < numDataPoints; i++) {
int label = labels[i];
bw.write(label + " ");
Vector vector = dataSet.getRow(i);
// only write non-zeros
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : vector.nonZeroes()) {
Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
pairs.add(pair);
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
for (Pair<Integer, Double> pair : sorted) {
bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
}
bw.write("\n");
}
} catch (IOException e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class LibSvmFormat method save.
public static void save(RegDataSet dataSet, String libSvmFile) {
File matrixFile = new File(libSvmFile);
int numDataPoints = dataSet.getNumDataPoints();
double[] labels = dataSet.getLabels();
try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
for (int i = 0; i < numDataPoints; i++) {
double label = labels[i];
bw.write(label + " ");
Vector vector = dataSet.getRow(i);
// only write non-zeros
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : vector.nonZeroes()) {
Pair<Integer, Double> pair = new Pair<>(element.index() + 1, element.get());
pairs.add(pair);
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
for (Pair<Integer, Double> pair : sorted) {
bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
}
bw.write("\n");
}
} catch (IOException e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class DataSetUtil method toMultiClass.
public static Pair<ClfDataSet, Translator<MultiLabel>> toMultiClass(MultiLabelClfDataSet dataSet) {
int numDataPoints = dataSet.getNumDataPoints();
int numFeatures = dataSet.getNumFeatures();
List<MultiLabel> multiLabels = DataSetUtil.gatherMultiLabels(dataSet);
Translator<MultiLabel> translator = new Translator<>();
translator.addAll(multiLabels);
ClfDataSet clfDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numClasses(translator.size()).build();
for (int i = 0; i < numDataPoints; i++) {
//only copy non-zero elements
Vector vector = dataSet.getRow(i);
for (Vector.Element element : vector.nonZeroes()) {
int featureIndex = element.index();
double value = element.get();
clfDataSet.setFeatureValue(i, featureIndex, value);
}
int label = translator.getIndex(dataSet.getMultiLabels()[i]);
clfDataSet.setLabel(i, label);
}
List<String> extLabels = multiLabels.stream().map(MultiLabel::toString).collect(Collectors.toList());
LabelTranslator labelTranslator = new LabelTranslator(extLabels);
clfDataSet.setLabelTranslator(labelTranslator);
clfDataSet.setFeatureList(dataSet.getFeatureList());
return new Pair<>(clfDataSet, translator);
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class LinearRegElasticNet method main.
public static void main(String[] args) throws Exception {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
String output = config.getString("output.folder");
new File(output).mkdirs();
String sparsity = config.getString("featureMatrix.sparsity").toLowerCase();
DataSetType dataSetType = null;
switch(sparsity) {
case "dense":
dataSetType = DataSetType.REG_DENSE;
break;
case "sparse":
dataSetType = DataSetType.REG_SPARSE;
break;
default:
throw new IllegalArgumentException("featureMatrix.sparsity can be either dense or sparse");
}
RegDataSet trainSet = TRECFormat.loadRegDataSet(config.getString("input.trainSet"), dataSetType, true);
RegDataSet testSet = TRECFormat.loadRegDataSet(config.getString("input.testSet"), dataSetType, true);
LinearRegression linearRegression = new LinearRegression(trainSet.getNumFeatures());
ElasticNetLinearRegOptimizer optimizer = new ElasticNetLinearRegOptimizer(linearRegression, trainSet);
optimizer.setRegularization(config.getDouble("regularization"));
optimizer.setL1Ratio(config.getDouble("l1Ratio"));
System.out.println("before training");
System.out.println("training set RMSE = " + RMSE.rmse(linearRegression, trainSet));
System.out.println("test set RMSE = " + RMSE.rmse(linearRegression, testSet));
System.out.println("start training");
StopWatch stopWatch = new StopWatch();
stopWatch.start();
optimizer.optimize();
System.out.println("training done");
System.out.println("time spent on training = " + stopWatch);
System.out.println("after training");
System.out.println("training set RMSE = " + RMSE.rmse(linearRegression, trainSet));
System.out.println("test set RMSE = " + RMSE.rmse(linearRegression, testSet));
System.out.println("number of non-zeros weights in linear regression (not including bias) = " + linearRegression.getWeights().getWeightsWithoutBias().getNumNonZeroElements());
List<Pair<Integer, Double>> sorted = new ArrayList<>();
for (Vector.Element element : linearRegression.getWeights().getWeightsWithoutBias().nonZeroes()) {
sorted.add(new Pair<>(element.index(), element.get()));
}
Comparator<Pair<Integer, Double>> comparatorByIndex = Comparator.comparing(pair -> pair.getFirst());
sorted = sorted.stream().sorted(comparatorByIndex).collect(Collectors.toList());
StringBuilder sb1 = new StringBuilder();
for (Pair<Integer, Double> pair : sorted) {
int index = pair.getFirst();
sb1.append(index).append("(").append(trainSet.getFeatureList().get(index).getName()).append(")").append(":").append(pair.getSecond()).append("\n");
}
FileUtils.writeStringToFile(new File(output, "features_sorted_by_indices.txt"), sb1.toString());
System.out.println("all selected features (sorted by indices) are saved to " + new File(output, "features_sorted_by_indices.txt").getAbsolutePath());
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(pair -> Math.abs(pair.getSecond()));
sorted = sorted.stream().sorted(comparator.reversed()).collect(Collectors.toList());
StringBuilder sb = new StringBuilder();
for (Pair<Integer, Double> pair : sorted) {
int index = pair.getFirst();
sb.append(index).append("(").append(trainSet.getFeatureList().get(index).getName()).append(")").append(":").append(pair.getSecond()).append("\n");
}
FileUtils.writeStringToFile(new File(output, "features_sorted_by_weights.txt"), sb.toString());
System.out.println("all selected features (sorted by absolute weights) are saved to " + new File(output, "features_sorted_by_weights.txt").getAbsolutePath());
File reportFile = new File(output, "test_predictions.txt");
report(linearRegression, testSet, reportFile);
System.out.println("predictions on the test set are written to " + reportFile.getAbsolutePath());
}
Aggregations