Search in sources :

Example 1 with ProblemRun

use of com.simiacryptus.mindseye.test.ProblemRun in project MindsEye by SimiaCryptus.

the class TrainingTester method trainAll.

/**
 * Train all apply result.
 *
 * @param title         the title
 * @param log           the log
 * @param trainingInput the training input
 * @param layer         the layer
 * @param mask          the mask
 * @return the apply result
 */
@Nonnull
public TestResult trainAll(CharSequence title, @Nonnull NotebookOutput log, @Nonnull Tensor[][] trainingInput, @Nonnull Layer layer, boolean... mask) {
    try {
        log.h3("Gradient Descent");
        final List<StepRecord> gd = train(log, this::trainGD, layer.copy(), copy(trainingInput), mask);
        log.h3("Conjugate Gradient Descent");
        final List<StepRecord> cjgd = train(log, this::trainCjGD, layer.copy(), copy(trainingInput), mask);
        log.h3("Limited-Memory BFGS");
        final List<StepRecord> lbfgs = train(log, this::trainLBFGS, layer.copy(), copy(trainingInput), mask);
        log.h3("Experimental Optimizer");
        final List<StepRecord> magic = train(log, this::trainMagic, layer.copy(), copy(trainingInput), mask);
        @Nonnull final ProblemRun[] runs = { new ProblemRun("GD", gd, Color.GRAY, ProblemRun.PlotType.Line), new ProblemRun("CjGD", cjgd, Color.CYAN, ProblemRun.PlotType.Line), new ProblemRun("LBFGS", lbfgs, Color.GREEN, ProblemRun.PlotType.Line), new ProblemRun("Experimental", magic, Color.MAGENTA, ProblemRun.PlotType.Line) };
        @Nonnull ProblemResult result = new ProblemResult();
        result.put("GD", new TrainingResult(getResultType(gd), min(gd)));
        result.put("CjGD", new TrainingResult(getResultType(cjgd), min(cjgd)));
        result.put("LBFGS", new TrainingResult(getResultType(lbfgs), min(lbfgs)));
        result.put("Experimental", new TrainingResult(getResultType(magic), min(magic)));
        if (verbose) {
            final PlotCanvas iterPlot = log.code(() -> {
                return TestUtil.compare(title + " vs Iteration", runs);
            });
            final PlotCanvas timePlot = log.code(() -> {
                return TestUtil.compareTime(title + " vs Time", runs);
            });
            return new TestResult(iterPlot, timePlot, result);
        } else {
            @Nullable final PlotCanvas iterPlot = TestUtil.compare(title + " vs Iteration", runs);
            @Nullable final PlotCanvas timePlot = TestUtil.compareTime(title + " vs Time", runs);
            return new TestResult(iterPlot, timePlot, result);
        }
    } finally {
        layer.freeRef();
    }
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) PlotCanvas(smile.plot.PlotCanvas) Nonnull(javax.annotation.Nonnull)

Example 2 with ProblemRun

use of com.simiacryptus.mindseye.test.ProblemRun in project MindsEye by SimiaCryptus.

the class Research method compare.

