Search in sources :

Example 11 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class StyleTransfer method styleTransfer.

/**
 * Style transfer buffered image.
 *
 * @param server          the server
 * @param log             the log
 * @param canvasImage     the canvas image
 * @param styleParameters the style parameters
 * @param trainingMinutes the training minutes
 * @param measureStyle    the measure style
 * @return the buffered image
 */
public BufferedImage styleTransfer(final StreamNanoHTTPD server, @Nonnull final NotebookOutput log, final BufferedImage canvasImage, final StyleSetup<T> styleParameters, final int trainingMinutes, final NeuralSetup measureStyle) {
    BufferedImage result = ArtistryUtil.logExceptionWithDefault(log, () -> {
        log.p("Input Content:");
        log.p(log.image(styleParameters.contentImage, "Content Image"));
        log.p("Style Content:");
        styleParameters.styleImages.forEach((file, styleImage) -> {
            log.p(log.image(styleImage, file));
        });
        log.p("Input Canvas:");
        log.p(log.image(canvasImage, "Input Canvas"));
        System.gc();
        Tensor canvas = Tensor.fromRGB(canvasImage);
        TestUtil.monitorImage(canvas, false, false);
        log.p("Input Parameters:");
        log.code(() -> {
            return ArtistryUtil.toJson(styleParameters);
        });
        Trainable trainable = log.code(() -> {
            PipelineNetwork network = fitnessNetwork(measureStyle);
            network.setFrozen(true);
            ArtistryUtil.setPrecision(network, styleParameters.precision);
            TestUtil.instrumentPerformance(network);
            if (null != server)
                ArtistryUtil.addLayersHandler(network, server);
            return new ArrayTrainable(network, 1).setVerbose(true).setMask(true).setData(Arrays.asList(new Tensor[][] { { canvas } }));
        });
        log.code(() -> {
            @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
            new IterativeTrainer(trainable).setMonitor(TestUtil.getMonitor(history)).setOrientation(new TrustRegionStrategy() {

                @Override
                public TrustRegion getRegionPolicy(final Layer layer) {
                    return new RangeConstraint().setMin(1e-2).setMax(256);
                }
            }).setIterationsPerSample(100).setLineSearchFactory(name -> new BisectionSearch().setSpanTol(1e-1).setCurrentRate(1e6)).setTimeout(trainingMinutes, TimeUnit.MINUTES).setTerminateThreshold(Double.NEGATIVE_INFINITY).runAndFree();
            return TestUtil.plot(history);
        });
        return canvas.toImage();
    }, canvasImage);
    log.p("Output Canvas:");
    log.p(log.image(result, "Output Canvas"));
    return result;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy) HashMap(java.util.HashMap) NullNotebookOutput(com.simiacryptus.util.io.NullNotebookOutput) MultiLayerImageNetwork(com.simiacryptus.mindseye.models.MultiLayerImageNetwork) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) Trainable(com.simiacryptus.mindseye.eval.Trainable) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Tuple2(com.simiacryptus.util.lang.Tuple2) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) GateBiasLayer(com.simiacryptus.mindseye.layers.cudnn.GateBiasLayer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Logger(org.slf4j.Logger) BufferedImage(java.awt.image.BufferedImage) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) TestUtil(com.simiacryptus.mindseye.test.TestUtil) UUID(java.util.UUID) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Collectors(java.util.stream.Collectors) BandAvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.BandAvgReducerLayer) StreamNanoHTTPD(com.simiacryptus.util.StreamNanoHTTPD) TimeUnit(java.util.concurrent.TimeUnit) BisectionSearch(com.simiacryptus.mindseye.opt.line.BisectionSearch) List(java.util.List) GramianLayer(com.simiacryptus.mindseye.layers.cudnn.GramianLayer) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) BinarySumLayer(com.simiacryptus.mindseye.layers.cudnn.BinarySumLayer) InnerNode(com.simiacryptus.mindseye.network.InnerNode) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) MultiLayerVGG16(com.simiacryptus.mindseye.models.MultiLayerVGG16) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) LayerEnum(com.simiacryptus.mindseye.models.LayerEnum) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) Tensor(com.simiacryptus.mindseye.lang.Tensor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) GateBiasLayer(com.simiacryptus.mindseye.layers.cudnn.GateBiasLayer) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) BandAvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.BandAvgReducerLayer) GramianLayer(com.simiacryptus.mindseye.layers.cudnn.GramianLayer) BinarySumLayer(com.simiacryptus.mindseye.layers.cudnn.BinarySumLayer) BufferedImage(java.awt.image.BufferedImage) StepRecord(com.simiacryptus.mindseye.test.StepRecord) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) BisectionSearch(com.simiacryptus.mindseye.opt.line.BisectionSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Example 12 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class TrainingTester method trainMagic.

/**
 * Train lbfgs list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainMagic(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("Now we train using an experimental optimizer:");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new StaticLearningRate(1.0)).setOrientation(new RecursiveSubspace() {

                @Override
                public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
                    @Nonnull Tensor[][] nullData = { { new Tensor() } };
                    @Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
                    @Nonnull ArrayTrainable trainable1 = new ArrayTrainable(inner, nullData);
                    inner.freeRef();
                    new IterativeTrainer(trainable1).setOrientation(new QQN()).setLineSearchFactory(n -> new QuadraticSearch().setCurrentRate(n.equals(QQN.CURSOR_NAME) ? 1.0 : 1e-4)).setMonitor(new TrainingMonitor() {

                        @Override
                        public void log(String msg) {
                            monitor.log("\t" + msg);
                        }
                    }).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).runAndFree();
                    trainable1.freeRef();
                    for (@Nonnull Tensor[] tensors : nullData) {
                        for (@Nonnull Tensor tensor : tensors) {
                            tensor.freeRef();
                        }
                    }
                }
            }).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) ArrayList(java.util.ArrayList) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) Nonnull(javax.annotation.Nonnull)

Example 13 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Example 14 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class TrustRegionStrategy method orient.

@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
    history.add(0, origin);
    while (history.size() > maxHistory) {
        history.remove(history.size() - 1);
    }
    final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
    return new LineSearchCursorBase() {

        @Nonnull
        @Override
        public CharSequence getDirectionType() {
            return cursor.getDirectionType() + "+Trust";
        }

        @Nonnull
        @Override
        public DeltaSet<Layer> position(final double alpha) {
            reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            project(adjustedPosVector, new TrainingMonitor());
            return adjustedPosVector;
        }

        @Nonnull
        public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
            final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
            @Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
            deltaIn.getMap().forEach((layer, buffer) -> {
                @Nullable final double[] delta = buffer.getDelta();
                if (null == delta)
                    return;
                final double[] currentPosition = buffer.target;
                @Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
                @Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
                final TrustRegion region = getRegionPolicy(layer);
                if (null != region) {
                    final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
                        final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
                        @Nullable final double[] z = null == d ? null : d.getDelta();
                        return z;
                    });
                    final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
                    if (projectedPosition != proposedPosition) {
                        for (int i = 0; i < projectedPosition.length; i++) {
                            delta[i] = projectedPosition[i] - currentPosition[i];
                        }
                        @Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
                        final double normalMagSq = ArrayUtil.dot(normal, normal);
                        // normalMagSq));
                        if (0 < normalMagSq) {
                            final double a = ArrayUtil.dot(originalAlphaD, normal);
                            if (a != -1) {
                                @Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
                                for (int i = 0; i < tangent.length; i++) {
                                    newAlphaD[i] = tangent[i];
                                }
                            // double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
                            // double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
                            // assert(newAlphaDerivSq <= originalAlphaDerivSq);
                            // assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
                            // monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
                            }
                        }
                    }
                }
            });
            return newAlphaDerivative;
        }

        @Override
        public void reset() {
            cursor.reset();
        }

        @Nonnull
        @Override
        public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
            cursor.reset();
            @Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
            @Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
            adjustedPosVector.accumulate(1);
            @Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
            return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
        }

        @Override
        public void _free() {
            cursor.freeRef();
        }
    };
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) IntStream(java.util.stream.IntStream) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) ArrayUtil(com.simiacryptus.util.ArrayUtil) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) LinkedList(java.util.LinkedList) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 15 with Trainable

use of com.simiacryptus.mindseye.eval.Trainable in project MindsEye by SimiaCryptus.

the class ImageClassifier method deepDream.

/**
 * Deep dream.
 *
 * @param log                 the log
 * @param image               the image
 * @param targetCategoryIndex the target category index
 * @param totalCategories     the total categories
 * @param config              the config
 * @param network             the network
 * @param lossLayer           the loss layer
 * @param targetValue         the target value
 */
public void deepDream(@Nonnull final NotebookOutput log, final Tensor image, final int targetCategoryIndex, final int totalCategories, Function<IterativeTrainer, IterativeTrainer> config, final Layer network, final Layer lossLayer, final double targetValue) {
    @Nonnull List<Tensor[]> data = Arrays.<Tensor[]>asList(new Tensor[] { image, new Tensor(1, 1, totalCategories).set(targetCategoryIndex, targetValue) });
    log.code(() -> {
        for (Tensor[] tensors : data) {
            ImageClassifier.log.info(log.image(tensors[0].toImage(), "") + tensors[1]);
        }
    });
    log.code(() -> {
        @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
        @Nonnull PipelineNetwork clamp = new PipelineNetwork(1);
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        @Nonnull PipelineNetwork supervised = new PipelineNetwork(2);
        supervised.wrap(lossLayer, supervised.add(network.freeze(), supervised.wrap(clamp, supervised.getInput(0))), supervised.getInput(1));
        // TensorList[] gpuInput = data.stream().map(data1 -> {
        // return CudnnHandle.apply(gpu -> {
        // Precision precision = Precision.Float;
        // return CudaTensorList.wrap(gpu.getPtr(TensorArray.wrap(data1), precision, MemoryType.Managed), 1, image.getDimensions(), precision);
        // });
        // }).toArray(i -> new TensorList[i]);
        // @Nonnull Trainable trainable = new TensorListTrainable(supervised, gpuInput).setVerbosity(1).setMask(true);
        @Nonnull Trainable trainable = new ArrayTrainable(supervised, 1).setVerbose(true).setMask(true, false).setData(data);
        config.apply(new IterativeTrainer(trainable).setMonitor(getTrainingMonitor(history, supervised)).setOrientation(new QQN()).setLineSearchFactory(name -> new ArmijoWolfeSearch()).setTimeout(60, TimeUnit.MINUTES)).setTerminateThreshold(Double.NEGATIVE_INFINITY).runAndFree();
        return TestUtil.plot(history);
    });
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) ActivationLayer(com.simiacryptus.mindseye.layers.cudnn.ActivationLayer) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Aggregations

Trainable (com.simiacryptus.mindseye.eval.Trainable)25 Nonnull (javax.annotation.Nonnull)25 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)22 Layer (com.simiacryptus.mindseye.lang.Layer)15 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)13 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)13 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)12 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)10 Tensor (com.simiacryptus.mindseye.lang.Tensor)9 List (java.util.List)9 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)8 ArrayList (java.util.ArrayList)8 Map (java.util.Map)8 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)7 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)7 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)7 StepRecord (com.simiacryptus.mindseye.test.StepRecord)7 Arrays (java.util.Arrays)7 IntStream (java.util.stream.IntStream)7 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)5