Search in sources :

Example 1 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class CMLCRFElasticNet method iterate.

public void iterate() {
    // System.out.println("weights: " + cmlcrf.getWeights().getAllWeights());
    // O(NdL)
    // System.out.println(Arrays.toString(cmlcrf.getCombinationLabelPartScores()));
    updateClassScoreMatrix();
    cmlcrf.updateCombLabelPartScores();
    updateAssignmentScoreMatrix();
    updateAssignmentProbMatrix();
    updateCombProbSums();
    updatePredictedCounts();
    updateClassProbMatrix();
    // update for each support label set
    Vector accumulateWeights = new SequentialAccessSparseVector(numParameters);
    Vector oldWeights = cmlcrf.getWeights().deepCopy().getAllWeights();
    for (int l = 0; l < numSupport; l++) {
        // System.out.println("label: " + supportedCombinations.get(l));
        DataSet newData = expandData(l);
        iterateForOneComb(newData, l);
        accumulateWeights = accumulateWeights.plus(cmlcrf.getWeights().getAllWeights());
        cmlcrf.getWeights().setWeightVector(oldWeights);
    }
    // lineSearch
    if (true) {
        Vector searchDirection = accumulateWeights;
        Vector gradient = this.predictedCounts.minus(empiricalCounts).divide(numData);
        lineSearch(searchDirection, gradient);
    }
    this.terminator.add(getValue());
}
Also used : SequentialSparseDataSet(edu.neu.ccs.pyramid.dataset.SequentialSparseDataSet) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) DenseVector(org.apache.mahout.math.DenseVector) SequentialAccessSparseVector(org.apache.mahout.math.SequentialAccessSparseVector) Vector(org.apache.mahout.math.Vector) SequentialAccessSparseVector(org.apache.mahout.math.SequentialAccessSparseVector)

Example 2 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class StumpSelector method scores.

/**
 * @param index
 * @param labels size = num labels * num data
 * @param feature
 * @param idTranslator
 * @param matchScoreType
 * @param docFilter
 */
public static double[] scores(ESIndex index, double[][] labels, Ngram feature, IdTranslator idTranslator, FeatureLoader.MatchScoreType matchScoreType, String docFilter, Map<String, float[]> fieldLength) {
    Ngram ngram = null;
    try {
        ngram = (Ngram) Serialization.deepCopy(feature);
    } catch (IOException e) {
        e.printStackTrace();
    } catch (ClassNotFoundException e) {
        e.printStackTrace();
    }
    ngram.setIndex(0);
    DataSet dataSet = new SparseDataSet(labels[0].length, 1, false, null);
    FeatureLoader.loadNgramFeature(index, dataSet, ngram, idTranslator, matchScoreType, docFilter, fieldLength);
    double[] scores = new double[labels.length];
    for (int l = 0; l < scores.length; l++) {
        double score = score(dataSet, labels[l]);
        scores[l] = score;
    }
    return scores;
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) SparseDataSet(edu.neu.ccs.pyramid.dataset.SparseDataSet) SparseDataSet(edu.neu.ccs.pyramid.dataset.SparseDataSet) IOException(java.io.IOException) Ngram(edu.neu.ccs.pyramid.feature.Ngram)

Example 3 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class LogisticRegressionInspector method topFeatures.

