Search in sources :

Example 21 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class TrainingTester method testInputLearning.

/**
 * Test input learning.
 *
 * @param log            the log
 * @param component      the component
 * @param random         the randomize
 * @param inputPrototype the input prototype
 * @return the apply result
 */
public TestResult testInputLearning(@Nonnull final NotebookOutput log, @Nonnull final Layer component, final Random random, @Nonnull final Tensor[] inputPrototype) {
    @Nonnull final Layer network = shuffle(random, component.copy()).freeze();
    final Tensor[][] input_target = shuffleCopy(random, inputPrototype);
    log.p("In this apply, we use a network to learn this target input, given it's pre-evaluated output:");
    log.code(() -> {
        return Arrays.stream(input_target).flatMap(x -> Arrays.stream(x)).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
    });
    Result[] array = ConstantResult.batchResultArray(input_target);
    @Nullable Result eval = network.eval(array);
    TensorList result = eval.getData();
    final Tensor[] output_target = result.stream().toArray(i -> new Tensor[i]);
    result.freeRef();
    eval.freeRef();
    if (output_target.length != getBatches()) {
        logger.info(String.format("Meta layers not supported. %d != %d", output_target.length, getBatches()));
        return null;
    }
    for (@Nonnull Result nnResult : array) {
        nnResult.getData().freeRef();
        nnResult.freeRef();
    }
    for (@Nonnull Tensor[] tensors : input_target) {
        for (@Nonnull Tensor tensor : tensors) {
            tensor.freeRef();
        }
    }
    // if (output_target.length != inputPrototype.length) return null;
    Tensor[][] trainingInput = append(shuffleCopy(random, inputPrototype), output_target);
    @Nonnull TestResult testResult = trainAll("Input Convergence", log, trainingInput, network, buildMask(inputPrototype.length));
    Arrays.stream(trainingInput).flatMap(x -> Arrays.stream(x)).forEach(x -> x.freeRef());
    return testResult;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult)

Example 22 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class TrainingTester method trainMagic.

/**
 * Train lbfgs list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainMagic(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("Now we train using an experimental optimizer:");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new StaticLearningRate(1.0)).setOrientation(new RecursiveSubspace() {

                @Override
                public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
                    @Nonnull Tensor[][] nullData = { { new Tensor() } };
                    @Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
                    @Nonnull ArrayTrainable trainable1 = new ArrayTrainable(inner, nullData);
                    inner.freeRef();
                    new IterativeTrainer(trainable1).setOrientation(new QQN()).setLineSearchFactory(n -> new QuadraticSearch().setCurrentRate(n.equals(QQN.CURSOR_NAME) ? 1.0 : 1e-4)).setMonitor(new TrainingMonitor() {

                        @Override
                        public void log(String msg) {
                            monitor.log("\t" + msg);
                        }
                    }).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).runAndFree();
                    trainable1.freeRef();
                    for (@Nonnull Tensor[] tensors : nullData) {
                        for (@Nonnull Tensor tensor : tensors) {
                            tensor.freeRef();
                        }
                    }
                }
            }).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) ArrayList(java.util.ArrayList) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) Nonnull(javax.annotation.Nonnull)

Example 23 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class TrainingTester method testCompleteLearning.

/**
 * Test complete learning apply result.
 *
 * @param log            the log
 * @param component      the component
 * @param random         the random
 * @param inputPrototype the input prototype
 * @return the apply result
 */
