Search in sources :

Example 1 with StateSet

use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.

the class BasicTrainable method eval.

/**
 * Eval point sample.
 *
 * @param list    the list
 * @param monitor the monitor
 * @return the point sample
 */
@Nonnull
protected PointSample eval(@Nonnull final List<Tensor[]> list, @Nullable final TrainingMonitor monitor) {
    @Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
        final Result[] nnContext = BasicTrainable.getNNContext(list, mask);
        final Result result = network.eval(nnContext);
        for (@Nonnull Result nnResult : nnContext) {
            nnResult.getData().freeRef();
            nnResult.freeRef();
        }
        final TensorList resultData = result.getData();
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull StateSet<Layer> stateSet = null;
        try {
            final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
                double[] array = Arrays.stream(x.getData()).toArray();
                x.freeRef();
                return Arrays.stream(array);
            }).summaryStatistics();
            final double sum = statistics.getSum();
            result.accumulate(deltaSet, 1.0);
            stateSet = new StateSet<>(deltaSet);
            // log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
            return new PointSample(deltaSet, stateSet, sum, 0.0, list.size());
        } finally {
            if (null != stateSet)
                stateSet.freeRef();
            resultData.freeRefAsync();
            result.freeRefAsync();
            deltaSet.freeRefAsync();
        }
    });
    if (null != monitor && verbosity() > 0) {
        monitor.log(String.format("Device completed %s items in %.3f sec", list.size(), timedResult.timeNanos / 1e9));
    }
    @Nonnull PointSample normalize = timedResult.result.normalize();
    timedResult.result.freeRef();
    return normalize;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull)

Example 2 with StateSet

use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.

the class SparkTrainable method measure.

@Override
public PointSample measure(final TrainingMonitor monitor) {
    final long time1 = System.nanoTime();
    final JavaRDD<ReducableResult> mapPartitions = sampledRDD.toJavaRDD().mapPartitions(new PartitionTask(network));
    final long time2 = System.nanoTime();
    final SparkTrainable.ReducableResult result = mapPartitions.reduce(SparkTrainable.ReducableResult::add);
    if (isVerbose()) {
        log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
    }
    @Nonnull final DeltaSet<Layer> xxx = getDelta(result);
    return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
Also used : Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Layer(com.simiacryptus.mindseye.lang.Layer) StateSet(com.simiacryptus.mindseye.lang.StateSet)

Example 3 with StateSet

use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.

the class TensorListTrainable method eval.

/**
 * Eval point sample.
 *
 * @param list    the list
 * @param monitor the monitor
 * @return the point sample
 */
@Nonnull
protected PointSample eval(@Nonnull final TensorList[] list, @Nullable final TrainingMonitor monitor) {
    int inputs = data.length;
    assert 0 < inputs;
    int items = data[0].length();
    assert 0 < items;
    @Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
        final Result[] nnContext = TensorListTrainable.getNNContext(list, mask);
        final Result result = network.eval(nnContext);
        for (@Nonnull Result nnResult : nnContext) {
            nnResult.getData().freeRef();
            nnResult.freeRef();
        }
        final TensorList resultData = result.getData();
        final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
            double[] array = Arrays.stream(x.getData()).toArray();
            x.freeRef();
            return Arrays.stream(array);
        }).summaryStatistics();
        final double sum = statistics.getSum();
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull PointSample pointSample;
        try {
            result.accumulate(deltaSet, 1.0);
            // log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
            @Nonnull StateSet<Layer> stateSet = new StateSet<>(deltaSet);
            pointSample = new PointSample(deltaSet, stateSet, sum, 0.0, items);
            stateSet.freeRef();
        } finally {
            resultData.freeRef();
            result.freeRef();
            deltaSet.freeRef();
        }
        return pointSample;
    });
    if (null != monitor && verbosity() > 0) {
        monitor.log(String.format("Device completed %s items in %.3f sec", items, timedResult.timeNanos / 1e9));
    }
    @Nonnull PointSample normalize = timedResult.result.normalize();
    timedResult.result.freeRef();
    return normalize;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) PointSample(com.simiacryptus.mindseye.lang.PointSample) StateSet(com.simiacryptus.mindseye.lang.StateSet) Nonnull(javax.annotation.Nonnull)

Example 4 with StateSet

use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Example 5 with StateSet

use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.

the class LocalSparkTrainable method measure.

@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
    final long time1 = System.nanoTime();
    final JavaRDD<Tensor[]> javaRDD = sampledRDD.toJavaRDD();
    assert !javaRDD.isEmpty();
    final List<ReducableResult> mapPartitions = javaRDD.partitions().stream().map(partition -> {
        try {
            final List<Tensor[]>[] array = javaRDD.collectPartitions(new int[] { partition.index() });
            assert 0 < array.length;
            if (0 == Arrays.stream(array).mapToInt((@Nonnull final List<Tensor[]> x) -> x.size()).sum()) {
                return null;
            }
            assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
            final Stream<Tensor[]> stream = Arrays.stream(array).flatMap(i -> i.stream());
            @Nonnull final Iterator<Tensor[]> iterator = stream.iterator();
            return new PartitionTask(network).call(iterator).next();
        } catch (@Nonnull final RuntimeException e) {
            throw e;
        } catch (@Nonnull final Exception e) {
            throw new RuntimeException(e);
        }
    }).filter(x -> null != x).collect(Collectors.toList());
    final long time2 = System.nanoTime();
    @Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
    if (isVerbose()) {
        log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
    }
    @Nonnull final DeltaSet<Layer> xxx = getDelta(result);
    return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
Also used : Arrays(java.util.Arrays) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) List(java.util.List) Stream(java.util.stream.Stream) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Layer(com.simiacryptus.mindseye.lang.Layer) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) RDD(org.apache.spark.rdd.RDD) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) JavaRDD(org.apache.spark.api.java.JavaRDD) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) List(java.util.List) PointSample(com.simiacryptus.mindseye.lang.PointSample) StateSet(com.simiacryptus.mindseye.lang.StateSet) Nonnull(javax.annotation.Nonnull)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)5 PointSample (com.simiacryptus.mindseye.lang.PointSample)5 StateSet (com.simiacryptus.mindseye.lang.StateSet)5 Nonnull (javax.annotation.Nonnull)5 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)4 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)4 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 Arrays (java.util.Arrays)3 List (java.util.List)3 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)2 ReferenceCountingBase (com.simiacryptus.mindseye.lang.ReferenceCountingBase)2 Result (com.simiacryptus.mindseye.lang.Result)2 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)2 TensorList (com.simiacryptus.mindseye.lang.TensorList)2 TimedResult (com.simiacryptus.util.lang.TimedResult)2 DoubleSummaryStatistics (java.util.DoubleSummaryStatistics)2 Collectors (java.util.stream.Collectors)2 IntStream (java.util.stream.IntStream)2 Nullable (javax.annotation.Nullable)2 Trainable (com.simiacryptus.mindseye.eval.Trainable)1