public static TopFeatures topFeatures(LogisticRegression logisticRegression, DataSet dataSet, int classIndex, int limit) {
    FeatureList featureList = logisticRegression.getFeatureList();
    Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
    Comparator<FeatureUtility> comparator = Comparator.comparing(featureUtility -> Math.abs(featureUtility.getUtility()));
    List<FeatureUtility> list = IntStream.range(0, weights.size()).parallel().mapToObj(i -> {
        Vector column = dataSet.getColumn(i);
        if (column.getNumNonZeroElements() == 0) {
            return new FeatureUtility(featureList.get(i)).setUtility(0);
        }
        double weight = weights.get(i);
        double sum = 0;
        for (Vector.Element element : column.nonZeroes()) {
            sum += weight * element.get();
        }
        sum /= column.getNumNonZeroElements();
        return new FeatureUtility(featureList.get(i)).setUtility(sum);
    }).sorted(comparator.reversed()).limit(limit).collect(Collectors.toList());
    TopFeatures topFeatures = new TopFeatures();
    topFeatures.setTopFeatures(list.stream().map(FeatureUtility::getFeature).collect(Collectors.toList()));
    topFeatures.setUtilities(list.stream().map(FeatureUtility::getUtility).collect(Collectors.toList()));
    topFeatures.setClassIndex(classIndex);
    LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
    topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
    return topFeatures;
}
Also used : edu.neu.ccs.pyramid.regression(edu.neu.ccs.pyramid.regression) IntStream(java.util.stream.IntStream) ClassProbability(edu.neu.ccs.pyramid.classification.ClassProbability) PredictionAnalysis(edu.neu.ccs.pyramid.classification.PredictionAnalysis) java.util(java.util) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) IdTranslator(edu.neu.ccs.pyramid.dataset.IdTranslator) Vector(org.apache.mahout.math.Vector) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) Collectors(java.util.stream.Collectors) edu.neu.ccs.pyramid.feature(edu.neu.ccs.pyramid.feature) LabelTranslator(edu.neu.ccs.pyramid.dataset.LabelTranslator) Vector(org.apache.mahout.math.Vector)

Example 4 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMSelector method selectAll.

public static Pair<BM, double[][]> selectAll(int numClasses, MultiLabel[] multiLabels, int numClusters) {
    DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
    for (int i = 0; i < multiLabels.length; i++) {
        MultiLabel multiLabel = multiLabels[i];
        for (int label : multiLabel.getMatchedLabels()) {
            dataSet.setFeatureValue(i, label, 1);
        }
    }
    BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
    // System.out.println("bm = "+trainer.bm);
    // System.out.println("gamma = "+ Arrays.deepToString(trainer.gammas));
    Pair<BM, double[][]> pair = new Pair<>();
    pair.setFirst(trainer.getBm());
    pair.setSecond(trainer.gammas);
    return pair;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) Pair(edu.neu.ccs.pyramid.util.Pair)

Example 5 with DataSet

use of edu.neu.ccs.pyramid.dataset.DataSet in project pyramid by cheng-li.

the class BMSelectorTest method test1.

private static void test1() {
    DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
    for (int i = 0; i < 5; i++) {
        dataSet.setFeatureValue(i, 0, 1);
    }
    for (int i = 5; i < 10; i++) {
        dataSet.setFeatureValue(i, 1, 1);
    }
    for (int i = 10; i < 20; i++) {
        dataSet.setFeatureValue(i, 1, 1);
        dataSet.setFeatureValue(i, 2, 1);
        dataSet.setFeatureValue(i, 3, 1);
    }
    System.out.println("dataset = " + dataSet);
    BM bm = BMSelector.select(dataSet, 3, 10);
    System.out.println(bm);
    for (int i = 0; i < 5; i++) {
        System.out.println("sample " + i);
        System.out.println(bm.sample());
    }
    Vector vector1 = new DenseVector(5);
    vector1.set(0, 1);
    Vector vector2 = new DenseVector(5);
    vector2.set(1, 1);
    Vector vector3 = new DenseVector(5);
    vector3.set(1, 1);
    vector3.set(2, 1);
    vector3.set(3, 1);
    System.out.println(Math.exp(bm.logProbability(vector1)));
    System.out.println(Math.exp(bm.logProbability(vector2)));
    System.out.println(Math.exp(bm.logProbability(vector3)));
}
Also used : DataSet(edu.neu.ccs.pyramid.dataset.DataSet) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Aggregations

DataSet (edu.neu.ccs.pyramid.dataset.DataSet)20 Vector (org.apache.mahout.math.Vector)7 File (java.io.File)6 java.util (java.util)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 DataSetBuilder (edu.neu.ccs.pyramid.dataset.DataSetBuilder)3 Accuracy (edu.neu.ccs.pyramid.eval.Accuracy)3 ArgMax (edu.neu.ccs.pyramid.util.ArgMax)3 BufferedImage (java.awt.image.BufferedImage)3 List (java.util.List)3 ImageIO (javax.imageio.ImageIO)3 FileUtils (org.apache.commons.io.FileUtils)3 DenseVector (org.apache.mahout.math.DenseVector)3 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 Ngram (edu.neu.ccs.pyramid.feature.Ngram)2 IOException (java.io.IOException)2 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)1 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)1 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)1