Search in sources :

Example 1 with InitNeighbor

use of com.tencent.angel.graph.client.initNeighbor5.InitNeighbor in project angel by Tencent.

the class GetNodeFeatsTest2 method testCSR.

@Test
public void testCSR() throws Exception {
    Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
    MatrixClient client = worker.getPSAgent().getMatrixClient(NODE, 0);
    int matrixId = client.getMatrixId();
    // Init node neighbors and feats
    long[] nodeIds = new long[6];
    IntFloatVector[] feats = new IntFloatVector[6];
    int[] indptr = new int[7];
    long[][] neighbors = new long[6][];
    Long2ObjectOpenHashMap<long[]> idToNeighbors = new Long2ObjectOpenHashMap<>();
    indptr[0] = 0;
    nodeIds[0] = 1;
    neighbors[0] = new long[] { 2, 3, 4, 5, 8 };
    indptr[1] = 5;
    idToNeighbors.put(nodeIds[0], neighbors[0]);
    feats[0] = VFactory.denseFloatVector(5);
    feats[0].set(0, 0.2f);
    feats[0].set(1, 0.3f);
    feats[0].set(2, 0.4f);
    feats[0].set(3, 0.5f);
    feats[0].set(4, 0.6f);
    nodeIds[1] = 2;
    neighbors[1] = new long[] { 1, 3, 4, 5, 8 };
    indptr[2] = indptr[1] + 5;
    idToNeighbors.put(nodeIds[1], neighbors[1]);
    feats[1] = VFactory.sparseFloatVector(5, 2);
    feats[1].set(1, 0.4f);
    feats[1].set(3, 0.5f);
    nodeIds[2] = 3;
    neighbors[2] = new long[] { 1, 2, 4, 5, 8 };
    indptr[3] = indptr[2] + 5;
    idToNeighbors.put(nodeIds[2], neighbors[2]);
    feats[2] = VFactory.sortedFloatVector(5, 3);
    feats[2].set(0, 0.4f);
    feats[2].set(1, 0.5f);
    feats[2].set(4, 0.6f);
    nodeIds[3] = 4;
    neighbors[3] = new long[] { 1, 2, 3, 5, 8 };
    indptr[4] = indptr[3] + 5;
    idToNeighbors.put(nodeIds[3], neighbors[3]);
    feats[3] = VFactory.sparseFloatVector(5, 2);
    feats[3].set(4, 0.6f);
    feats[3].set(1, 0.5f);
    nodeIds[4] = 5;
    neighbors[4] = new long[] { 1, 2, 3, 4, 8 };
    indptr[5] = indptr[4] + 5;
    idToNeighbors.put(nodeIds[4], neighbors[4]);
    feats[4] = VFactory.sparseFloatVector(5, 1);
    feats[4].set(2, 0.6f);
    nodeIds[5] = 8;
    neighbors[5] = new long[] { 1, 2, 3, 4, 5 };
    indptr[6] = indptr[5] + 5;
    idToNeighbors.put(nodeIds[5], neighbors[5]);
    feats[5] = VFactory.sparseFloatVector(5, 2);
    feats[5].set(0, 0.3f);
    feats[5].set(1, 0.4f);
    long[] ns = new long[indptr[indptr.length - 1] - indptr[0]];
    int t = 0;
    for (int i = 0; i < neighbors.length; i++) {
        for (int j = 0; j < neighbors[i].length; j++) {
            ns[t] = neighbors[i][j];
            t++;
        }
    }
    InitNeighbor initFunc = new InitNeighbor(new InitNeighborParam(matrixId, nodeIds, indptr, ns));
    client.asyncUpdate(initFunc).get();
    InitNodeFeats func = new InitNodeFeats(new InitNodeFeatsParam(matrixId, nodeIds, feats));
    client.asyncUpdate(func).get();
    // Sample the neighbors
    nodeIds = new long[] { 1, 2, 3, 4, 5, 6, 7, 8 };
    GetNodeFeatsParam param = new GetNodeFeatsParam(matrixId, nodeIds);
    Long2ObjectOpenHashMap<IntFloatVector> result = ((GetNodeFeatsResult) (client.get(new GetNodeFeats(param)))).getResult();
    ObjectIterator<Long2ObjectMap.Entry<IntFloatVector>> iter = result.long2ObjectEntrySet().fastIterator();
    LOG.info("==============================sample neighbors result============================");
    Long2ObjectMap.Entry<IntFloatVector> entry;
    while (iter.hasNext()) {
        entry = iter.next();
        IntFloatVector vector = entry.getValue();
        if (vector.isDense()) {
            LOG.info("node " + entry.getLongKey() + " has a dense features");
            float[] values = vector.getStorage().getValues();
            for (int i = 0; i < values.length; i++) {
                LOG.info("feat index " + i + " values = " + values[i]);
            }
        } else if (vector.isSparse()) {
            LOG.info("node " + entry.getLongKey() + " has a sparse features");
            ObjectIterator<Int2FloatMap.Entry> valueIter = vector.getStorage().entryIterator();
            while (valueIter.hasNext()) {
                Int2FloatMap.Entry keyValue = valueIter.next();
                LOG.info("feat index " + keyValue.getIntKey() + " values = " + keyValue.getFloatValue());
            }
        } else {
            LOG.info("node " + entry.getLongKey() + " has a sorted features");
            int[] keys = vector.getStorage().getIndices();
            float[] values = vector.getStorage().getValues();
            for (int i = 0; i < values.length; i++) {
                LOG.info("feat index " + keys[i] + " values = " + values[i]);
            }
        }
    }
}
Also used : InitNeighborParam(com.tencent.angel.graph.client.initNeighbor5.InitNeighborParam) InitNeighbor(com.tencent.angel.graph.client.initNeighbor5.InitNeighbor) ObjectIterator(it.unimi.dsi.fastutil.objects.ObjectIterator) Worker(com.tencent.angel.worker.Worker) MatrixClient(com.tencent.angel.psagent.matrix.MatrixClient) InitNodeFeats(com.tencent.angel.graph.client.initnodefeats4.InitNodeFeats) InitNodeFeatsParam(com.tencent.angel.graph.client.initnodefeats4.InitNodeFeatsParam) GetNodeFeats(com.tencent.angel.graph.client.getnodefeats2.GetNodeFeats) Long2ObjectMap(it.unimi.dsi.fastutil.longs.Long2ObjectMap) GetNodeFeatsParam(com.tencent.angel.graph.client.getnodefeats2.GetNodeFeatsParam) Long2ObjectOpenHashMap(it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap) IntFloatVector(com.tencent.angel.ml.math2.vector.IntFloatVector) GetNodeFeatsResult(com.tencent.angel.graph.client.getnodefeats2.GetNodeFeatsResult) Int2FloatMap(it.unimi.dsi.fastutil.ints.Int2FloatMap) Test(org.junit.Test)

