Search in sources :

Example 6 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class StyleTransfer method getStyleComponents.

/**
 * Gets style components.
 *
 * @param node          the node
 * @param network       the network
 * @param styleParams   the style params
 * @param mean          the mean
 * @param covariance    the covariance
 * @param centeringMode the centering mode
 * @return the style components
 */
@Nonnull
public ArrayList<Tuple2<Double, DAGNode>> getStyleComponents(final DAGNode node, final PipelineNetwork network, final LayerStyleParams styleParams, final Tensor mean, final Tensor covariance, final CenteringMode centeringMode) {
    ArrayList<Tuple2<Double, DAGNode>> styleComponents = new ArrayList<>();
    if (null != styleParams && (styleParams.cov != 0 || styleParams.mean != 0)) {
        double meanRms = mean.rms();
        double meanScale = 0 == meanRms ? 1 : (1.0 / meanRms);
        InnerNode negTarget = network.wrap(new ValueLayer(mean.scale(-1)), new DAGNode[] {});
        InnerNode negAvg = network.wrap(new BandAvgReducerLayer().setAlpha(-1), node);
        if (styleParams.cov != 0) {
            DAGNode recentered;
            switch(centeringMode) {
                case Origin:
                    recentered = node;
                    break;
                case Dynamic:
                    recentered = network.wrap(new GateBiasLayer(), node, negAvg);
                    break;
                case Static:
                    recentered = network.wrap(new GateBiasLayer(), node, negTarget);
                    break;
                default:
                    throw new RuntimeException();
            }
            int[] covDim = covariance.getDimensions();
            assert 0 < covDim[2] : Arrays.toString(covDim);
            int inputBands = mean.getDimensions()[2];
            assert 0 < inputBands : Arrays.toString(mean.getDimensions());
            int outputBands = covDim[2] / inputBands;
            assert 0 < outputBands : Arrays.toString(covDim) + " / " + inputBands;
            double covRms = covariance.rms();
            double covScale = 0 == covRms ? 1 : (1.0 / covRms);
            styleComponents.add(new Tuple2<>(styleParams.cov, network.wrap(new MeanSqLossLayer().setAlpha(covScale), network.wrap(new ValueLayer(covariance), new DAGNode[] {}), network.wrap(ArtistryUtil.wrapTilesAvg(new GramianLayer()), recentered))));
        }
        if (styleParams.mean != 0) {
            styleComponents.add(new Tuple2<>(styleParams.mean, network.wrap(new MeanSqLossLayer().setAlpha(meanScale), negAvg, negTarget)));
        }
    }
    return styleComponents;
}
Also used : ArrayList(java.util.ArrayList) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) DAGNode(com.simiacryptus.mindseye.network.DAGNode) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) InnerNode(com.simiacryptus.mindseye.network.InnerNode) GramianLayer(com.simiacryptus.mindseye.layers.cudnn.GramianLayer) Tuple2(com.simiacryptus.util.lang.Tuple2) BandAvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.BandAvgReducerLayer) GateBiasLayer(com.simiacryptus.mindseye.layers.cudnn.GateBiasLayer) Nonnull(javax.annotation.Nonnull)

Example 7 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class EncodingProblem method run.

