use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.
the class ProductInputsLayer method eval.
@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
assert inObj.length > 1;
Arrays.stream(inObj).forEach(x -> x.getData().addRef());
Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
for (int i = 1; i < inObj.length; i++) {
final int dim0 = Tensor.length(inObj[0].getData().getDimensions());
final int dimI = Tensor.length(inObj[i].getData().getDimensions());
if (dim0 != 1 && dimI != 1 && dim0 != dimI) {
throw new IllegalArgumentException(Arrays.toString(inObj[0].getData().getDimensions()) + " != " + Arrays.toString(inObj[i].getData().getDimensions()));
}
}
return new Result(Arrays.stream(inObj).parallel().map(x -> {
TensorList data = x.getData();
data.addRef();
return data;
}).reduce((l, r) -> {
TensorArray productArray = TensorArray.wrap(IntStream.range(0, Math.max(l.length(), r.length())).parallel().mapToObj(i1 -> {
@Nullable final Tensor left = l.get(1 == l.length() ? 0 : i1);
@Nullable final Tensor right = r.get(1 == r.length() ? 0 : i1);
Tensor product = Tensor.product(left, right);
left.freeRef();
right.freeRef();
return product;
}).toArray(i -> new Tensor[i]));
l.freeRef();
r.freeRef();
return productArray;
}).get(), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
for (@Nonnull final Result input : inObj) {
if (input.isAlive()) {
@Nonnull TensorList passback = Arrays.stream(inObj).parallel().map(x -> {
TensorList tensorList = x == input ? delta : x.getData();
tensorList.addRef();
return tensorList;
}).reduce((l, r) -> {
TensorArray productList = TensorArray.wrap(IntStream.range(0, Math.max(l.length(), r.length())).parallel().mapToObj(j -> {
@Nullable final Tensor left = l.get(1 == l.length() ? 0 : j);
@Nullable final Tensor right = r.get(1 == r.length() ? 0 : j);
Tensor product = Tensor.product(left, right);
left.freeRef();
right.freeRef();
return product;
}).toArray(j -> new Tensor[j]));
l.freeRef();
r.freeRef();
return productList;
}).get();
final TensorList inputData = input.getData();
if (1 == inputData.length() && 1 < passback.length()) {
TensorArray newValue = TensorArray.wrap(passback.stream().reduce((a, b) -> {
@Nullable Tensor c = a.addAndFree(b);
b.freeRef();
return c;
}).get());
passback.freeRef();
passback = newValue;
}
if (1 == Tensor.length(inputData.getDimensions()) && 1 < Tensor.length(passback.getDimensions())) {
TensorArray newValue = TensorArray.wrap(passback.stream().map((a) -> {
@Nonnull Tensor b = new Tensor(a.sum());
a.freeRef();
return b;
}).toArray(i -> new Tensor[i]));
passback.freeRef();
passback = newValue;
}
input.accumulate(buffer, passback);
}
}
}) {
@Override
public boolean isAlive() {
for (@Nonnull final Result element : inObj) if (element.isAlive()) {
return true;
}
return false;
}
@Override
protected void _free() {
Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
Arrays.stream(inObj).forEach(x -> x.getData().freeRef());
}
};
}
use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.
the class ReLuActivationLayer method eval.
@Nonnull
@Override
public Result eval(final Result... inObj) {
final Result input = inObj[0];
final TensorList indata = input.getData();
input.addRef();
indata.addRef();
weights.addRef();
final int itemCnt = indata.length();
return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
@Nullable Tensor tensorElement = indata.get(dataIndex);
@Nonnull final Tensor tensor = tensorElement.multiply(weights.get(0));
tensorElement.freeRef();
@Nullable final double[] outputData = tensor.getData();
for (int i = 0; i < outputData.length; i++) {
if (outputData[i] < 0) {
outputData[i] = 0;
}
}
return tensor;
}).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
if (!isFrozen()) {
IntStream.range(0, delta.length()).parallel().forEach(dataIndex -> {
@Nullable Tensor deltaTensor = delta.get(dataIndex);
@Nullable final double[] deltaData = deltaTensor.getData();
@Nullable Tensor inputTensor = indata.get(dataIndex);
@Nullable final double[] inputData = inputTensor.getData();
@Nonnull final Tensor weightDelta = new Tensor(weights.getDimensions());
@Nullable final double[] weightDeltaData = weightDelta.getData();
for (int i = 0; i < deltaData.length; i++) {
weightDeltaData[0] += inputData[i] < 0 ? 0 : deltaData[i] * inputData[i];
}
buffer.get(ReLuActivationLayer.this, weights.getData()).addInPlace(weightDeltaData).freeRef();
deltaTensor.freeRef();
inputTensor.freeRef();
weightDelta.freeRef();
});
}
if (input.isAlive()) {
final double weight = weights.getData()[0];
@Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).parallel().mapToObj(dataIndex -> {
@Nullable Tensor deltaTensor = delta.get(dataIndex);
@Nullable final double[] deltaData = deltaTensor.getData();
@Nullable Tensor inTensor = indata.get(dataIndex);
@Nullable final double[] inputData = inTensor.getData();
@Nonnull final int[] dims = inTensor.getDimensions();
@Nonnull final Tensor passback = new Tensor(dims);
for (int i = 0; i < passback.length(); i++) {
passback.set(i, inputData[i] < 0 ? 0 : deltaData[i] * weight);
}
inTensor.freeRef();
deltaTensor.freeRef();
return passback;
}).toArray(i -> new Tensor[i]));
input.accumulate(buffer, tensorArray);
}
}) {
@Override
protected void _free() {
input.freeRef();
indata.freeRef();
weights.freeRef();
}
@Override
public boolean isAlive() {
return input.isAlive() || !isFrozen();
}
};
}
use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.
the class ScaleMetaLayer method eval.
@Nullable
@Override
public Result eval(@Nonnull final Result... inObj) {
final int itemCnt = inObj[0].getData().length();
Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
Arrays.stream(inObj).forEach(x -> x.getData().addRef());
final Tensor[] tensors = IntStream.range(0, itemCnt).mapToObj(dataIndex -> inObj[0].getData().get(dataIndex).mapIndex((v, c) -> v * inObj[1].getData().get(0).get(c))).toArray(i -> new Tensor[i]);
Tensor tensor0 = tensors[0];
tensor0.addRef();
return new Result(TensorArray.wrap(tensors), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
if (inObj[0].isAlive()) {
@Nonnull TensorArray tensorArray = TensorArray.wrap(data.stream().map(t -> {
@Nullable Tensor t1 = inObj[1].getData().get(0);
@Nullable Tensor tensor = t.mapIndex((v, c) -> {
return v * t1.get(c);
});
t.freeRef();
t1.freeRef();
return tensor;
}).toArray(i -> new Tensor[i]));
inObj[0].accumulate(buffer, tensorArray);
}
if (inObj[1].isAlive()) {
@Nullable final Tensor passback = tensor0.mapIndex((v, c) -> {
return IntStream.range(0, itemCnt).mapToDouble(i -> data.get(i).get(c) * inObj[0].getData().get(i).get(c)).sum();
});
tensor0.freeRef();
@Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, inObj[1].getData().length()).mapToObj(i -> i == 0 ? passback : passback.map(v -> 0)).toArray(i -> new Tensor[i]));
inObj[1].accumulate(buffer, tensorArray);
}
Arrays.stream(inObj).forEach(x -> x.getData().addRef());
}) {
@Override
protected void _free() {
Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
}
@Override
public boolean isAlive() {
return inObj[0].isAlive() || inObj[1].isAlive();
}
};
}
use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.
the class QuantifyOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
final double denominator = weightDelta.deltaStatistics().rms();
final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
return numerator / (0 == denominator ? 1 : denominator);
}).collect(Collectors.toList());
if (1 == doubleList.size())
return Double.toString(doubleList.get(0));
return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
}));
monitor.log(String.format("Line search stats: %s", dataMap));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.
the class TrustRegionStrategy method orient.
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin);
while (history.size() > maxHistory) {
history.remove(history.size() - 1);
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet<Layer> position(final double alpha) {
reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
@Nonnull
public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((layer, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta)
return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(layer);
if (null != region) {
final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
}
@Override
public void _free() {
cursor.freeRef();
}
};
}
Aggregations