@Nonnull
public TestResult testCompleteLearning(@Nonnull final NotebookOutput log, @Nonnull final Layer component, final Random random, @Nonnull final Tensor[] inputPrototype) {
    @Nonnull final Layer network_target = shuffle(random, component.copy()).freeze();
    final Tensor[][] input_target = shuffleCopy(random, inputPrototype);
    log.p("In this apply, attempt to train a network to emulate a randomized network given an example input/output. The target state is:");
    log.code(() -> {
        return network_target.state().stream().map(Arrays::toString).reduce((a, b) -> a + "\n" + b).orElse("");
    });
    log.p("We simultaneously regress this target input:");
    log.code(() -> {
        return Arrays.stream(input_target).flatMap(x -> Arrays.stream(x)).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
    });
    log.p("Which produces the following output:");
    Result[] inputs = ConstantResult.batchResultArray(input_target);
    Result eval = network_target.eval(inputs);
    network_target.freeRef();
    Arrays.stream(inputs).forEach(ReferenceCounting::freeRef);
    TensorList result = eval.getData();
    eval.freeRef();
    final Tensor[] output_target = result.stream().toArray(i -> new Tensor[i]);
    log.code(() -> {
        return Stream.of(output_target).map(x -> x.prettyPrint()).reduce((a, b) -> a + "\n" + b).orElse("");
    });
    // if (output_target.length != inputPrototype.length) return null;
    return trainAll("Integrated Convergence", log, append(shuffleCopy(random, inputPrototype), output_target), shuffle(random, component.copy()), buildMask(inputPrototype.length));
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Arrays(java.util.Arrays) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Nonnull(javax.annotation.Nonnull)

Example 24 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class TrainingTester method test.

/**
 * Test.
 *
 * @param log            the log
 * @param component      the component
 * @param inputPrototype the input prototype
 */
@Override
public ComponentResult test(@Nonnull final NotebookOutput log, @Nonnull final Layer component, @Nonnull final Tensor... inputPrototype) {
    printHeader(log);
    final boolean testModel = !component.state().isEmpty();
    if (testModel && isZero(component.state().stream().flatMapToDouble(x1 -> Arrays.stream(x1)))) {
        throw new AssertionError("Weights are all zero?");
    }
    if (isZero(Arrays.stream(inputPrototype).flatMapToDouble(x -> Arrays.stream(x.getData())))) {
        throw new AssertionError("Inputs are all zero?");
    }
    @Nonnull final Random random = new Random();
    final boolean testInput = Arrays.stream(inputPrototype).anyMatch(x -> x.length() > 0);
    @Nullable TestResult inputLearning;
    if (testInput) {
        log.h2("Input Learning");
        inputLearning = testInputLearning(log, component, random, inputPrototype);
    } else {
        inputLearning = null;
    }
    @Nullable TestResult modelLearning;
    if (testModel) {
        log.h2("Model Learning");
        modelLearning = testModelLearning(log, component, random, inputPrototype);
    } else {
        modelLearning = null;
    }
    @Nullable TestResult completeLearning;
    if (testInput && testModel) {
        log.h2("Composite Learning");
        completeLearning = testCompleteLearning(log, component, random, inputPrototype);
    } else {
        completeLearning = null;
    }
    log.h2("Results");
    log.code(() -> {
        return grid(inputLearning, modelLearning, completeLearning);
    });
    ComponentResult result = log.code(() -> {
        return new ComponentResult(null == inputLearning ? null : inputLearning.value, null == modelLearning ? null : modelLearning.value, null == completeLearning ? null : completeLearning.value);
    });
    log.setFrontMatterProperty("training_analysis", result.toString());
    if (throwExceptions) {
        assert result.complete.map.values().stream().allMatch(x -> x.type == ResultType.Converged);
        assert result.input.map.values().stream().allMatch(x -> x.type == ResultType.Converged);
        assert result.model.map.values().stream().allMatch(x -> x.type == ResultType.Converged);
    }
    return result;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) Random(java.util.Random) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable)

Example 25 with NotebookOutput

use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.