@Nonnull
@Override
public EncodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    Tensor[][] trainingData;
    try {
        trainingData = data.trainingData().map(labeledObject -> {
            return new Tensor[] { new Tensor(features).set(this::random), labeledObject.data };
        }).toArray(i -> new Tensor[i][]);
    } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
    }
    @Nonnull final DAGNetwork imageNetwork = revFactory.vectorToImage(log, features);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(imageNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final PipelineNetwork trainingNetwork = new PipelineNetwork(2);
    @Nullable final DAGNode image = trainingNetwork.add(imageNetwork, trainingNetwork.getInput(0));
    @Nullable final DAGNode softmax = trainingNetwork.add(new SoftmaxActivationLayer(), trainingNetwork.getInput(0));
    trainingNetwork.add(new SumInputsLayer(), trainingNetwork.add(new EntropyLossLayer(), softmax, softmax), trainingNetwork.add(new NthPowerActivationLayer().setPower(1.0 / 2.0), trainingNetwork.add(new MeanSqLossLayer(), image, trainingNetwork.getInput(1))));
    log.h3("Training");
    log.p("We start by training apply a very small population to improve initial convergence performance:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final Tensor[][] primingData = Arrays.copyOfRange(trainingData, 0, 1000);
    @Nonnull final ValidatingTrainer preTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(primingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(primingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        preTrainer.setTimeout(timeoutMinutes / 2, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    log.p("Then our main training phase:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final ValidatingTrainer mainTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(trainingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(trainingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        mainTrainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName().toString() + EncodingProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        log.appendFrontMatterProperty("result_plot", filename, ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // log.file()
    @Nonnull final String modelName = "encoding_model_" + EncodingProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(trainingNetwork.getJson().toString(), modelName, modelName));
    log.h3("Results");
    @Nonnull final PipelineNetwork testNetwork = new PipelineNetwork(2);
    testNetwork.add(imageNetwork, testNetwork.getInput(0));
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        Arrays.stream(trainingData).map(tensorArray -> {
            @Nullable final Tensor predictionSignal = testNetwork.eval(tensorArray).getData().get(0);
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Source", log.image(tensorArray[1].toImage(), ""));
            row.put("Echo", log.image(predictionSignal.toImage(), ""));
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Learned Model Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        trainingNetwork.state().stream().flatMapToDouble(x -> Arrays.stream(x)).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Learned Representation Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        Arrays.stream(trainingData).flatMapToDouble(row -> Arrays.stream(row[0].getData())).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = imageNetwork.eval(input).getData().get(0);
        TestUtil.renderToImages(tensor, true).forEach(img -> {
            log.out(log.image(img, ""));
        });
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) Format(guru.nidi.graphviz.engine.Format) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IOException(java.io.IOException) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNode(com.simiacryptus.mindseye.network.DAGNode) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 8 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class PolynomialNetwork method getHead.

@Override
public synchronized DAGNode getHead() {
    if (null == head) {
        synchronized (this) {
            if (null == head) {
                if (null == alpha) {
                    alpha = newSynapse(1e-8);
                    alphaBias = newBias(inputDims, 0.0);
                }
                reset();
                final DAGNode input = getInput(0);
                @Nonnull final ArrayList<DAGNode> terms = new ArrayList<>();
                terms.add(add(alpha, add(alphaBias, input)));
                for (@Nonnull final Correcton c : corrections) {
                    terms.add(c.add(input));
                }
                head = terms.size() == 1 ? terms.get(0) : add(newProductLayer(), terms.toArray(new DAGNode[] {}));
            }
        }
    }
    return head;
}
Also used : Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) DAGNode(com.simiacryptus.mindseye.network.DAGNode)

Example 9 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class PolynomialNetwork method getJson.

@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
    assertConsistent();
    @Nullable final DAGNode head = getHead();
    final JsonObject json = super.getJson(resources, dataSerializer);
    json.addProperty("head", head.getId().toString());
    if (null != alpha) {
        json.addProperty("alpha", alpha.getId().toString());
    }
    if (null != alphaBias) {
        json.addProperty("alphaBias", alpha.getId().toString());
    }
    json.add("inputDims", PolynomialNetwork.toJson(inputDims));
    json.add("outputDims", PolynomialNetwork.toJson(outputDims));
    @Nonnull final JsonArray elements = new JsonArray();
    for (@Nonnull final Correcton c : corrections) {
        elements.add(c.getJson());
    }
    json.add("corrections", elements);
    assert null != Layer.fromJson(json) : "Smoke apply deserialization";
    return json;
}
Also used : JsonArray(com.google.gson.JsonArray) Nonnull(javax.annotation.Nonnull) JsonObject(com.google.gson.JsonObject) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Nullable(javax.annotation.Nullable)

Example 10 with DAGNode

use of com.simiacryptus.mindseye.network.DAGNode in project MindsEye by SimiaCryptus.

the class RescaledSubnetLayer method eval.

@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
    assert 1 == inObj.length;
    final TensorList batch = inObj[0].getData();
    @Nonnull final int[] inputDims = batch.getDimensions();
    assert 3 == inputDims.length;
    if (1 == scale)
        return subnetwork.eval(inObj);
    @Nonnull final PipelineNetwork network = new PipelineNetwork();
    @Nullable final DAGNode condensed = network.wrap(new ImgReshapeLayer(scale, scale, false));
    network.wrap(new ImgConcatLayer(), IntStream.range(0, scale * scale).mapToObj(subband -> {
        @Nonnull final int[] select = new int[inputDims[2]];
        for (int i = 0; i < inputDims[2]; i++) {
            select[i] = subband * inputDims[2] + i;
        }
        return network.add(subnetwork, network.wrap(new ImgBandSelectLayer(select), condensed));
    }).toArray(i -> new DAGNode[i]));
    network.wrap(new ImgReshapeLayer(scale, scale, true));
    Result eval = network.eval(inObj);
    network.freeRef();
    return eval;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Result(com.simiacryptus.mindseye.lang.Result) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) ImgConcatLayer(com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) ImgConcatLayer(com.simiacryptus.mindseye.layers.cudnn.ImgConcatLayer) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) TensorList(com.simiacryptus.mindseye.lang.TensorList) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Aggregations

DAGNode (com.simiacryptus.mindseye.network.DAGNode)17 Nonnull (javax.annotation.Nonnull)14 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)11 ArrayList (java.util.ArrayList)9 Nullable (javax.annotation.Nullable)8 Tensor (com.simiacryptus.mindseye.lang.Tensor)6 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)6 JsonObject (com.google.gson.JsonObject)5 Layer (com.simiacryptus.mindseye.lang.Layer)5 Arrays (java.util.Arrays)5 List (java.util.List)5 IntStream (java.util.stream.IntStream)5 Result (com.simiacryptus.mindseye.lang.Result)4 Map (java.util.Map)4 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)3 MeanSqLossLayer (com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer)3 ValueLayer (com.simiacryptus.mindseye.layers.cudnn.ValueLayer)3 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)2 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)2 TensorList (com.simiacryptus.mindseye.lang.TensorList)2