use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class Weights method readObject.
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
numClasses = in.readInt();
numFeatures = in.readInt();
int[] indices = (int[]) in.readObject();
double[] values = (double[]) in.readObject();
weightVector = new DenseVector((numFeatures + 1) * numClasses);
for (int i = 0; i < indices.length; i++) {
weightVector.set(indices[i], values[i]);
}
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class RidgeBinaryLogisticLoss method Hv.
public void Hv(Vector s, Vector Hs) {
Vector wa = new DenseVector(numRows);
Xv(s, wa);
for (int i = 0; i < numRows; i++) wa.set(i, regularization.get(i) * diagonals.get(i) * wa.get(i));
XTv(wa, Hs);
for (int i = 0; i < numColumns; i++) Hs.set(i, s.get(i) + Hs.get(i));
// delete[] wa;
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class StandardFormat method save.
public static void save(MultiLabelClfDataSet dataSet, File featureFile, File labelFile, String delimiter) throws Exception {
int numDataPoints = dataSet.getNumDataPoints();
int numFeatures = dataSet.getNumFeatures();
int numClasses = dataSet.getNumClasses();
featureFile.getParentFile().mkdirs();
labelFile.getParentFile().mkdirs();
try (BufferedWriter bw1 = new BufferedWriter(new FileWriter(featureFile))) {
for (int i = 0; i < numDataPoints; i++) {
Vector vector = dataSet.getRow(i);
Vector dense = new DenseVector(vector);
for (int j = 0; j < numFeatures; j++) {
bw1.write("" + dense.get(j));
if (j != numFeatures - 1) {
bw1.write(delimiter);
} else {
bw1.write("\n");
}
}
}
}
try (BufferedWriter bw2 = new BufferedWriter(new FileWriter(labelFile))) {
MultiLabel[] labels = dataSet.getMultiLabels();
for (int i = 0; i < numDataPoints; i++) {
MultiLabel label = labels[i];
for (int l = 0; l < numClasses; l++) {
if (label.matchClass(l)) {
bw2.write("" + 1);
} else {
bw2.write("" + 0);
}
if (l < numClasses - 1) {
bw2.write(delimiter);
}
}
bw2.write("\n");
}
}
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class BMSelectorTest method test1.
private static void test1() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
for (int i = 0; i < 5; i++) {
dataSet.setFeatureValue(i, 0, 1);
}
for (int i = 5; i < 10; i++) {
dataSet.setFeatureValue(i, 1, 1);
}
for (int i = 10; i < 20; i++) {
dataSet.setFeatureValue(i, 1, 1);
dataSet.setFeatureValue(i, 2, 1);
dataSet.setFeatureValue(i, 3, 1);
}
System.out.println("dataset = " + dataSet);
BM bm = BMSelector.select(dataSet, 3, 10);
System.out.println(bm);
for (int i = 0; i < 5; i++) {
System.out.println("sample " + i);
System.out.println(bm.sample());
}
Vector vector1 = new DenseVector(5);
vector1.set(0, 1);
Vector vector2 = new DenseVector(5);
vector2.set(1, 1);
Vector vector3 = new DenseVector(5);
vector3.set(1, 1);
vector3.set(2, 1);
vector3.set(3, 1);
System.out.println(Math.exp(bm.logProbability(vector1)));
System.out.println(Math.exp(bm.logProbability(vector2)));
System.out.println(Math.exp(bm.logProbability(vector3)));
}
use of org.apache.mahout.math.DenseVector in project pyramid by cheng-li.
the class BMTrainerTest method test5.
private static void test5() {
DataSet dataSet = DataSetBuilder.getBuilder().numFeatures(5).numDataPoints(20).dense(true).build();
for (int i = 0; i < 5; i++) {
dataSet.setFeatureValue(i, 0, 1);
}
for (int i = 5; i < 10; i++) {
dataSet.setFeatureValue(i, 1, 1);
}
for (int i = 10; i < 20; i++) {
dataSet.setFeatureValue(i, 2, 1);
dataSet.setFeatureValue(i, 3, 1);
}
System.out.println("dataset = " + dataSet);
BMTrainer trainer = new BMTrainer(dataSet, 3, 0);
System.out.println(trainer.bm);
trainer.train();
// for (int iter=0;iter<100;iter++){
// trainer.iterate();
// }
System.out.println(trainer.bm);
Vector vector1 = new DenseVector(5);
vector1.set(0, 1);
Vector vector2 = new DenseVector(5);
vector2.set(1, 1);
Vector vector3 = new DenseVector(5);
vector3.set(2, 1);
vector3.set(3, 1);
System.out.println(Math.exp(trainer.bm.logProbability(vector1)));
System.out.println(Math.exp(trainer.bm.logProbability(vector2)));
System.out.println(Math.exp(trainer.bm.logProbability(vector3)));
}
Aggregations