Search in sources :

Example 1 with NoEdgesException

use of org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.

the class PopularityWalker method next.

/**
     * This method returns next walk sequence from this graph
     *
     * @return
     */
@Override
public Sequence<T> next() {
    Sequence<T> sequence = new Sequence<>();
    int[] visitedHops = new int[walkLength];
    Arrays.fill(visitedHops, -1);
    int startPosition = position.getAndIncrement();
    int lastId = -1;
    int startPoint = order[startPosition];
    startPosition = startPoint;
    for (int i = 0; i < walkLength; i++) {
        Vertex<T> vertex = sourceGraph.getVertex(startPosition);
        int currentPosition = startPosition;
        sequence.addElement(vertex.getValue());
        visitedHops[i] = vertex.vertexID();
        int cSpread = 0;
        if (alpha > 0 && lastId != startPoint && lastId != -1 && alpha > rng.nextDouble()) {
            startPosition = startPoint;
            continue;
        }
        switch(walkDirection) {
            case RANDOM:
            case FORWARD_ONLY:
            case FORWARD_UNIQUE:
            case FORWARD_PREFERRED:
                {
                    // we get  popularity of each node connected to the current node.
                    PriorityQueue<Node<T>> queue = new PriorityQueue<>();
                    // ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(order[currentPosition]), visitedHops);
                    int[] connections = ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(vertex.vertexID()), visitedHops);
                    int start = 0;
                    int stop = 0;
                    int cnt = 0;
                    if (connections.length > 0) {
                        for (int connected : connections) {
                            queue.add(new Node<T>(connected, sourceGraph.getConnectedVertices(connected).size()), sourceGraph.getConnectedVertices(connected).size());
                        }
                        cSpread = spread > connections.length ? connections.length : spread;
                        switch(popularityMode) {
                            case MAXIMUM:
                                start = 0;
                                stop = start + cSpread - 1;
                                break;
                            case MINIMUM:
                                start = connections.length - cSpread;
                                stop = connections.length - 1;
                                break;
                            case AVERAGE:
                                int mid = connections.length / 2;
                                start = mid - (cSpread / 2);
                                stop = mid + (cSpread / 2);
                                break;
                        }
                        // logger.info("Spread: ["+ cSpread+ "], Connections: ["+ connections.length+"], Start: ["+start+"], Stop: ["+stop+"]");
                        cnt = 0;
                        //logger.info("Queue: " + queue);
                        //logger.info("Queue size: " + queue.size());
                        List<Node<T>> list = new ArrayList<>();
                        double[] weights = new double[cSpread];
                        int fcnt = 0;
                        while (queue.hasNext()) {
                            Node<T> node = queue.next();
                            if (cnt >= start && cnt <= stop) {
                                list.add(node);
                                weights[fcnt] = node.getWeight();
                                fcnt++;
                            }
                            connections[cnt] = node.getVertexId();
                            cnt++;
                        }
                        int con = -1;
                        switch(spectrum) {
                            case PLAIN:
                                {
                                    con = RandomUtils.nextInt(start, stop + 1);
                                    //    logger.info("Picked selection: " + con);
                                    Vertex<T> nV = sourceGraph.getVertex(connections[con]);
                                    startPosition = nV.vertexID();
                                    lastId = vertex.vertexID();
                                }
                                break;
                            case PROPORTIONAL:
                                {
                                    double[] norm = MathArrays.normalizeArray(weights, 1);
                                    double prob = rng.nextDouble();
                                    double floor = 0.0;
                                    for (int b = 0; b < weights.length; b++) {
                                        if (prob >= floor && prob < floor + norm[b]) {
                                            startPosition = list.get(b).getVertexId();
                                            lastId = startPosition;
                                            break;
                                        } else {
                                            floor += norm[b];
                                        }
                                    }
                                }
                                break;
                        }
                    } else {
                        switch(noEdgeHandling) {
                            case EXCEPTION_ON_DISCONNECTED:
                                throw new NoEdgesException("No more edges at vertex [" + currentPosition + "]");
                            case CUTOFF_ON_DISCONNECTED:
                                i += walkLength;
                                break;
                            case SELF_LOOP_ON_DISCONNECTED:
                                startPosition = currentPosition;
                                break;
                            case RESTART_ON_DISCONNECTED:
                                startPosition = startPoint;
                                break;
                            default:
                                throw new UnsupportedOperationException("Unsupported noEdgeHandling: [" + noEdgeHandling + "]");
                        }
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException("Unknown WalkDirection: [" + walkDirection + "]");
        }
    }
    return sequence;
}
Also used : NoEdgesException(org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) PriorityQueue(org.deeplearning4j.berkeley.PriorityQueue) ArrayList(java.util.ArrayList) List(java.util.List)

Example 2 with NoEdgesException

use of org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.

the class Graph method getRandomConnectedVertex.

@Override
public Vertex<V> getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException {
    if (vertex < 0 || vertex >= vertices.size())
        throw new IllegalArgumentException("Invalid vertex index: " + vertex);
    if (edges[vertex] == null || edges[vertex].isEmpty())
        throw new NoEdgesException("Cannot generate random connected vertex: vertex " + vertex + " has no outgoing/undirected edges");
    int connectedVertexNum = rng.nextInt(edges[vertex].size());
    Edge<E> edge = edges[vertex].get(connectedVertexNum);
    if (edge.getFrom() == vertex)
        //directed or undirected, vertex -> x
        return vertices.get(edge.getTo());
    else
        //Undirected edge, x -> vertex
        return vertices.get(edge.getFrom());
}
Also used : NoEdgesException(org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException)

Aggregations

NoEdgesException (org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException)2 ArrayList (java.util.ArrayList)1 List (java.util.List)1 PriorityQueue (org.deeplearning4j.berkeley.PriorityQueue)1 Sequence (org.deeplearning4j.models.sequencevectors.sequence.Sequence)1