use of org.deeplearning4j.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.
the class RandomWalkIterator method next.
@Override
public IVertexSequence<V> next() {
if (!hasNext())
throw new NoSuchElementException();
//Generate a random walk starting at vertex order[current]
int currVertexIdx = order[position++];
int[] indices = new int[walkLength + 1];
indices[0] = currVertexIdx;
if (walkLength == 0)
return new VertexSequence<>(graph, indices);
Vertex<V> next;
try {
next = graph.getRandomConnectedVertex(currVertexIdx, rng);
} catch (NoEdgesException e) {
switch(mode) {
case SELF_LOOP_ON_DISCONNECTED:
for (int i = 1; i < walkLength; i++) indices[i] = currVertexIdx;
return new VertexSequence<>(graph, indices);
case EXCEPTION_ON_DISCONNECTED:
throw e;
default:
throw new RuntimeException("Unknown/not implemented NoEdgeHandling mode: " + mode);
}
}
indices[1] = next.vertexID();
currVertexIdx = indices[1];
for (int i = 2; i <= walkLength; i++) {
//<= walk length: i.e., if walk length = 2, it contains 3 vertices etc
next = graph.getRandomConnectedVertex(currVertexIdx, rng);
currVertexIdx = next.vertexID();
indices[i] = currVertexIdx;
}
return new VertexSequence<>(graph, indices);
}
use of org.deeplearning4j.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.
the class Graph method getRandomConnectedVertex.
@Override
public Vertex<V> getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException {
if (vertex < 0 || vertex >= vertices.size())
throw new IllegalArgumentException("Invalid vertex index: " + vertex);
if (edges[vertex] == null || edges[vertex].isEmpty())
throw new NoEdgesException("Cannot generate random connected vertex: vertex " + vertex + " has no outgoing/undirected edges");
int connectedVertexNum = rng.nextInt(edges[vertex].size());
Edge<E> edge = edges[vertex].get(connectedVertexNum);
if (edge.getFrom() == vertex)
//directed or undirected, vertex -> x
return vertices.get(edge.getTo());
else
//Undirected edge, x -> vertex
return vertices.get(edge.getFrom());
}
use of org.deeplearning4j.graph.exception.NoEdgesException in project deeplearning4j by deeplearning4j.
the class WeightedRandomWalkIterator method next.
@Override
public IVertexSequence<V> next() {
if (!hasNext())
throw new NoSuchElementException();
//Generate a weighted random walk starting at vertex order[current]
int currVertexIdx = order[position++];
int[] indices = new int[walkLength + 1];
indices[0] = currVertexIdx;
if (walkLength == 0)
return new VertexSequence<>(graph, indices);
for (int i = 1; i <= walkLength; i++) {
List<? extends Edge<? extends Number>> edgeList = graph.getEdgesOut(currVertexIdx);
//First: check if there are any outgoing edges from this vertex. If not: handle the situation
if (edgeList == null || edgeList.isEmpty()) {
switch(mode) {
case SELF_LOOP_ON_DISCONNECTED:
for (int j = i; j < walkLength; j++) indices[j] = currVertexIdx;
return new VertexSequence<>(graph, indices);
case EXCEPTION_ON_DISCONNECTED:
throw new NoEdgesException("Cannot conduct random walk: vertex " + currVertexIdx + " has no outgoing edges. " + " Set NoEdgeHandling mode to NoEdgeHandlingMode.SELF_LOOP_ON_DISCONNECTED to self loop instead of " + "throwing an exception in this situation.");
default:
throw new RuntimeException("Unknown/not implemented NoEdgeHandling mode: " + mode);
}
}
//To do a weighted random walk: we need to know total weight of all outgoing edges
double totalWeight = 0.0;
for (Edge<? extends Number> edge : edgeList) {
totalWeight += edge.getValue().doubleValue();
}
double d = rng.nextDouble();
double threshold = d * totalWeight;
double sumWeight = 0.0;
for (Edge<? extends Number> edge : edgeList) {
sumWeight += edge.getValue().doubleValue();
if (sumWeight >= threshold) {
if (edge.isDirected()) {
currVertexIdx = edge.getTo();
} else {
if (edge.getFrom() == currVertexIdx) {
currVertexIdx = edge.getTo();
} else {
//Undirected edge: might be next--currVertexIdx instead of currVertexIdx--next
currVertexIdx = edge.getFrom();
}
}
indices[i] = currVertexIdx;
break;
}
}
}
return new VertexSequence<>(graph, indices);
}
Aggregations