use of com.tencent.angel.graph.client.initnodefeats4.InitNodeFeatsParam in project angel by Tencent.
the class GetNodeFeatsTest2 method testCSR.
@Test
public void testCSR() throws Exception {
Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
MatrixClient client = worker.getPSAgent().getMatrixClient(NODE, 0);
int matrixId = client.getMatrixId();
// Init node neighbors and feats
long[] nodeIds = new long[6];
IntFloatVector[] feats = new IntFloatVector[6];
int[] indptr = new int[7];
long[][] neighbors = new long[6][];
Long2ObjectOpenHashMap<long[]> idToNeighbors = new Long2ObjectOpenHashMap<>();
indptr[0] = 0;
nodeIds[0] = 1;
neighbors[0] = new long[] { 2, 3, 4, 5, 8 };
indptr[1] = 5;
idToNeighbors.put(nodeIds[0], neighbors[0]);
feats[0] = VFactory.denseFloatVector(5);
feats[0].set(0, 0.2f);
feats[0].set(1, 0.3f);
feats[0].set(2, 0.4f);
feats[0].set(3, 0.5f);
feats[0].set(4, 0.6f);
nodeIds[1] = 2;
neighbors[1] = new long[] { 1, 3, 4, 5, 8 };
indptr[2] = indptr[1] + 5;
idToNeighbors.put(nodeIds[1], neighbors[1]);
feats[1] = VFactory.sparseFloatVector(5, 2);
feats[1].set(1, 0.4f);
feats[1].set(3, 0.5f);
nodeIds[2] = 3;
neighbors[2] = new long[] { 1, 2, 4, 5, 8 };
indptr[3] = indptr[2] + 5;
idToNeighbors.put(nodeIds[2], neighbors[2]);
feats[2] = VFactory.sortedFloatVector(5, 3);
feats[2].set(0, 0.4f);
feats[2].set(1, 0.5f);
feats[2].set(4, 0.6f);
nodeIds[3] = 4;
neighbors[3] = new long[] { 1, 2, 3, 5, 8 };
indptr[4] = indptr[3] + 5;
idToNeighbors.put(nodeIds[3], neighbors[3]);
feats[3] = VFactory.sparseFloatVector(5, 2);
feats[3].set(4, 0.6f);
feats[3].set(1, 0.5f);
nodeIds[4] = 5;
neighbors[4] = new long[] { 1, 2, 3, 4, 8 };
indptr[5] = indptr[4] + 5;
idToNeighbors.put(nodeIds[4], neighbors[4]);
feats[4] = VFactory.sparseFloatVector(5, 1);
feats[4].set(2, 0.6f);
nodeIds[5] = 8;
neighbors[5] = new long[] { 1, 2, 3, 4, 5 };
indptr[6] = indptr[5] + 5;
idToNeighbors.put(nodeIds[5], neighbors[5]);
feats[5] = VFactory.sparseFloatVector(5, 2);
feats[5].set(0, 0.3f);
feats[5].set(1, 0.4f);
long[] ns = new long[indptr[indptr.length - 1] - indptr[0]];
int t = 0;
for (int i = 0; i < neighbors.length; i++) {
for (int j = 0; j < neighbors[i].length; j++) {
ns[t] = neighbors[i][j];
t++;
}
}
InitNeighbor initFunc = new InitNeighbor(new InitNeighborParam(matrixId, nodeIds, indptr, ns));
client.asyncUpdate(initFunc).get();
InitNodeFeats func = new InitNodeFeats(new InitNodeFeatsParam(matrixId, nodeIds, feats));
client.asyncUpdate(func).get();
// Sample the neighbors
nodeIds = new long[] { 1, 2, 3, 4, 5, 6, 7, 8 };
GetNodeFeatsParam param = new GetNodeFeatsParam(matrixId, nodeIds);
Long2ObjectOpenHashMap<IntFloatVector> result = ((GetNodeFeatsResult) (client.get(new GetNodeFeats(param)))).getResult();
ObjectIterator<Long2ObjectMap.Entry<IntFloatVector>> iter = result.long2ObjectEntrySet().fastIterator();
LOG.info("==============================sample neighbors result============================");
Long2ObjectMap.Entry<IntFloatVector> entry;
while (iter.hasNext()) {
entry = iter.next();
IntFloatVector vector = entry.getValue();
if (vector.isDense()) {
LOG.info("node " + entry.getLongKey() + " has a dense features");
float[] values = vector.getStorage().getValues();
for (int i = 0; i < values.length; i++) {
LOG.info("feat index " + i + " values = " + values[i]);
}
} else if (vector.isSparse()) {
LOG.info("node " + entry.getLongKey() + " has a sparse features");
ObjectIterator<Int2FloatMap.Entry> valueIter = vector.getStorage().entryIterator();
while (valueIter.hasNext()) {
Int2FloatMap.Entry keyValue = valueIter.next();
LOG.info("feat index " + keyValue.getIntKey() + " values = " + keyValue.getFloatValue());
}
} else {
LOG.info("node " + entry.getLongKey() + " has a sorted features");
int[] keys = vector.getStorage().getIndices();
float[] values = vector.getStorage().getValues();
for (int i = 0; i < values.length; i++) {
LOG.info("feat index " + keys[i] + " values = " + values[i]);
}
}
}
}
Aggregations