Example 2 with InitNeighbor

use of com.tencent.angel.graph.client.initNeighbor5.InitNeighbor in project angel by Tencent.

the class InitNeighborTest2 method testCSR.

@Test
public void testCSR() throws Exception {
    Worker worker = LocalClusterContext.get().getWorker(workerAttempt0Id).getWorker();
    MatrixClient client = worker.getPSAgent().getMatrixClient(SPARSE_INT_MAT, 0);
    int matrixId = client.getMatrixId();
    ParameterServer ps = LocalClusterContext.get().getPS(psAttempt0Id).getPS();
    Location masterLoc = LocalClusterContext.get().getMaster().getAppMaster().getAppContext().getMasterService().getLocation();
    TConnection connection = TConnectionManager.getConnection(ps.getConf());
    MasterProtocol master = connection.getMasterService(masterLoc.getIp(), masterLoc.getPort());
    // Init node neighbors
    Long2ObjectOpenHashMap<long[]> nodeIdToNeighbors = new Long2ObjectOpenHashMap<>();
    nodeIdToNeighbors.put(1, new long[] { 2, 3, 4, 5, 6 });
    nodeIdToNeighbors.put(2, new long[] { 4, 5 });
    nodeIdToNeighbors.put(3, new long[] { 4, 5, 6 });
    nodeIdToNeighbors.put(4, new long[] { 5, 6 });
    nodeIdToNeighbors.put(5, new long[] { 6 });
    nodeIdToNeighbors.put(8, new long[] { 3, 4 });
    InitNeighbor func = new InitNeighbor(new InitNeighborParam(matrixId, nodeIdToNeighbors));
    client.asyncUpdate(func).get();
    nodeIdToNeighbors.clear();
    /*nodeIdToNeighbors.put(1, new long[]{4, 5, 6});
    nodeIdToNeighbors.put(2, new long[]{5});
    nodeIdToNeighbors.put(4, new long[]{5, 6});
    func = new InitNeighbor(new InitNeighborParam(matrixId, nodeIdToNeighbors));
    client.asyncUpdate(func).get();
    nodeIdToNeighbors.clear();

    nodeIdToNeighbors.put(3, new long[]{4, 5, 6});
    nodeIdToNeighbors.put(5, new long[]{6});
    nodeIdToNeighbors.put(8, new long[]{3, 4});
    func = new InitNeighbor(new InitNeighborParam(matrixId, nodeIdToNeighbors));
    client.asyncUpdate(func).get();
    nodeIdToNeighbors.clear();
    */
    // client.asyncUpdate(new InitNeighborOver(new InitNeighborOverParam(matrixId))).get();
    // Sample the neighbors
    long[] nodeIds = new long[] { 1, 2, 3, 4, 5, 6, 7, 8 };
    SampleNeighborParam param = new SampleNeighborParam(matrixId, nodeIds, 2);
    Long2ObjectOpenHashMap<long[]> result = ((SampleNeighborResult) (client.get(new SampleNeighbor(param)))).getNodeIdToNeighbors();
    ObjectIterator<Long2ObjectMap.Entry<long[]>> iter = result.long2ObjectEntrySet().fastIterator();
    LOG.info("==============================sample neighbors result============================");
    Long2ObjectMap.Entry<long[]> entry;
    while (iter.hasNext()) {
        entry = iter.next();
        LOG.info("node id = " + entry.getLongKey() + ", neighbors = " + Arrays.toString(entry.getValue()));
    }
    client.checkpoint(0).get();
    ps.stop(-1);
    PSErrorRequest request = PSErrorRequest.newBuilder().setPsAttemptId(ProtobufUtil.convertToIdProto(psAttempt0Id)).setMsg("out of memory").build();
    master.psError(null, request);
    Thread.sleep(10000);
    param = new SampleNeighborParam(matrixId, nodeIds, -1);
    result = ((SampleNeighborResult) (client.get(new SampleNeighbor(param)))).getNodeIdToNeighbors();
    iter = result.long2ObjectEntrySet().fastIterator();
    LOG.info("==============================sample neighbors result============================");
    while (iter.hasNext()) {
        entry = iter.next();
        LOG.info("node id = " + entry.getLongKey() + ", neighbors = " + Arrays.toString(entry.getValue()));
    }
}
Also used : InitNeighborParam(com.tencent.angel.graph.client.initneighbor2.InitNeighborParam) SampleNeighborResult(com.tencent.angel.graph.client.sampleneighbor2.SampleNeighborResult) Long2ObjectMap(it.unimi.dsi.fastutil.longs.Long2ObjectMap) SampleNeighbor(com.tencent.angel.graph.client.sampleneighbor2.SampleNeighbor) InitNeighbor(com.tencent.angel.graph.client.initneighbor2.InitNeighbor) Long2ObjectOpenHashMap(it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap) ParameterServer(com.tencent.angel.ps.ParameterServer) TConnection(com.tencent.angel.ipc.TConnection) SampleNeighborParam(com.tencent.angel.graph.client.sampleneighbor2.SampleNeighborParam) Worker(com.tencent.angel.worker.Worker) MatrixClient(com.tencent.angel.psagent.matrix.MatrixClient) MasterProtocol(com.tencent.angel.master.MasterProtocol) PSErrorRequest(com.tencent.angel.protobuf.generated.PSMasterServiceProtos.PSErrorRequest) Location(com.tencent.angel.common.location.Location) Test(org.junit.Test)