the class StyleTransfer_VGG19 method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    StyleTransfer.VGG19 styleTransfer = new StyleTransfer.VGG19();
    init(log);
    Precision precision = Precision.Float;
    int imageSize = 400;
    styleTransfer.parallelLossFunctions = true;
    double growthFactor = Math.sqrt(1.5);
    CharSequence lakeAndForest = "H:\\SimiaCryptus\\Artistry\\Owned\\IMG_20170624_153541213-EFFECTS.jpg";
    String monkey = "H:\\SimiaCryptus\\Artistry\\capuchin-monkey-2759768_960_720.jpg";
    CharSequence vanGogh1 = "H:\\SimiaCryptus\\Artistry\\portraits\\vangogh\\Van_Gogh_-_Portrait_of_Pere_Tanguy_1887-8.jpg";
    CharSequence vanGogh2 = "H:\\SimiaCryptus\\Artistry\\portraits\\vangogh\\800px-Vincent_van_Gogh_-_Dr_Paul_Gachet_-_Google_Art_Project.jpg";
    CharSequence threeMusicians = "H:\\SimiaCryptus\\Artistry\\portraits\\picasso\\800px-Pablo_Picasso,_1921,_Nous_autres_musiciens_(Three_Musicians),_oil_on_canvas,_204.5_x_188.3_cm,_Philadelphia_Museum_of_Art.jpg";
    CharSequence maJolie = "H:\\SimiaCryptus\\Artistry\\portraits\\picasso\\Ma_Jolie_Pablo_Picasso.jpg";
    Map<List<CharSequence>, StyleTransfer.StyleCoefficients> styles = new HashMap<>();
    double coeff_mean = 1e1;
    double coeff_cov = 1e0;
    styles.put(Arrays.asList(// threeMusicians, maJolie
    vanGogh1, vanGogh2), new StyleTransfer.StyleCoefficients(StyleTransfer.CenteringMode.Origin).set(MultiLayerVGG19.LayerType.Layer_1b, coeff_mean, coeff_cov).set(MultiLayerVGG19.LayerType.Layer_1d, coeff_mean, coeff_cov));
    StyleTransfer.ContentCoefficients contentCoefficients = new StyleTransfer.ContentCoefficients().set(MultiLayerVGG19.LayerType.Layer_1c, 1e0);
    int trainingMinutes = 90;
    log.h1("Phase 0");
    BufferedImage canvasImage = ArtistryUtil.load(monkey, imageSize);
    canvasImage = TestUtil.resize(canvasImage, imageSize, true);
    canvasImage = TestUtil.resize(TestUtil.resize(canvasImage, 25, true), imageSize, true);
    // canvasImage = randomize(canvasImage, x -> 10 * (FastRandom.INSTANCE.random()) * (FastRandom.INSTANCE.random() < 0.9 ? 1 : 0));
    canvasImage = ArtistryUtil.randomize(canvasImage, x -> x + 2 * 1 * (FastRandom.INSTANCE.random() - 0.5));
    // canvasImage = randomize(canvasImage, x -> 10*(FastRandom.INSTANCE.random()-0.5));
    // canvasImage = randomize(canvasImage, x -> x*(FastRandom.INSTANCE.random()));
    Map<CharSequence, BufferedImage> styleImages = new HashMap<>();
    styleImages.clear();
    styleImages.putAll(styles.keySet().stream().flatMap(x -> x.stream()).collect(Collectors.toMap(x -> x, file -> ArtistryUtil.load(file))));
    StyleTransfer.StyleSetup styleSetup = new StyleTransfer.StyleSetup(precision, ArtistryUtil.load(monkey, canvasImage.getWidth(), canvasImage.getHeight()), contentCoefficients, styleImages, styles);
    StyleTransfer.NeuralSetup measureStyle = styleTransfer.measureStyle(styleSetup);
    canvasImage = styleTransfer.styleTransfer(server, log, canvasImage, styleSetup, trainingMinutes, measureStyle);
    for (int i = 1; i < 10; i++) {
        log.h1("Phase " + i);
        imageSize = (int) (imageSize * growthFactor);
        canvasImage = TestUtil.resize(canvasImage, imageSize, true);
        canvasImage = styleTransfer.styleTransfer(server, log, canvasImage, styleSetup, trainingMinutes, measureStyle);
    }
    log.setFrontMatterProperty("status", "OK");
}
Also used : Arrays(java.util.Arrays) BufferedImage(java.awt.image.BufferedImage) ArtistryUtil(com.simiacryptus.mindseye.applications.ArtistryUtil) StyleTransfer(com.simiacryptus.mindseye.applications.StyleTransfer) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) VGG19(com.simiacryptus.mindseye.models.VGG19) List(java.util.List) Map(java.util.Map) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) VGG19(com.simiacryptus.mindseye.models.VGG19) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) HashMap(java.util.HashMap) BufferedImage(java.awt.image.BufferedImage) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) List(java.util.List) StyleTransfer(com.simiacryptus.mindseye.applications.StyleTransfer)

Aggregations

NotebookOutput (com.simiacryptus.util.io.NotebookOutput)48 Nonnull (javax.annotation.Nonnull)48 Tensor (com.simiacryptus.mindseye.lang.Tensor)46 Nullable (javax.annotation.Nullable)40 Layer (com.simiacryptus.mindseye.lang.Layer)39 Arrays (java.util.Arrays)38 List (java.util.List)37 IntStream (java.util.stream.IntStream)31 TestUtil (com.simiacryptus.mindseye.test.TestUtil)25 Logger (org.slf4j.Logger)25 LoggerFactory (org.slf4j.LoggerFactory)25 Stream (java.util.stream.Stream)23 Collectors (java.util.stream.Collectors)22 ArrayList (java.util.ArrayList)21 HashMap (java.util.HashMap)21 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)20 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)19 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 TimeUnit (java.util.concurrent.TimeUnit)19 StepRecord (com.simiacryptus.mindseye.test.StepRecord)18