Search in sources :

Example 6 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method orient.

@Nonnull
@Override
public SimpleLineSearchCursor orient(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nullable Layer macroLayer = buildSubspace(subject, measurement, monitor);
    train(monitor, macroLayer);
    Result eval = macroLayer.eval((Result) null);
    macroLayer.freeRef();
    eval.getData().freeRef();
    eval.freeRef();
    @Nonnull StateSet<Layer> backupCopy = origin.weights.backupCopy();
    @Nonnull DeltaSet<Layer> delta = backupCopy.subtract(origin.weights);
    backupCopy.freeRef();
    origin.restore();
    @Nonnull SimpleLineSearchCursor simpleLineSearchCursor = new SimpleLineSearchCursor(subject, origin, delta);
    delta.freeRef();
    origin.freeRef();
    return simpleLineSearchCursor.setDirectionType(CURSOR_LABEL);
}
Also used : Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 7 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class BasicTrainable method eval.

/**
 * Eval point sample.
 *
 * @param list    the list
 * @param monitor the monitor
 * @return the point sample
 */
@Nonnull
protected PointSample eval(@Nonnull final List<Tensor[]> list, @Nullable final TrainingMonitor monitor) {
    @Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
        final Result[] nnContext = BasicTrainable.getNNContext(list, mask);
        final Result result = network.eval(nnContext);
        for (@Nonnull Result nnResult : nnContext) {
            nnResult.getData().freeRef();
            nnResult.freeRef();
        }
        final TensorList resultData = result.getData();
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull StateSet<Layer> stateSet = null;
        try {
            final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
                double[] array = Arrays.stream(x.getData()).toArray();
                x.freeRef();
                return Arrays.stream(array);
            }).summaryStatistics();
            final double sum = statistics.getSum();
            result.accumulate(deltaSet, 1.0);
            stateSet = new StateSet<>(deltaSet);
            // log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
            return new PointSample(deltaSet, stateSet, sum, 0.0, list.size());
        } finally {
            if (null != stateSet)
                stateSet.freeRef();
            resultData.freeRefAsync();
            result.freeRefAsync();
            deltaSet.freeRefAsync();
        }
    });
    if (null != monitor && verbosity() > 0) {
        monitor.log(String.format("Device completed %s items in %.3f sec", list.size(), timedResult.timeNanos / 1e9));
    }
    @Nonnull PointSample normalize = timedResult.result.normalize();
    timedResult.result.freeRef();
    return normalize;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull)

Example 8 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class CachedTrainable method measure.

@Override
public PointSample measure(final TrainingMonitor monitor) {
    for (@Nonnull final PointSample result : history) {
        if (!result.weights.isDifferent()) {
            if (isVerbose()) {
                log.info(String.format("Returning cached value; %s buffers unchanged since %s => %s", result.weights.getMap().size(), result.rate, result.getMean()));
            }
            return result.copyFull();
        }
    }
    final PointSample result = super.measure(monitor);
    history.add(result.copyFull());
    while (getHistorySize() < history.size()) {
        history.remove(0);
    }
    return result;
}
Also used : Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 9 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class L12Normalizer method measure.

@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
    final PointSample innerMeasure = inner.measure(monitor);
    @Nonnull final DeltaSet<Layer> normalizationVector = new DeltaSet<Layer>();
    double valueAdj = 0;
    for (@Nonnull final Layer layer : getLayers(innerMeasure.delta.getMap().keySet())) {
        final double[] weights = innerMeasure.delta.getMap().get(layer).target;
        @Nullable final double[] gradientAdj = normalizationVector.get(layer, weights).getDelta();
        final double factor_L1 = getL1(layer);
        final double factor_L2 = getL2(layer);
        assert null != gradientAdj;
        for (int i = 0; i < gradientAdj.length; i++) {
            final double sign = weights[i] < 0 ? -1.0 : 1.0;
            gradientAdj[i] += factor_L1 * sign + 2 * factor_L2 * weights[i];
            valueAdj += (factor_L1 * sign + factor_L2 * weights[i]) * weights[i];
        }
    }
    return new PointSample(innerMeasure.delta.add(normalizationVector), innerMeasure.weights, innerMeasure.sum + (hideAdj ? 0 : valueAdj), innerMeasure.rate, innerMeasure.count).normalize();
}
Also used : Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) PointSample(com.simiacryptus.mindseye.lang.PointSample) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 10 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class SparkTrainable method measure.

@Override
public PointSample measure(final TrainingMonitor monitor) {
    final long time1 = System.nanoTime();
    final JavaRDD<ReducableResult> mapPartitions = sampledRDD.toJavaRDD().mapPartitions(new PartitionTask(network));
    final long time2 = System.nanoTime();
    final SparkTrainable.ReducableResult result = mapPartitions.reduce(SparkTrainable.ReducableResult::add);
    if (isVerbose()) {
        log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
    }
    @Nonnull final DeltaSet<Layer> xxx = getDelta(result);
    return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
Also used : Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Layer(com.simiacryptus.mindseye.lang.Layer) StateSet(com.simiacryptus.mindseye.lang.StateSet)

Aggregations

PointSample (com.simiacryptus.mindseye.lang.PointSample)33 Nonnull (javax.annotation.Nonnull)24 Layer (com.simiacryptus.mindseye.lang.Layer)16 Nullable (javax.annotation.Nullable)14 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)10 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)9 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)9 StateSet (com.simiacryptus.mindseye.lang.StateSet)8 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)8 List (java.util.List)8 Trainable (com.simiacryptus.mindseye.eval.Trainable)7 Arrays (java.util.Arrays)7 Collectors (java.util.stream.Collectors)7 IterativeStopException (com.simiacryptus.mindseye.lang.IterativeStopException)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)5 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)5 FailsafeLineSearchCursor (com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)5 LineSearchStrategy (com.simiacryptus.mindseye.opt.line.LineSearchStrategy)5