use of org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.
the class PopularityWalker method next.
/**
* This method returns next walk sequence from this graph
*
* @return
*/
@Override
public Sequence<T> next() {
Sequence<T> sequence = new Sequence<>();
int[] visitedHops = new int[walkLength];
Arrays.fill(visitedHops, -1);
int startPosition = position.getAndIncrement();
int lastId = -1;
int startPoint = order[startPosition];
startPosition = startPoint;
for (int i = 0; i < walkLength; i++) {
Vertex<T> vertex = sourceGraph.getVertex(startPosition);
int currentPosition = startPosition;
sequence.addElement(vertex.getValue());
visitedHops[i] = vertex.vertexID();
int cSpread = 0;
if (alpha > 0 && lastId != startPoint && lastId != -1 && alpha > rng.nextDouble()) {
startPosition = startPoint;
continue;
}
switch(walkDirection) {
case RANDOM:
case FORWARD_ONLY:
case FORWARD_UNIQUE:
case FORWARD_PREFERRED:
{
// we get popularity of each node connected to the current node.
PriorityQueue<Node<T>> queue = new PriorityQueue<>();
// ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(order[currentPosition]), visitedHops);
int[] connections = ArrayUtils.removeElements(sourceGraph.getConnectedVertexIndices(vertex.vertexID()), visitedHops);
int start = 0;
int stop = 0;
int cnt = 0;
if (connections.length > 0) {
for (int connected : connections) {
queue.add(new Node<T>(connected, sourceGraph.getConnectedVertices(connected).size()), sourceGraph.getConnectedVertices(connected).size());
}
cSpread = spread > connections.length ? connections.length : spread;
switch(popularityMode) {
case MAXIMUM:
start = 0;
stop = start + cSpread - 1;
break;
case MINIMUM:
start = connections.length - cSpread;
stop = connections.length - 1;
break;
case AVERAGE:
int mid = connections.length / 2;
start = mid - (cSpread / 2);
stop = mid + (cSpread / 2);
break;
}
// logger.info("Spread: ["+ cSpread+ "], Connections: ["+ connections.length+"], Start: ["+start+"], Stop: ["+stop+"]");
cnt = 0;
//logger.info("Queue: " + queue);
//logger.info("Queue size: " + queue.size());
List<Node<T>> list = new ArrayList<>();
double[] weights = new double[cSpread];
int fcnt = 0;
while (queue.hasNext()) {
Node<T> node = queue.next();
if (cnt >= start && cnt <= stop) {
list.add(node);
weights[fcnt] = node.getWeight();
fcnt++;
}
connections[cnt] = node.getVertexId();
cnt++;
}
int con = -1;
switch(spectrum) {
case PLAIN:
{
con = RandomUtils.nextInt(start, stop + 1);
// logger.info("Picked selection: " + con);
Vertex<T> nV = sourceGraph.getVertex(connections[con]);
startPosition = nV.vertexID();
lastId = vertex.vertexID();
}
break;
case PROPORTIONAL:
{
double[] norm = MathArrays.normalizeArray(weights, 1);
double prob = rng.nextDouble();
double floor = 0.0;
for (int b = 0; b < weights.length; b++) {
if (prob >= floor && prob < floor + norm[b]) {
startPosition = list.get(b).getVertexId();
lastId = startPosition;
break;
} else {
floor += norm[b];
}
}
}
break;
}
} else {
switch(noEdgeHandling) {
case EXCEPTION_ON_DISCONNECTED:
throw new NoEdgesException("No more edges at vertex [" + currentPosition + "]");
case CUTOFF_ON_DISCONNECTED:
i += walkLength;
break;
case SELF_LOOP_ON_DISCONNECTED:
startPosition = currentPosition;
break;
case RESTART_ON_DISCONNECTED:
startPosition = startPoint;
break;
default:
throw new UnsupportedOperationException("Unsupported noEdgeHandling: [" + noEdgeHandling + "]");
}
}
}
break;
default:
throw new UnsupportedOperationException("Unknown WalkDirection: [" + walkDirection + "]");
}
}
return sequence;
}
use of org.deeplearning4j.models.sequencevectors.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.
the class Graph method getRandomConnectedVertex.
@Override
public Vertex<V> getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException {
if (vertex < 0 || vertex >= vertices.size())
throw new IllegalArgumentException("Invalid vertex index: " + vertex);
if (edges[vertex] == null || edges[vertex].isEmpty())
throw new NoEdgesException("Cannot generate random connected vertex: vertex " + vertex + " has no outgoing/undirected edges");
int connectedVertexNum = rng.nextInt(edges[vertex].size());
Edge<E> edge = edges[vertex].get(connectedVertexNum);
if (edge.getFrom() == vertex)
//directed or undirected, vertex -> x
return vertices.get(edge.getTo());
else
//Undirected edge, x -> vertex
return vertices.get(edge.getFrom());
}
Aggregations