Search in sources :

Example 56 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class RegressionTreeTest method test1.

private static void test1() {
    Node a = new Node();
    a.setFeatureIndex(0);
    a.setThreshold(0.0);
    a.setLeftProb(0.3);
    a.setRightProb(0.7);
    Node b = new Node();
    b.setFeatureIndex(1);
    b.setThreshold(0.1);
    b.setLeftProb(0.8);
    b.setRightProb(0.2);
    Node c = new Node();
    c.setFeatureIndex(2);
    c.setThreshold(0.2);
    c.setLeftProb(0.1);
    c.setRightProb(0.9);
    Node d = new Node();
    d.setLeaf(true);
    d.setValue(1);
    Node e = new Node();
    e.setLeaf(true);
    e.setValue(2);
    Node f = new Node();
    f.setLeaf(true);
    f.setValue(3);
    Node g = new Node();
    g.setLeaf(true);
    g.setValue(4);
    a.setLeftChild(b);
    a.setRightChild(c);
    b.setLeftChild(d);
    b.setRightChild(e);
    c.setLeftChild(f);
    c.setRightChild(g);
    RegressionTree tree = new RegressionTree();
    tree.root = a;
    tree.leaves.add(d);
    tree.leaves.add(e);
    tree.leaves.add(f);
    tree.leaves.add(g);
    Vector vector1 = new DenseVector(3);
    vector1.set(0, Double.NaN);
    vector1.set(1, Double.NaN);
    vector1.set(2, Double.NaN);
    System.out.println(tree.probability(vector1, a));
    System.out.println(tree.probability(vector1, b));
    System.out.println(tree.probability(vector1, c));
    System.out.println(tree.probability(vector1, d));
    System.out.println(tree.probability(vector1, e));
    System.out.println(tree.probability(vector1, f));
    System.out.println(tree.probability(vector1, g));
    System.out.println(tree.predict(vector1));
    System.out.println(0.24 + 0.06 * 2 + 0.07 * 3 + 0.63 * 4);
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 57 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class IntervalSplitterTest method test7.

static void test7() {
    RegTreeConfig regTreeConfig = new RegTreeConfig().setNumSplitIntervals(4);
    Vector vector = new DenseVector(4);
    vector.set(0, 0);
    vector.set(1, 1);
    vector.set(2, 2);
    vector.set(3, 3);
    double[] probs = { 0, 0.5, 0.2, 0.6 };
    double[] labels = { 1, 2, 3, 4 };
    Splitter.GlobalStats globalStats = new Splitter.GlobalStats(labels, probs);
    List<Interval> intervals = IntervalSplitter.generateIntervals(regTreeConfig, vector, probs, labels, globalStats);
    System.out.println(intervals);
    System.out.println(IntervalSplitter.compress(intervals));
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 58 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class VectorsTest method main.

public static void main(String[] args) {
    double[] d = { 1, 2, 5 };
    Vector v = new DenseVector(d);
    System.out.println(Vectors.concatenate(v, 4.5));
    test2();
    test3();
}
Also used : DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Example 59 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class SerializableVector method readObject.

private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
    in.defaultReadObject();
    if (type == Type.DENSE) {
        double[] values = (double[]) in.readObject();
        vector = new DenseVector(values);
    } else if (type == Type.SPARSE_RANDOM) {
        int[] indices = (int[]) in.readObject();
        double[] values = (double[]) in.readObject();
        vector = new RandomAccessSparseVector(size);
        for (int i = 0; i < indices.length; i++) {
            vector.set(indices[i], values[i]);
        }
    } else if (type == Type.SPARSE_SEQUENTIAL) {
        int[] indices = (int[]) in.readObject();
        double[] values = (double[]) in.readObject();
        vector = new SequentialAccessSparseVector(size);
        for (int i = 0; i < indices.length; i++) {
            vector.set(indices[i], values[i]);
        }
    }
}
Also used : RandomAccessSparseVector(org.apache.mahout.math.RandomAccessSparseVector) DenseVector(org.apache.mahout.math.DenseVector) SequentialAccessSparseVector(org.apache.mahout.math.SequentialAccessSparseVector)

Example 60 with DenseVector

use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.

the class FusedKolmogorovFilterTest method test2.

private static void test2() {
    Vector vector = new DenseVector(10);
    vector.set(0, 0.1);
    vector.set(1, 0.2);
    vector.set(2, 0.15);
    vector.set(3, 0.4);
    vector.set(4, 0.7);
    vector.set(8, 0.9);
    vector.set(9, 0.8);
    int[] labels = new int[10];
    labels[0] = 0;
    labels[1] = 1;
    labels[2] = 2;
    labels[3] = 1;
    labels[9] = 2;
    FusedKolmogorovFilter filter = new FusedKolmogorovFilter();
    filter.setNumBins(10);
    List<List<Double>> inputsEachClass = filter.generateInputsEachClass(vector, labels, 3);
    System.out.println(inputsEachClass);
    List<EmpiricalCDF> empiricalCDFs = filter.generateCDFs(vector, inputsEachClass);
    System.out.println(empiricalCDFs);
    System.out.println(filter.maxDistance(empiricalCDFs));
    System.out.println(EmpiricalCDF.distance(empiricalCDFs.get(0), empiricalCDFs.get(1)));
    System.out.println(EmpiricalCDF.distance(empiricalCDFs.get(0), empiricalCDFs.get(2)));
    System.out.println(EmpiricalCDF.distance(empiricalCDFs.get(1), empiricalCDFs.get(2)));
}
Also used : List(java.util.List) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector) EmpiricalCDF(edu.neu.ccs.pyramid.util.EmpiricalCDF)

Aggregations

DenseVector (org.apache.mahout.math.DenseVector)62 Vector (org.apache.mahout.math.Vector)56 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)7 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)5 RandomAccessSparseVector (org.apache.mahout.math.RandomAccessSparseVector)5 SequentialAccessSparseVector (org.apache.mahout.math.SequentialAccessSparseVector)4 List (java.util.List)3 EnumeratedIntegerDistribution (org.apache.commons.math3.distribution.EnumeratedIntegerDistribution)3 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)2 DataSet (edu.neu.ccs.pyramid.dataset.DataSet)2 EmpiricalCDF (edu.neu.ccs.pyramid.util.EmpiricalCDF)2 IntegerDistribution (org.apache.commons.math3.distribution.IntegerDistribution)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Classifier (edu.neu.ccs.pyramid.classification.Classifier)1 Weights (edu.neu.ccs.pyramid.classification.logistic_regression.Weights)1 RegDataSet (edu.neu.ccs.pyramid.dataset.RegDataSet)1 ConstantRegressor (edu.neu.ccs.pyramid.regression.ConstantRegressor)1 BernoulliDistribution (edu.neu.ccs.pyramid.util.BernoulliDistribution)1 Pair (edu.neu.ccs.pyramid.util.Pair)1 ArrayList (java.util.ArrayList)1