Search in sources :

Example 16 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class StyleTransfer method getContentComponents.

/**
 * Gets content components.
 *
 * @param setup   the setup
 * @param nodeMap the node map
 * @return the content components
 */
@Nonnull
public ArrayList<Tuple2<Double, DAGNode>> getContentComponents(NeuralSetup<T> setup, final Map<T, DAGNode> nodeMap) {
    ArrayList<Tuple2<Double, DAGNode>> contentComponents = new ArrayList<>();
    for (final T layerType : getLayerTypes()) {
        final DAGNode node = nodeMap.get(layerType);
        final double coeff_content = !setup.style.content.params.containsKey(layerType) ? 0 : setup.style.content.params.get(layerType);
        final PipelineNetwork network1 = (PipelineNetwork) node.getNetwork();
        if (coeff_content != 0) {
            Tensor content = setup.contentTarget.content.get(layerType);
            contentComponents.add(new Tuple2<>(coeff_content, network1.wrap(new MeanSqLossLayer().setAlpha(1.0 / content.rms()), node, network1.wrap(new ValueLayer(content), new DAGNode[] {}))));
        }
    }
    return contentComponents;
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) Tuple2(com.simiacryptus.util.lang.Tuple2) ArrayList(java.util.ArrayList) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNode(com.simiacryptus.mindseye.network.DAGNode) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) Nonnull(javax.annotation.Nonnull)

Example 17 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class TrainingTester method train.

private List<StepRecord> train(@Nonnull NotebookOutput log, @Nonnull BiFunction<NotebookOutput, Trainable, List<StepRecord>> opt, @Nonnull Layer layer, @Nonnull Tensor[][] data, @Nonnull boolean... mask) {
    try {
        int inputs = data[0].length;
        @Nonnull final PipelineNetwork network = new PipelineNetwork(inputs);
        network.wrap(new MeanSqLossLayer(), network.add(layer, IntStream.range(0, inputs - 1).mapToObj(i -> network.getInput(i)).toArray(i -> new DAGNode[i])), network.getInput(inputs - 1));
        @Nonnull ArrayTrainable trainable = new ArrayTrainable(data, network);
        if (0 < mask.length)
            trainable.setMask(mask);
        List<StepRecord> history;
        try {
            history = opt.apply(log, trainable);
            if (history.stream().mapToDouble(x -> x.fitness).min().orElse(1) > 1e-5) {
                if (!network.isFrozen()) {
                    log.p("This training apply resulted in the following configuration:");
                    log.code(() -> {
                        return network.state().stream().map(Arrays::toString).reduce((a, b) -> a + "\n" + b).orElse("");
                    });
                }
                if (0 < mask.length) {
                    log.p("And regressed input:");
                    log.code(() -> {
                        return Arrays.stream(data).flatMap(x -> Arrays.stream(x)).limit(1).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
                    });
                }
                log.p("To produce the following output:");
                log.code(() -> {
                    Result[] array = ConstantResult.batchResultArray(pop(data));
                    @Nullable Result eval = layer.eval(array);
                    for (@Nonnull Result result : array) {
                        result.freeRef();
                        result.getData().freeRef();
                    }
                    TensorList tensorList = eval.getData();
                    eval.freeRef();
                    String str = tensorList.stream().limit(1).map(x -> {
                        String s = x.prettyPrint();
                        x.freeRef();
                        return s;
                    }).reduce((a, b) -> a + "\n" + b).orElse("");
                    tensorList.freeRef();
                    return str;
                });
            } else {
                log.p("Training Converged");
            }
        } finally {
            trainable.freeRef();
            network.freeRef();
        }
        return history;
    } finally {
        layer.freeRef();
        for (@Nonnull Tensor[] tensors : data) {
            for (@Nonnull Tensor tensor : tensors) {
                tensor.freeRef();
            }
        }
    }
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) StepRecord(com.simiacryptus.mindseye.test.StepRecord) Arrays(java.util.Arrays) Nullable(javax.annotation.Nullable)

Aggregations

DAGNode (com.simiacryptus.mindseye.network.DAGNode)17 Nonnull (javax.annotation.Nonnull)14 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)11 ArrayList (java.util.ArrayList)9 Nullable (javax.annotation.Nullable)8 Tensor (com.simiacryptus.mindseye.lang.Tensor)6 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)6 JsonObject (com.google.gson.JsonObject)5 Layer (com.simiacryptus.mindseye.lang.Layer)5 Arrays (java.util.Arrays)5 List (java.util.List)5 IntStream (java.util.stream.IntStream)5 Result (com.simiacryptus.mindseye.lang.Result)4 Map (java.util.Map)4 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)3 MeanSqLossLayer (com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer)3 ValueLayer (com.simiacryptus.mindseye.layers.cudnn.ValueLayer)3 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)2 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)2 TensorList (com.simiacryptus.mindseye.lang.TensorList)2