use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.
the class BasicTrainable method eval.
/**
* Eval point sample.
*
* @param list the list
* @param monitor the monitor
* @return the point sample
*/
@Nonnull
protected PointSample eval(@Nonnull final List<Tensor[]> list, @Nullable final TrainingMonitor monitor) {
@Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
final Result[] nnContext = BasicTrainable.getNNContext(list, mask);
final Result result = network.eval(nnContext);
for (@Nonnull Result nnResult : nnContext) {
nnResult.getData().freeRef();
nnResult.freeRef();
}
final TensorList resultData = result.getData();
@Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
@Nonnull StateSet<Layer> stateSet = null;
try {
final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
double[] array = Arrays.stream(x.getData()).toArray();
x.freeRef();
return Arrays.stream(array);
}).summaryStatistics();
final double sum = statistics.getSum();
result.accumulate(deltaSet, 1.0);
stateSet = new StateSet<>(deltaSet);
// log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
return new PointSample(deltaSet, stateSet, sum, 0.0, list.size());
} finally {
if (null != stateSet)
stateSet.freeRef();
resultData.freeRefAsync();
result.freeRefAsync();
deltaSet.freeRefAsync();
}
});
if (null != monitor && verbosity() > 0) {
monitor.log(String.format("Device completed %s items in %.3f sec", list.size(), timedResult.timeNanos / 1e9));
}
@Nonnull PointSample normalize = timedResult.result.normalize();
timedResult.result.freeRef();
return normalize;
}
use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.
the class SparkTrainable method measure.
@Override
public PointSample measure(final TrainingMonitor monitor) {
final long time1 = System.nanoTime();
final JavaRDD<ReducableResult> mapPartitions = sampledRDD.toJavaRDD().mapPartitions(new PartitionTask(network));
final long time2 = System.nanoTime();
final SparkTrainable.ReducableResult result = mapPartitions.reduce(SparkTrainable.ReducableResult::add);
if (isVerbose()) {
log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
}
@Nonnull final DeltaSet<Layer> xxx = getDelta(result);
return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.
the class TensorListTrainable method eval.
/**
* Eval point sample.
*
* @param list the list
* @param monitor the monitor
* @return the point sample
*/
@Nonnull
protected PointSample eval(@Nonnull final TensorList[] list, @Nullable final TrainingMonitor monitor) {
int inputs = data.length;
assert 0 < inputs;
int items = data[0].length();
assert 0 < items;
@Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
final Result[] nnContext = TensorListTrainable.getNNContext(list, mask);
final Result result = network.eval(nnContext);
for (@Nonnull Result nnResult : nnContext) {
nnResult.getData().freeRef();
nnResult.freeRef();
}
final TensorList resultData = result.getData();
final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
double[] array = Arrays.stream(x.getData()).toArray();
x.freeRef();
return Arrays.stream(array);
}).summaryStatistics();
final double sum = statistics.getSum();
@Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
@Nonnull PointSample pointSample;
try {
result.accumulate(deltaSet, 1.0);
// log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
@Nonnull StateSet<Layer> stateSet = new StateSet<>(deltaSet);
pointSample = new PointSample(deltaSet, stateSet, sum, 0.0, items);
stateSet.freeRef();
} finally {
resultData.freeRef();
result.freeRef();
deltaSet.freeRef();
}
return pointSample;
});
if (null != monitor && verbosity() > 0) {
monitor.log(String.format("Device completed %s items in %.3f sec", items, timedResult.timeNanos / 1e9));
}
@Nonnull PointSample normalize = timedResult.result.normalize();
timedResult.result.freeRef();
return normalize;
}
use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.
the class QuantifyOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
final double denominator = weightDelta.deltaStatistics().rms();
final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
return numerator / (0 == denominator ? 1 : denominator);
}).collect(Collectors.toList());
if (1 == doubleList.size())
return Double.toString(doubleList.get(0));
return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
}));
monitor.log(String.format("Line search stats: %s", dataMap));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
use of com.simiacryptus.mindseye.lang.StateSet in project MindsEye by SimiaCryptus.
the class LocalSparkTrainable method measure.
@Nonnull
@Override
public PointSample measure(final TrainingMonitor monitor) {
final long time1 = System.nanoTime();
final JavaRDD<Tensor[]> javaRDD = sampledRDD.toJavaRDD();
assert !javaRDD.isEmpty();
final List<ReducableResult> mapPartitions = javaRDD.partitions().stream().map(partition -> {
try {
final List<Tensor[]>[] array = javaRDD.collectPartitions(new int[] { partition.index() });
assert 0 < array.length;
if (0 == Arrays.stream(array).mapToInt((@Nonnull final List<Tensor[]> x) -> x.size()).sum()) {
return null;
}
assert 0 < Arrays.stream(array).mapToInt(x -> x.stream().mapToInt(y -> y.length).sum()).sum();
final Stream<Tensor[]> stream = Arrays.stream(array).flatMap(i -> i.stream());
@Nonnull final Iterator<Tensor[]> iterator = stream.iterator();
return new PartitionTask(network).call(iterator).next();
} catch (@Nonnull final RuntimeException e) {
throw e;
} catch (@Nonnull final Exception e) {
throw new RuntimeException(e);
}
}).filter(x -> null != x).collect(Collectors.toList());
final long time2 = System.nanoTime();
@Nonnull final SparkTrainable.ReducableResult result = mapPartitions.stream().reduce(SparkTrainable.ReducableResult::add).get();
if (isVerbose()) {
log.info(String.format("Measure timing: %.3f / %.3f for %s items", (time2 - time1) * 1e-9, (System.nanoTime() - time2) * 1e-9, sampledRDD.count()));
}
@Nonnull final DeltaSet<Layer> xxx = getDelta(result);
return new PointSample(xxx, new StateSet<Layer>(xxx), result.sum, 0.0, result.count).normalize();
}
Aggregations