Aggregations

MatrixClient (com.tencent.angel.psagent.matrix.MatrixClient)2 Worker (com.tencent.angel.worker.Worker)2 Long2ObjectMap (it.unimi.dsi.fastutil.longs.Long2ObjectMap)2 Long2ObjectOpenHashMap (it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap)2 Test (org.junit.Test)2 Location (com.tencent.angel.common.location.Location)1 GetNodeFeats (com.tencent.angel.graph.client.getnodefeats2.GetNodeFeats)1 GetNodeFeatsParam (com.tencent.angel.graph.client.getnodefeats2.GetNodeFeatsParam)1 GetNodeFeatsResult (com.tencent.angel.graph.client.getnodefeats2.GetNodeFeatsResult)1 InitNeighbor (com.tencent.angel.graph.client.initNeighbor5.InitNeighbor)1 InitNeighborParam (com.tencent.angel.graph.client.initNeighbor5.InitNeighborParam)1 InitNeighbor (com.tencent.angel.graph.client.initneighbor2.InitNeighbor)1 InitNeighborParam (com.tencent.angel.graph.client.initneighbor2.InitNeighborParam)1 InitNodeFeats (com.tencent.angel.graph.client.initnodefeats4.InitNodeFeats)1 InitNodeFeatsParam (com.tencent.angel.graph.client.initnodefeats4.InitNodeFeatsParam)1 SampleNeighbor (com.tencent.angel.graph.client.sampleneighbor2.SampleNeighbor)1 SampleNeighborParam (com.tencent.angel.graph.client.sampleneighbor2.SampleNeighborParam)1 SampleNeighborResult (com.tencent.angel.graph.client.sampleneighbor2.SampleNeighborResult)1 TConnection (com.tencent.angel.ipc.TConnection)1 MasterProtocol (com.tencent.angel.master.MasterProtocol)1