use of com.simiacryptus.mindseye.lang.LayerBase in project MindsEye by SimiaCryptus.
the class StandardLayerTests method getInvocations.
/**
* Gets invocations.
*
* @param smallLayer the small layer
* @param smallDims the small dims
* @return the invocations
*/
@Nonnull
public Collection<Invocation> getInvocations(@Nonnull Layer smallLayer, @Nonnull int[][] smallDims) {
@Nonnull DAGNetwork smallCopy = (DAGNetwork) smallLayer.copy();
@Nonnull HashSet<Invocation> invocations = new HashSet<>();
smallCopy.visitNodes(node -> {
@Nullable Layer inner = node.getLayer();
inner.addRef();
@Nullable Layer wrapper = new LayerBase() {
@Nullable
@Override
public Result eval(@Nonnull Result... array) {
if (null == inner)
return null;
@Nullable Result result = inner.eval(array);
invocations.add(new Invocation(inner, Arrays.stream(array).map(x -> x.getData().getDimensions()).toArray(i -> new int[i][])));
return result;
}
@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
return inner.getJson(resources, dataSerializer);
}
@Nullable
@Override
public List<double[]> state() {
return inner.state();
}
@Override
protected void _free() {
inner.freeRef();
}
};
node.setLayer(wrapper);
wrapper.freeRef();
});
Tensor[] input = Arrays.stream(smallDims).map(i -> new Tensor(i)).toArray(i -> new Tensor[i]);
try {
Result eval = smallCopy.eval(input);
eval.freeRef();
eval.getData().freeRef();
return invocations;
} finally {
Arrays.stream(input).forEach(ReferenceCounting::freeRef);
smallCopy.freeRef();
}
}
use of com.simiacryptus.mindseye.lang.LayerBase in project MindsEye by SimiaCryptus.
the class SimpleConvolutionLayer method getCompatibilityLayer.
/**
* Gets compatibility layer.
*
* @return the compatibility layer
*/
@Nonnull
public Layer getCompatibilityLayer() {
log.info(String.format("Using compatibility layer for %s", this));
int bands = (int) Math.sqrt(this.kernel.getDimensions()[2]);
@Nonnull final com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer convolutionLayer = new com.simiacryptus.mindseye.layers.aparapi.ConvolutionLayer(this.kernel.getDimensions()[0], this.kernel.getDimensions()[1], this.kernel.getDimensions()[2], true);
@Nonnull final Tensor tensor = new Tensor(kernel.getDimensions());
tensor.setByCoord(c -> {
final int band = c.getCoords()[2];
final int bandX = band % bands;
final int bandY = (band - bandX) / bands;
assert band == bandX + bandY * bands;
final int bandT = bandY + bandX * bands;
return kernel.get(c.getCoords()[0], c.getCoords()[1], bandT);
});
convolutionLayer.kernel.set(tensor);
return new LayerBase() {
@Nonnull
@Override
public Result eval(@Nonnull Result... array) {
Arrays.stream(array).forEach(x -> x.addRef());
@Nonnull Result result = convolutionLayer.eval(array);
return new Result(result.getData(), (DeltaSet<Layer> buffer, TensorList data) -> {
throw new IllegalStateException();
}) {
@Override
protected void _free() {
Arrays.stream(array).forEach(x -> x.freeRef());
}
@Override
public boolean isAlive() {
return false;
}
};
}
@Nonnull
@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
throw new IllegalStateException();
}
@Nonnull
@Override
public List<double[]> state() {
throw new IllegalStateException();
}
};
}
use of com.simiacryptus.mindseye.lang.LayerBase in project MindsEye by SimiaCryptus.
the class RecursiveSubspace method buildSubspace.
/**
* Build subspace nn layer.
*
* @param subject the subject
* @param measurement the measurement
* @param monitor the monitor
* @return the nn layer
*/
@Nullable
public Layer buildSubspace(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
@Nonnull PointSample origin = measurement.copyFull().backup();
@Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
final double magnitude = direction.getMagnitude();
if (Math.abs(magnitude) < 1e-10) {
monitor.log(String.format("Zero gradient: %s", magnitude));
} else if (Math.abs(magnitude) < 1e-5) {
monitor.log(String.format("Low gradient: %s", magnitude));
}
boolean hasPlaceholders = direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).findAny().isPresent();
List<Layer> deltaLayers = direction.getMap().entrySet().stream().map(x -> x.getKey()).filter(x -> !(x instanceof PlaceholderLayer)).collect(Collectors.toList());
int size = deltaLayers.size() + (hasPlaceholders ? 1 : 0);
if (null == weights || weights.length != size)
weights = new double[size];
return new LayerBase() {
@Nonnull
Layer self = this;
@Nonnull
@Override
public Result eval(Result... array) {
assertAlive();
origin.restore();
IntStream.range(0, deltaLayers.size()).forEach(i -> {
direction.getMap().get(deltaLayers.get(i)).accumulate(weights[hasPlaceholders ? (i + 1) : i]);
});
if (hasPlaceholders) {
direction.getMap().entrySet().stream().filter(x -> x.getKey() instanceof PlaceholderLayer).distinct().forEach(entry -> entry.getValue().accumulate(weights[0]));
}
PointSample measure = subject.measure(monitor);
double mean = measure.getMean();
monitor.log(String.format("RecursiveSubspace: %s <- %s", mean, Arrays.toString(weights)));
direction.addRef();
return new Result(TensorArray.wrap(new Tensor(mean)), (DeltaSet<Layer> buffer, TensorList data) -> {
DoubleStream deltaStream = deltaLayers.stream().mapToDouble(layer -> {
Delta<Layer> a = direction.getMap().get(layer);
Delta<Layer> b = measure.delta.getMap().get(layer);
return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
});
if (hasPlaceholders) {
deltaStream = DoubleStream.concat(DoubleStream.of(direction.getMap().keySet().stream().filter(x -> x instanceof PlaceholderLayer).distinct().mapToDouble(layer -> {
Delta<Layer> a = direction.getMap().get(layer);
Delta<Layer> b = measure.delta.getMap().get(layer);
return b.dot(a) / Math.max(Math.sqrt(a.dot(a)), 1e-8);
}).sum()), deltaStream);
}
buffer.get(self, weights).addInPlace(deltaStream.toArray()).freeRef();
}) {
@Override
protected void _free() {
measure.freeRef();
direction.freeRef();
}
@Override
public boolean isAlive() {
return true;
}
};
}
@Override
protected void _free() {
direction.freeRef();
origin.freeRef();
super._free();
}
@Nonnull
@Override
public JsonObject getJson(Map<CharSequence, byte[]> resources, DataSerializer dataSerializer) {
throw new IllegalStateException();
}
@Nullable
@Override
public List<double[]> state() {
return null;
}
};
}
Aggregations