Search in sources :

Example 1 with QQN

use of com.simiacryptus.mindseye.opt.orient.QQN in project MindsEye by SimiaCryptus.

the class ImageClassifier method deepDream.

/**
 * Deep dream.
 *
 * @param log   the log
 * @param image the image
 */
public void deepDream(@Nonnull final NotebookOutput log, final Tensor image) {
    log.code(() -> {
        @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
        @Nonnull PipelineNetwork clamp = new PipelineNetwork(1);
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        @Nonnull PipelineNetwork supervised = new PipelineNetwork(1);
        supervised.add(getNetwork().freeze(), supervised.wrap(clamp, supervised.getInput(0)));
        // CudaTensorList gpuInput = CudnnHandle.apply(gpu -> {
        // Precision precision = Precision.Float;
        // return CudaTensorList.wrap(gpu.getPtr(TensorArray.wrap(image), precision, MemoryType.Managed), 1, image.getDimensions(), precision);
        // });
        // @Nonnull Trainable trainable = new TensorListTrainable(supervised, gpuInput).setVerbosity(1).setMask(true);
        @Nonnull Trainable trainable = new ArrayTrainable(supervised, 1).setVerbose(true).setMask(true, false).setData(Arrays.<Tensor[]>asList(new Tensor[] { image }));
        new IterativeTrainer(trainable).setMonitor(getTrainingMonitor(history, supervised)).setOrientation(new QQN()).setLineSearchFactory(name -> new ArmijoWolfeSearch()).setTimeout(60, TimeUnit.MINUTES).runAndFree();
        return TestUtil.plot(history);
    });
}
Also used : ActivationLayer(com.simiacryptus.mindseye.layers.cudnn.ActivationLayer) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) Tensor(com.simiacryptus.mindseye.lang.Tensor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Example 2 with QQN

use of com.simiacryptus.mindseye.opt.orient.QQN in project MindsEye by SimiaCryptus.

the class TrainingTester method trainMagic.

/**
 * Train lbfgs list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainMagic(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("Now we train using an experimental optimizer:");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new StaticLearningRate(1.0)).setOrientation(new RecursiveSubspace() {

                @Override
                public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
                    @Nonnull Tensor[][] nullData = { { new Tensor() } };
                    @Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
                    @Nonnull ArrayTrainable trainable1 = new ArrayTrainable(inner, nullData);
                    inner.freeRef();
                    new IterativeTrainer(trainable1).setOrientation(new QQN()).setLineSearchFactory(n -> new QuadraticSearch().setCurrentRate(n.equals(QQN.CURSOR_NAME) ? 1.0 : 1e-4)).setMonitor(new TrainingMonitor() {

                        @Override
                        public void log(String msg) {
                            monitor.log("\t" + msg);
                        }
                    }).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).runAndFree();
                    trainable1.freeRef();
                    for (@Nonnull Tensor[] tensors : nullData) {
                        for (@Nonnull Tensor tensor : tensors) {
                            tensor.freeRef();
                        }
                    }
                }
            }).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) ArrayList(java.util.ArrayList) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) Nonnull(javax.annotation.Nonnull)

Example 3 with QQN

use of com.simiacryptus.mindseye.opt.orient.QQN in project MindsEye by SimiaCryptus.

the class ImageClassifier method deepDream.

/**
 * Deep dream.
 *
 * @param log                 the log
 * @param image               the image
 * @param targetCategoryIndex the target category index
 * @param totalCategories     the total categories
 * @param config              the config
 * @param network             the network
 * @param lossLayer           the loss layer
 * @param targetValue         the target value
 */
public void deepDream(@Nonnull final NotebookOutput log, final Tensor image, final int targetCategoryIndex, final int totalCategories, Function<IterativeTrainer, IterativeTrainer> config, final Layer network, final Layer lossLayer, final double targetValue) {
    @Nonnull List<Tensor[]> data = Arrays.<Tensor[]>asList(new Tensor[] { image, new Tensor(1, 1, totalCategories).set(targetCategoryIndex, targetValue) });
    log.code(() -> {
        for (Tensor[] tensors : data) {
            ImageClassifier.log.info(log.image(tensors[0].toImage(), "") + tensors[1]);
        }
    });
    log.code(() -> {
        @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
        @Nonnull PipelineNetwork clamp = new PipelineNetwork(1);
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        @Nonnull PipelineNetwork supervised = new PipelineNetwork(2);
        supervised.wrap(lossLayer, supervised.add(network.freeze(), supervised.wrap(clamp, supervised.getInput(0))), supervised.getInput(1));
        // TensorList[] gpuInput = data.stream().map(data1 -> {
        // return CudnnHandle.apply(gpu -> {
        // Precision precision = Precision.Float;
        // return CudaTensorList.wrap(gpu.getPtr(TensorArray.wrap(data1), precision, MemoryType.Managed), 1, image.getDimensions(), precision);
        // });
        // }).toArray(i -> new TensorList[i]);
        // @Nonnull Trainable trainable = new TensorListTrainable(supervised, gpuInput).setVerbosity(1).setMask(true);
        @Nonnull Trainable trainable = new ArrayTrainable(supervised, 1).setVerbose(true).setMask(true, false).setData(data);
        config.apply(new IterativeTrainer(trainable).setMonitor(getTrainingMonitor(history, supervised)).setOrientation(new QQN()).setLineSearchFactory(name -> new ArmijoWolfeSearch()).setTimeout(60, TimeUnit.MINUTES)).setTerminateThreshold(Double.NEGATIVE_INFINITY).runAndFree();
        return TestUtil.plot(history);
    });
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) ActivationLayer(com.simiacryptus.mindseye.layers.cudnn.ActivationLayer) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Aggregations

ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)3 Trainable (com.simiacryptus.mindseye.eval.Trainable)3 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)3 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)3 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)3 QQN (com.simiacryptus.mindseye.opt.orient.QQN)3 StepRecord (com.simiacryptus.mindseye.test.StepRecord)3 ArrayList (java.util.ArrayList)3 Nonnull (javax.annotation.Nonnull)3 ActivationLayer (com.simiacryptus.mindseye.layers.cudnn.ActivationLayer)2 LinearActivationLayer (com.simiacryptus.mindseye.layers.java.LinearActivationLayer)2 BasicTrainable (com.simiacryptus.mindseye.eval.BasicTrainable)1 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)1 Result (com.simiacryptus.mindseye.lang.Result)1 TensorList (com.simiacryptus.mindseye.lang.TensorList)1 MeanSqLossLayer (com.simiacryptus.mindseye.layers.java.MeanSqLossLayer)1 DAGNode (com.simiacryptus.mindseye.network.DAGNode)1