use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class PartitionTrainingFunction method call.
@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
/**
* first we initialize
*/
if (vectorsConfiguration == null)
vectorsConfiguration = configurationBroadcast.getValue();
if (paramServer == null) {
paramServer = VoidParameterServer.getInstance();
if (elementsLearningAlgorithm == null) {
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
driver = elementsLearningAlgorithm.getTrainingDriver();
// FIXME: init line should probably be removed, basically init happens in VocabRddFunction
paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
}
if (shallowVocabCache == null)
shallowVocabCache = vocabCacheBroadcast.getValue();
if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
// TODO: do ELA initialization
try {
elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (elementsLearningAlgorithm != null)
elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
// TODO: do SLA initialization
try {
sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
if (sequenceLearningAlgorithm != null)
sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
}
List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
// now we roll throw Sequences and prepare/convert/"learn" them
while (sequenceIterator.hasNext()) {
Sequence<T> sequence = sequenceIterator.next();
Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
for (T element : sequence.getElements()) {
// it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
if (reduced != null)
mergedSequence.addElement(reduced);
}
// do the same with labels, transfer them, if any
if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
for (T label : sequence.getSequenceLabels()) {
ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
if (reduced != null)
mergedSequence.addSequenceLabel(reduced);
}
}
sequences.add(mergedSequence);
if (sequences.size() >= 8) {
trainAllAtOnce(sequences);
sequences.clear();
}
}
if (sequences.size() > 0) {
// finishing training round, to make sure we don't have trails
trainAllAtOnce(sequences);
sequences.clear();
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class NetworkOrganizer method toBinaryOctet.
protected static String toBinaryOctet(@NonNull Integer value) {
if (value < 0 || value > 255)
throw new ND4JIllegalStateException("IP octets cant hold values below 0 or above 255");
String octetBase = Integer.toBinaryString(value);
StringBuilder builder = new StringBuilder();
for (int i = 0; i < 8 - octetBase.length(); i++) {
builder.append("0");
}
builder.append(octetBase);
return builder.toString();
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.
the class NetworkOrganizer method getIntersections.
/**
* This method returns specified numbers of IP's by parsing original list of trees into some form of binary tree
*
* @param numShards
* @param primary
* @return
*/
protected List<String> getIntersections(int numShards, Collection<String> primary) {
/**
* Since each ip address can be represented in 4-byte sequence, 1 byte per value, with leading order - we'll use that to build tree
*/
if (primary == null) {
for (NetworkInformation information : informationCollection) {
for (String ip : information.getIpAddresses()) {
// first we get binary representation for each IP
String octet = convertIpToOctets(ip);
// then we map each of them into virtual "tree", to find most popular networks within cluster
tree.map(octet);
}
}
// we get most "popular" A network from tree now
String octetA = tree.getHottestNetworkA();
List<String> candidates = new ArrayList<>();
AtomicInteger matchCount = new AtomicInteger(0);
for (NetworkInformation node : informationCollection) {
for (String ip : node.getIpAddresses()) {
String octet = convertIpToOctets(ip);
// calculating matches
if (octet.startsWith(octetA)) {
matchCount.incrementAndGet();
candidates.add(ip);
break;
}
}
}
/**
* TODO: improve this. we just need to iterate over popular networks instead of single top A network
*/
if (matchCount.get() != informationCollection.size())
throw new ND4JIllegalStateException("Mismatching A class");
Collections.shuffle(candidates);
return new ArrayList<>(candidates.subList(0, Math.min(numShards, candidates.size())));
} else {
// if primary isn't null, we expect network to be already filtered
String octetA = tree.getHottestNetworkA();
List<String> candidates = new ArrayList<>();
for (NetworkInformation node : informationCollection) {
for (String ip : node.getIpAddresses()) {
String octet = convertIpToOctets(ip);
// calculating matches
if (octet.startsWith(octetA) && !primary.contains(ip)) {
candidates.add(ip);
break;
}
}
}
Collections.shuffle(candidates);
return new ArrayList<>(candidates.subList(0, Math.min(numShards, candidates.size())));
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class DefaultOpExecutioner method checkForWorkspaces.
protected void checkForWorkspaces(Op op) {
val x = op.x();
if (x != null && x.isAttached()) {
val ws = x.data().getParentWorkspace();
if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
if (!ws.isScopeActive()) {
throw new ND4JIllegalStateException("Op [" + op.opName() + "] X argument uses leaked workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
}
if (ws.getGenerationId() != x.data().getGenerationId())
throw new ND4JIllegalStateException("Op [" + op.opName() + "] X argument uses outdated workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
}
}
val y = op.y();
if (y != null && y.isAttached()) {
val ws = y.data().getParentWorkspace();
if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
if (!ws.isScopeActive()) {
throw new ND4JIllegalStateException("Op [" + op.opName() + "] Y argument uses leaked workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
}
if (ws.getGenerationId() != y.data().getGenerationId())
throw new ND4JIllegalStateException("Op [" + op.opName() + "] Y argument uses outdated workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
}
}
val z = op.z();
if (z != null && z.isAttached()) {
val ws = z.data().getParentWorkspace();
if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
if (!ws.isScopeActive()) {
throw new ND4JIllegalStateException("Op [" + op.opName() + "] Z argument uses leaked workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
}
if (ws.getGenerationId() != z.data().getGenerationId())
throw new ND4JIllegalStateException("Op [" + op.opName() + "] Z argument uses outdated workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
}
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class OpExecutionerUtil method checkForNaN.
public static void checkForNaN(INDArray z) {
if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
return;
int match = 0;
if (!z.isScalar()) {
MatchCondition condition = new MatchCondition(z, Conditions.isNan());
match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
} else {
if (z.data().dataType() == DataBuffer.Type.DOUBLE) {
if (Double.isNaN(z.getDouble(0)))
match = 1;
} else {
if (Float.isNaN(z.getFloat(0)))
match = 1;
}
}
if (match > 0)
throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
}
Aggregations