Search in sources :

Example 1 with StyleTransfer

use of com.simiacryptus.mindseye.applications.StyleTransfer in project MindsEye by SimiaCryptus.

the class StyleTransfer_VGG19 method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    StyleTransfer.VGG19 styleTransfer = new StyleTransfer.VGG19();
    init(log);
    Precision precision = Precision.Float;
    int imageSize = 400;
    styleTransfer.parallelLossFunctions = true;
    double growthFactor = Math.sqrt(1.5);
    CharSequence lakeAndForest = "H:\\SimiaCryptus\\Artistry\\Owned\\IMG_20170624_153541213-EFFECTS.jpg";
    String monkey = "H:\\SimiaCryptus\\Artistry\\capuchin-monkey-2759768_960_720.jpg";
    CharSequence vanGogh1 = "H:\\SimiaCryptus\\Artistry\\portraits\\vangogh\\Van_Gogh_-_Portrait_of_Pere_Tanguy_1887-8.jpg";
    CharSequence vanGogh2 = "H:\\SimiaCryptus\\Artistry\\portraits\\vangogh\\800px-Vincent_van_Gogh_-_Dr_Paul_Gachet_-_Google_Art_Project.jpg";
    CharSequence threeMusicians = "H:\\SimiaCryptus\\Artistry\\portraits\\picasso\\800px-Pablo_Picasso,_1921,_Nous_autres_musiciens_(Three_Musicians),_oil_on_canvas,_204.5_x_188.3_cm,_Philadelphia_Museum_of_Art.jpg";
    CharSequence maJolie = "H:\\SimiaCryptus\\Artistry\\portraits\\picasso\\Ma_Jolie_Pablo_Picasso.jpg";
    Map<List<CharSequence>, StyleTransfer.StyleCoefficients> styles = new HashMap<>();
    double coeff_mean = 1e1;
    double coeff_cov = 1e0;
    styles.put(Arrays.asList(// threeMusicians, maJolie
    vanGogh1, vanGogh2), new StyleTransfer.StyleCoefficients(StyleTransfer.CenteringMode.Origin).set(MultiLayerVGG19.LayerType.Layer_1b, coeff_mean, coeff_cov).set(MultiLayerVGG19.LayerType.Layer_1d, coeff_mean, coeff_cov));
    StyleTransfer.ContentCoefficients contentCoefficients = new StyleTransfer.ContentCoefficients().set(MultiLayerVGG19.LayerType.Layer_1c, 1e0);
    int trainingMinutes = 90;
    log.h1("Phase 0");
    BufferedImage canvasImage = ArtistryUtil.load(monkey, imageSize);
    canvasImage = TestUtil.resize(canvasImage, imageSize, true);
    canvasImage = TestUtil.resize(TestUtil.resize(canvasImage, 25, true), imageSize, true);
    // canvasImage = randomize(canvasImage, x -> 10 * (FastRandom.INSTANCE.random()) * (FastRandom.INSTANCE.random() < 0.9 ? 1 : 0));
    canvasImage = ArtistryUtil.randomize(canvasImage, x -> x + 2 * 1 * (FastRandom.INSTANCE.random() - 0.5));
    // canvasImage = randomize(canvasImage, x -> 10*(FastRandom.INSTANCE.random()-0.5));
    // canvasImage = randomize(canvasImage, x -> x*(FastRandom.INSTANCE.random()));
    Map<CharSequence, BufferedImage> styleImages = new HashMap<>();
    styleImages.clear();
    styleImages.putAll(styles.keySet().stream().flatMap(x -> x.stream()).collect(Collectors.toMap(x -> x, file -> ArtistryUtil.load(file))));
    StyleTransfer.StyleSetup styleSetup = new StyleTransfer.StyleSetup(precision, ArtistryUtil.load(monkey, canvasImage.getWidth(), canvasImage.getHeight()), contentCoefficients, styleImages, styles);
    StyleTransfer.NeuralSetup measureStyle = styleTransfer.measureStyle(styleSetup);
    canvasImage = styleTransfer.styleTransfer(server, log, canvasImage, styleSetup, trainingMinutes, measureStyle);
    for (int i = 1; i < 10; i++) {
        log.h1("Phase " + i);
        imageSize = (int) (imageSize * growthFactor);
        canvasImage = TestUtil.resize(canvasImage, imageSize, true);
        canvasImage = styleTransfer.styleTransfer(server, log, canvasImage, styleSetup, trainingMinutes, measureStyle);
    }
    log.setFrontMatterProperty("status", "OK");
}
Also used : Arrays(java.util.Arrays) BufferedImage(java.awt.image.BufferedImage) ArtistryUtil(com.simiacryptus.mindseye.applications.ArtistryUtil) StyleTransfer(com.simiacryptus.mindseye.applications.StyleTransfer) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) VGG19(com.simiacryptus.mindseye.models.VGG19) List(java.util.List) Map(java.util.Map) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) VGG19(com.simiacryptus.mindseye.models.VGG19) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) HashMap(java.util.HashMap) BufferedImage(java.awt.image.BufferedImage) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) List(java.util.List) StyleTransfer(com.simiacryptus.mindseye.applications.StyleTransfer)

Aggregations

ArtistryUtil (com.simiacryptus.mindseye.applications.ArtistryUtil)1 StyleTransfer (com.simiacryptus.mindseye.applications.StyleTransfer)1 Precision (com.simiacryptus.mindseye.lang.cudnn.Precision)1 MultiLayerVGG19 (com.simiacryptus.mindseye.models.MultiLayerVGG19)1 VGG19 (com.simiacryptus.mindseye.models.VGG19)1 TestUtil (com.simiacryptus.mindseye.test.TestUtil)1 FastRandom (com.simiacryptus.util.FastRandom)1 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)1 BufferedImage (java.awt.image.BufferedImage)1 Arrays (java.util.Arrays)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 Collectors (java.util.stream.Collectors)1 Nonnull (javax.annotation.Nonnull)1