Search in sources :

Example 26 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class L2NormalizationTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new L12Normalizer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000)) {

            @Override
            public Layer getLayer() {
                return inner.getLayer();
            }

            @Override
            protected double getL1(final Layer layer) {
                return 0.0;
            }

            @Override
            protected double getL2(final Layer layer) {
                return 1e4;
            }
        };
        return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) L12Normalizer(com.simiacryptus.mindseye.eval.L12Normalizer) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer)

Example 27 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class NLayerTest method test.

/**
 * Test.
 *
 * @param log the log
 */
public void test(@Nonnull final NotebookOutput log) {
    log.h1("%s", getClass().getSimpleName());
    @Nonnull final int[] inputDims = getInputDims();
    @Nonnull final ArrayList<int[]> workingSpec = new ArrayList<>();
    for (final int[] l : dimList) {
        workingSpec.add(l);
        @Nonnull final Layer layer = buildNetwork(concat(inputDims, workingSpec));
        graphviz(log, layer);
        test(log, layer, inputDims);
    }
}
Also used : Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) Layer(com.simiacryptus.mindseye.lang.Layer)

Example 28 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class MnistTestBase method validate.

/**
 * Validate.
 *
 * @param log     the log
 * @param network the network
 */
public void validate(@Nonnull final NotebookOutput log, @Nonnull final Layer network) {
    log.h1("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return MNIST.validationDataStream().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        MNIST.validationDataStream().map(labeledObject -> {
            final int actualCategory = parse(labeledObject.label);
            @Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
            final int[] predictionList = IntStream.range(0, 10).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
            // We will only examine mispredicted rows
            if (predictionList[0] == actualCategory)
                return null;
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Image", log.image(labeledObject.data.toGrayImage(), labeledObject.label));
            row.put("Prediction", Arrays.stream(predictionList).limit(3).mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i])).reduce((a, b) -> a + ", " + b).get());
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) LabeledObject(com.simiacryptus.util.test.LabeledObject) Layer(com.simiacryptus.mindseye.lang.Layer) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MNIST(com.simiacryptus.mindseye.test.data.MNIST) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) Test(org.junit.Test) IOException(java.io.IOException) Category(org.junit.experimental.categories.Category) MonitoredObject(com.simiacryptus.util.MonitoredObject) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) List(java.util.List) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) TestCategories(com.simiacryptus.util.test.TestCategories) TableOutput(com.simiacryptus.util.TableOutput) Nonnull(javax.annotation.Nonnull) LinkedHashMap(java.util.LinkedHashMap)

Example 29 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class MnistTestBase method report.

/**
 * Report.
 *
 * @param log            the log
 * @param monitoringRoot the monitoring root
 * @param history        the history
 * @param network        the network
 */
public void report(@Nonnull final NotebookOutput log, @Nonnull final MonitoredObject monitoringRoot, @Nonnull final List<Step> history, @Nonnull final Layer network) {
    if (!history.isEmpty()) {
        log.code(() -> {
            @Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[] { step.iteration, Math.log10(step.point.getMean()) }).toArray(i -> new double[i][]));
            plot.setTitle("Convergence Plot");
            plot.setAxisLabels("Iteration", "log10(Fitness)");
            plot.setSize(600, 400);
            return plot;
        });
    }
    @Nonnull final String modelName = "model" + modelNo++ + ".json";
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h1("Metrics");
    log.code(() -> {
        try {
            @Nonnull final ByteArrayOutputStream out = new ByteArrayOutputStream();
            JsonUtil.writeJson(out, monitoringRoot.getMetrics());
            return out.toString();
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) LabeledObject(com.simiacryptus.util.test.LabeledObject) Layer(com.simiacryptus.mindseye.lang.Layer) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MNIST(com.simiacryptus.mindseye.test.data.MNIST) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) Test(org.junit.Test) IOException(java.io.IOException) Category(org.junit.experimental.categories.Category) MonitoredObject(com.simiacryptus.util.MonitoredObject) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) List(java.util.List) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) TestCategories(com.simiacryptus.util.test.TestCategories) Nonnull(javax.annotation.Nonnull) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) PlotCanvas(smile.plot.PlotCanvas)

Example 30 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class LBFGS method orient.

@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    // if (getClass().desiredAssertionStatus()) {
    // double verify = subject.measure(monitor).getMean();
    // double input = measurement.getMean();
    // boolean isDifferent = Math.abs(verify - input) > 1e-2;
    // if (isDifferent) throw new AssertionError(String.format("Invalid input point: %s != %s", verify, input));
    // monitor.log(String.format("Verified input point: %s == %s", verify, input));
    // }
    addToHistory(measurement, monitor);
    @Nonnull final List<PointSample> history = Arrays.asList(this.history.toArray(new PointSample[] {}));
    @Nullable final DeltaSet<Layer> result = lbfgs(measurement, monitor, history);
    SimpleLineSearchCursor returnValue;
    if (null == result) {
        @Nonnull DeltaSet<Layer> scale = measurement.delta.scale(-1);
        returnValue = cursor(subject, measurement, "GD", scale);
        scale.freeRef();
    } else {
        returnValue = cursor(subject, measurement, "LBFGS", result);
        result.freeRef();
    }
    while (this.history.size() > (null == result ? minHistory : maxHistory)) {
        @Nullable final PointSample remove = this.history.pollFirst();
        if (verbose) {
            monitor.log(String.format("Removed measurement %s to history. Total: %s", Long.toHexString(System.identityHashCode(remove)), history.size()));
        }
        remove.freeRef();
    }
    return returnValue;
}
Also used : Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)167 Nonnull (javax.annotation.Nonnull)159 Nullable (javax.annotation.Nullable)128 Arrays (java.util.Arrays)117 Tensor (com.simiacryptus.mindseye.lang.Tensor)116 List (java.util.List)108 Result (com.simiacryptus.mindseye.lang.Result)103 IntStream (java.util.stream.IntStream)98 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)95 TensorList (com.simiacryptus.mindseye.lang.TensorList)93 Map (java.util.Map)83 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)76 Logger (org.slf4j.Logger)76 LoggerFactory (org.slf4j.LoggerFactory)76 JsonObject (com.google.gson.JsonObject)70 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)66 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)64 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)51 Collectors (java.util.stream.Collectors)42 Stream (java.util.stream.Stream)37