use of com.tencent.angel.ml.math2.vector.IntFloatVector in project angel by Tencent.
the class NodeUtils method deserialize.
public static IntFloatVector deserialize(DataInputStream input) throws IOException {
IntFloatVector feats;
int dim = input.readInt();
int len = input.readInt();
StorageMethod storageMethod = StorageMethod.valuesOf(input.readInt());
switch(storageMethod) {
case DENSE:
{
float[] values = new float[len];
for (int i = 0; i < len; i++) {
values[i] = input.readFloat();
}
feats = VFactory.denseFloatVector(values);
break;
}
case SPARSE:
{
feats = VFactory.sparseFloatVector(dim, len);
for (int i = 0; i < len; i++) {
feats.set(input.readInt(), input.readFloat());
}
break;
}
case SORTED:
{
int[] keys = new int[len];
float[] values = new float[len];
for (int i = 0; i < len; i++) {
keys[i] = input.readInt();
values[i] = input.readFloat();
}
feats = VFactory.sortedFloatVector(dim, keys, values);
break;
}
default:
throw new UnsupportedOperationException("Unsupport storage type " + storageMethod);
}
return feats;
}
use of com.tencent.angel.ml.math2.vector.IntFloatVector in project angel by Tencent.
the class MergeUtils method combineIntFloatIndexRowSplits.
// //////////////////////////////////////////////////////////////////////////////
// Combine Int key Float value vector
// //////////////////////////////////////////////////////////////////////////////
public static Vector combineIntFloatIndexRowSplits(int matrixId, int rowId, int resultSize, KeyPart[] keyParts, ValuePart[] valueParts, MatrixMeta matrixMeta) {
IntFloatVector vector = VFactory.sparseFloatVector((int) matrixMeta.getColNum(), resultSize);
for (int i = 0; i < keyParts.length; i++) {
mergeTo(vector, keyParts[i], (FloatValuesPart) valueParts[i]);
}
vector.setRowId(rowId);
vector.setMatrixId(matrixId);
return vector;
}
use of com.tencent.angel.ml.math2.vector.IntFloatVector in project angel by Tencent.
the class CompIntFloatVectorSplitter method split.
@Override
public Map<PartitionKey, RowUpdateSplit> split(Vector vector, List<PartitionKey> parts) {
IntFloatVector[] vecParts = ((CompIntFloatVector) vector).getPartitions();
assert vecParts.length == parts.size();
Map<PartitionKey, RowUpdateSplit> updateSplitMap = new HashMap<>(parts.size());
for (int i = 0; i < vecParts.length; i++) {
updateSplitMap.put(parts.get(i), new CompIntFloatRowUpdateSplit(vector.getRowId(), vecParts[i], (int) (parts.get(i).getEndCol() - parts.get(i).getStartCol())));
}
return updateSplitMap;
}
use of com.tencent.angel.ml.math2.vector.IntFloatVector in project angel by Tencent.
the class UpdatePSFTest method testSparseFloatUDF.
public void testSparseFloatUDF() throws Exception {
Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
MatrixClient client1 = worker.getPSAgent().getMatrixClient(SPARSE_FLOAT_MAT, 0);
int matrixW1Id = client1.getMatrixId();
int[] index = genIndexs(feaNum, nnz);
IntFloatVector deltaVec = new IntFloatVector(feaNum, new IntFloatSparseVectorStorage(feaNum, nnz));
for (int i = 0; i < index.length; i++) {
deltaVec.set(index[i], index[i]);
}
// for (int i = 0; i < feaNum; i++) {
// deltaVec.set(i, i);
// }
deltaVec.setRowId(0);
Vector[] updates = new Vector[1];
updates[0] = deltaVec;
client1.asyncUpdate(new IncrementRows(new IncrementRowsParam(matrixW1Id, updates))).get();
IntFloatVector row = (IntFloatVector) client1.getRow(0);
for (int id : index) {
// System.out.println("id=" + id + ", value=" + row.get(id));
assertEquals(row.get(id), deltaVec.get(id), 0.000001);
}
Assert.assertTrue(index.length == row.size());
}
use of com.tencent.angel.ml.math2.vector.IntFloatVector in project angel by Tencent.
the class UpdatePSFTest method testDenseFloatUDF.
public void testDenseFloatUDF() throws Exception {
Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
MatrixClient client1 = worker.getPSAgent().getMatrixClient(DENSE_FLOAT_MAT, 0);
int matrixW1Id = client1.getMatrixId();
int[] index = genIndexs(feaNum, nnz);
IntFloatVector deltaVec = new IntFloatVector(feaNum, new IntFloatDenseVectorStorage(feaNum));
for (int i = 0; i < feaNum; i++) {
deltaVec.set(i, i);
}
deltaVec.setRowId(0);
Vector[] updates = new Vector[1];
updates[0] = deltaVec;
client1.asyncUpdate(new IncrementRows(new IncrementRowsParam(matrixW1Id, updates))).get();
IntFloatVector row = (IntFloatVector) client1.getRow(0);
for (int id : index) {
Assert.assertTrue(row.get(id) == deltaVec.get(id));
}
Assert.assertTrue(feaNum == row.size());
}
Aggregations