Search in sources :

Example 66 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class PartitionTrainingFunction method call.

@SuppressWarnings("unchecked")
@Override
public void call(Iterator<Sequence<T>> sequenceIterator) throws Exception {
    /**
         * first we initialize
         */
    if (vectorsConfiguration == null)
        vectorsConfiguration = configurationBroadcast.getValue();
    if (paramServer == null) {
        paramServer = VoidParameterServer.getInstance();
        if (elementsLearningAlgorithm == null) {
            try {
                elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        driver = elementsLearningAlgorithm.getTrainingDriver();
        // FIXME: init line should probably be removed, basically init happens in VocabRddFunction
        paramServer.init(paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), driver);
    }
    if (shallowVocabCache == null)
        shallowVocabCache = vocabCacheBroadcast.getValue();
    if (elementsLearningAlgorithm == null && vectorsConfiguration.getElementsLearningAlgorithm() != null) {
        // TODO: do ELA initialization
        try {
            elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) Class.forName(vectorsConfiguration.getElementsLearningAlgorithm()).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (elementsLearningAlgorithm != null)
        elementsLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (sequenceLearningAlgorithm == null && vectorsConfiguration.getSequenceLearningAlgorithm() != null) {
        // TODO: do SLA initialization
        try {
            sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) Class.forName(vectorsConfiguration.getSequenceLearningAlgorithm()).newInstance();
            sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    if (sequenceLearningAlgorithm != null)
        sequenceLearningAlgorithm.configure(shallowVocabCache, null, vectorsConfiguration);
    if (elementsLearningAlgorithm == null && sequenceLearningAlgorithm == null) {
        throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
    }
    List<Sequence<ShallowSequenceElement>> sequences = new ArrayList<>();
    // now we roll throw Sequences and prepare/convert/"learn" them
    while (sequenceIterator.hasNext()) {
        Sequence<T> sequence = sequenceIterator.next();
        Sequence<ShallowSequenceElement> mergedSequence = new Sequence<>();
        for (T element : sequence.getElements()) {
            // it's possible to get null here, i.e. if frequency for this element is below minWordFrequency threshold
            ShallowSequenceElement reduced = shallowVocabCache.tokenFor(element.getStorageId());
            if (reduced != null)
                mergedSequence.addElement(reduced);
        }
        // do the same with labels, transfer them, if any
        if (sequenceLearningAlgorithm != null && vectorsConfiguration.isTrainSequenceVectors()) {
            for (T label : sequence.getSequenceLabels()) {
                ShallowSequenceElement reduced = shallowVocabCache.tokenFor(label.getStorageId());
                if (reduced != null)
                    mergedSequence.addSequenceLabel(reduced);
            }
        }
        sequences.add(mergedSequence);
        if (sequences.size() >= 8) {
            trainAllAtOnce(sequences);
            sequences.clear();
        }
    }
    if (sequences.size() > 0) {
        // finishing training round, to make sure we don't have trails
        trainAllAtOnce(sequences);
        sequences.clear();
    }
}
Also used : ShallowSequenceElement(org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement) ArrayList(java.util.ArrayList) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Sequence(org.deeplearning4j.models.sequencevectors.sequence.Sequence) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 67 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class NetworkOrganizer method toBinaryOctet.

protected static String toBinaryOctet(@NonNull Integer value) {
    if (value < 0 || value > 255)
        throw new ND4JIllegalStateException("IP octets cant hold values below 0 or above 255");
    String octetBase = Integer.toBinaryString(value);
    StringBuilder builder = new StringBuilder();
    for (int i = 0; i < 8 - octetBase.length(); i++) {
        builder.append("0");
    }
    builder.append(octetBase);
    return builder.toString();
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 68 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project deeplearning4j by deeplearning4j.

the class NetworkOrganizer method getIntersections.

/**
     * This method returns specified numbers of IP's by parsing original list of trees into some form of binary tree
     *
     * @param numShards
     * @param primary
     * @return
     */
protected List<String> getIntersections(int numShards, Collection<String> primary) {
    /**
         * Since each ip address can be represented in 4-byte sequence, 1 byte per value, with leading order - we'll use that to build tree
         */
    if (primary == null) {
        for (NetworkInformation information : informationCollection) {
            for (String ip : information.getIpAddresses()) {
                // first we get binary representation for each IP
                String octet = convertIpToOctets(ip);
                // then we map each of them into virtual "tree", to find most popular networks within cluster
                tree.map(octet);
            }
        }
        // we get most "popular" A network from tree now
        String octetA = tree.getHottestNetworkA();
        List<String> candidates = new ArrayList<>();
        AtomicInteger matchCount = new AtomicInteger(0);
        for (NetworkInformation node : informationCollection) {
            for (String ip : node.getIpAddresses()) {
                String octet = convertIpToOctets(ip);
                // calculating matches
                if (octet.startsWith(octetA)) {
                    matchCount.incrementAndGet();
                    candidates.add(ip);
                    break;
                }
            }
        }
        /**
             * TODO: improve this. we just need to iterate over popular networks instead of single top A network
             */
        if (matchCount.get() != informationCollection.size())
            throw new ND4JIllegalStateException("Mismatching A class");
        Collections.shuffle(candidates);
        return new ArrayList<>(candidates.subList(0, Math.min(numShards, candidates.size())));
    } else {
        // if primary isn't null, we expect network to be already filtered
        String octetA = tree.getHottestNetworkA();
        List<String> candidates = new ArrayList<>();
        for (NetworkInformation node : informationCollection) {
            for (String ip : node.getIpAddresses()) {
                String octet = convertIpToOctets(ip);
                // calculating matches
                if (octet.startsWith(octetA) && !primary.contains(ip)) {
                    candidates.add(ip);
                    break;
                }
            }
        }
        Collections.shuffle(candidates);
        return new ArrayList<>(candidates.subList(0, Math.min(numShards, candidates.size())));
    }
}
Also used : AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NetworkInformation(org.deeplearning4j.spark.models.sequencevectors.primitives.NetworkInformation)

Example 69 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class DefaultOpExecutioner method checkForWorkspaces.

protected void checkForWorkspaces(Op op) {
    val x = op.x();
    if (x != null && x.isAttached()) {
        val ws = x.data().getParentWorkspace();
        if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
            if (!ws.isScopeActive()) {
                throw new ND4JIllegalStateException("Op [" + op.opName() + "] X argument uses leaked workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
            }
            if (ws.getGenerationId() != x.data().getGenerationId())
                throw new ND4JIllegalStateException("Op [" + op.opName() + "] X argument uses outdated workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
        }
    }
    val y = op.y();
    if (y != null && y.isAttached()) {
        val ws = y.data().getParentWorkspace();
        if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
            if (!ws.isScopeActive()) {
                throw new ND4JIllegalStateException("Op [" + op.opName() + "] Y argument uses leaked workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
            }
            if (ws.getGenerationId() != y.data().getGenerationId())
                throw new ND4JIllegalStateException("Op [" + op.opName() + "] Y argument uses outdated workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
        }
    }
    val z = op.z();
    if (z != null && z.isAttached()) {
        val ws = z.data().getParentWorkspace();
        if (ws.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
            if (!ws.isScopeActive()) {
                throw new ND4JIllegalStateException("Op [" + op.opName() + "] Z argument uses leaked workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
            }
            if (ws.getGenerationId() != z.data().getGenerationId())
                throw new ND4JIllegalStateException("Op [" + op.opName() + "] Z argument uses outdated workspace pointer from workspace [" + ws.getId() + "]\n" + SCOPE_PANIC_MSG);
        }
    }
}
Also used : lombok.val(lombok.val) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 70 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class OpExecutionerUtil method checkForNaN.

public static void checkForNaN(INDArray z) {
    if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.NAN_PANIC && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.ANY_PANIC)
        return;
    int match = 0;
    if (!z.isScalar()) {
        MatchCondition condition = new MatchCondition(z, Conditions.isNan());
        match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
    } else {
        if (z.data().dataType() == DataBuffer.Type.DOUBLE) {
            if (Double.isNaN(z.getDouble(0)))
                match = 1;
        } else {
            if (Float.isNaN(z.getFloat(0)))
                match = 1;
        }
    }
    if (match > 0)
        throw new ND4JIllegalStateException("P.A.N.I.C.! Op.Z() contains " + match + " NaN value(s): ");
}
Also used : MatchCondition(org.nd4j.linalg.api.ops.impl.accum.MatchCondition) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3