@Override
public void compare(@Nonnull final NotebookOutput log, @Nonnull final Function<OptimizationStrategy, List<StepRecord>> test) {
    log.h1("Research Optimizer Comparison");
    log.h2("Recursive Subspace (Un-Normalized)");
    fwdFactory = MnistTests.fwd_conv_1;
    @Nonnull final ProblemRun subspace_1 = new ProblemRun("SS", test.apply(Research.recursive_subspace), Color.LIGHT_GRAY, ProblemRun.PlotType.Line);
    log.h2("Recursive Subspace (Un-Normalized)");
    fwdFactory = MnistTests.fwd_conv_1;
    @Nonnull final ProblemRun subspace_2 = new ProblemRun("SS+QQN", test.apply(Research.recursive_subspace_2), Color.RED, ProblemRun.PlotType.Line);
    log.h2("QQN (Normalized)");
    fwdFactory = MnistTests.fwd_conv_1_n;
    @Nonnull final ProblemRun qqn1 = new ProblemRun("QQN", test.apply(Research.quadratic_quasi_newton), Color.DARK_GRAY, ProblemRun.PlotType.Line);
    log.h2("L-BFGS (Strong Line Search) (Normalized)");
    fwdFactory = MnistTests.fwd_conv_1_n;
    @Nonnull final ProblemRun lbfgs_2 = new ProblemRun("LB-2", test.apply(Research.limited_memory_bfgs), Color.MAGENTA, ProblemRun.PlotType.Line);
    log.h2("L-BFGS (Normalized)");
    fwdFactory = MnistTests.fwd_conv_1_n;
    @Nonnull final ProblemRun lbfgs_1 = new ProblemRun("LB-1", test.apply(TextbookOptimizers.limited_memory_bfgs), Color.GREEN, ProblemRun.PlotType.Line);
    log.h2("L-BFGS-0 (Un-Normalized)");
    fwdFactory = MnistTests.fwd_conv_1;
    @Nonnull final ProblemRun rawlbfgs = new ProblemRun("LBFGS-0", test.apply(TextbookOptimizers.limited_memory_bfgs), Color.CYAN, ProblemRun.PlotType.Line);
    log.h2("Comparison");
    log.code(() -> {
        return TestUtil.compare("Convergence Plot", subspace_1, subspace_2, rawlbfgs, lbfgs_1, lbfgs_2, qqn1);
    });
    log.code(() -> {
        return TestUtil.compareTime("Convergence Plot", subspace_1, subspace_2, rawlbfgs, lbfgs_1, lbfgs_2, qqn1);
    });
}
Also used : ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) Nonnull(javax.annotation.Nonnull)

Example 3 with ProblemRun

use of com.simiacryptus.mindseye.test.ProblemRun in project MindsEye by SimiaCryptus.

the class TextbookOptimizers method compare.

@Override
public void compare(@Nonnull final NotebookOutput log, @Nonnull final Function<OptimizationStrategy, List<StepRecord>> test) {
    log.h1("Textbook Optimizer Comparison");
    log.h2("GD");
    @Nonnull final ProblemRun gd = new ProblemRun("GD", test.apply(TextbookOptimizers.simple_gradient_descent), Color.BLACK, ProblemRun.PlotType.Line);
    log.h2("SGD");
    @Nonnull final ProblemRun sgd = new ProblemRun("SGD", test.apply(TextbookOptimizers.stochastic_gradient_descent), Color.GREEN, ProblemRun.PlotType.Line);
    log.h2("CGD");
    @Nonnull final ProblemRun cgd = new ProblemRun("CjGD", test.apply(TextbookOptimizers.conjugate_gradient_descent), Color.BLUE, ProblemRun.PlotType.Line);
    log.h2("L-BFGS");
    @Nonnull final ProblemRun lbfgs = new ProblemRun("L-BFGS", test.apply(TextbookOptimizers.limited_memory_bfgs), Color.MAGENTA, ProblemRun.PlotType.Line);
    log.h2("OWL-QN");
    @Nonnull final ProblemRun owlqn = new ProblemRun("OWL-QN", test.apply(TextbookOptimizers.orthantwise_quasi_newton), Color.ORANGE, ProblemRun.PlotType.Line);
    log.h2("Comparison");
    log.code(() -> {
        return TestUtil.compare("Convergence Plot", gd, sgd, cgd, lbfgs, owlqn);
    });
    log.code(() -> {
        return TestUtil.compareTime("Convergence Plot", gd, sgd, cgd, lbfgs, owlqn);
    });
}
Also used : ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) Nonnull(javax.annotation.Nonnull)

Aggregations

ProblemRun (com.simiacryptus.mindseye.test.ProblemRun)3 Nonnull (javax.annotation.Nonnull)3 StepRecord (com.simiacryptus.mindseye.test.StepRecord)1 Nullable (javax.annotation.Nullable)1 PlotCanvas (smile.plot.PlotCanvas)1