Search in sources :

Example 86 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class StaticScalarLossLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    if (1 != inObj.length)
        throw new IllegalArgumentException();
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    // if (inObj[0].getData().length() != 1) throw new IllegalArgumentException();
    final Result in0 = inObj[0];
    TensorList indata = in0.getData();
    indata.addRef();
    return new Result(TensorArray.wrap(IntStream.range(0, indata.length()).parallel().mapToObj(dataIndex -> {
        @Nullable final Tensor a = indata.get(dataIndex);
        final double diff = Math.abs(a.get(0) - getTarget());
        a.freeRef();
        return new Tensor(new double[] { diff }, 1);
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (in0.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, data.length()).parallel().mapToObj(dataIndex -> {
                @Nullable final Tensor a = indata.get(dataIndex);
                Tensor tensor = data.get(dataIndex);
                final double deriv = tensor.get(0) * (a.get(0) - getTarget() < 0 ? -1 : 1);
                tensor.freeRef();
                a.freeRef();
                return new Tensor(new double[] { deriv }, 1);
            }).toArray(i -> new Tensor[i]));
            in0.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            indata.freeRef();
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return in0.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 87 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class LayerRateDiagnosticTrainer method run.

/**
 * Run map.
 *
 * @return the map
 */
@Nonnull
public Map<Layer, LayerStats> run() {
    final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
    PointSample measure = measure();
    @Nonnull final ArrayList<Layer> layers = new ArrayList<>(measure.weights.getMap().keySet());
    while (timeoutMs > System.currentTimeMillis() && measure.sum > terminateThreshold) {
        if (currentIteration.get() > maxIterations) {
            break;
        }
        final PointSample initialPhasePoint = measure();
        measure = initialPhasePoint;
        for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
            if (currentIteration.incrementAndGet() > maxIterations) {
                break;
            }
            {
                @Nonnull final SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
                final double stepSize = 1e-12 * orient.origin.sum;
                @Nonnull final DeltaSet<Layer> pointB = orient.step(stepSize, monitor).point.delta.copy();
                @Nonnull final DeltaSet<Layer> pointA = orient.step(0.0, monitor).point.delta.copy();
                @Nonnull final DeltaSet<Layer> d1 = pointA;
                @Nonnull final DeltaSet<Layer> d2 = d1.add(pointB.scale(-1)).scale(1.0 / stepSize);
                @Nonnull final Map<Layer, Double> steps = new HashMap<>();
                final double overallStepEstimate = d1.getMagnitude() / d2.getMagnitude();
                for (final Layer layer : layers) {
                    final DoubleBuffer<Layer> a = d2.get(layer, (double[]) null);
                    final DoubleBuffer<Layer> b = d1.get(layer, (double[]) null);
                    final double bmag = Math.sqrt(b.deltaStatistics().sumSq());
                    final double amag = Math.sqrt(a.deltaStatistics().sumSq());
                    final double dot = a.dot(b) / (amag * bmag);
                    final double idealSize = bmag / (amag * dot);
                    steps.put(layer, idealSize);
                    monitor.log(String.format("Layers stats: %s (%s, %s, %s) => %s", layer, amag, bmag, dot, idealSize));
                }
                monitor.log(String.format("Estimated ideal rates for layers: %s (%s overall; probed at %s)", steps, overallStepEstimate, stepSize));
            }
            @Nullable SimpleLineSearchCursor bestOrient = null;
            @Nullable PointSample bestPoint = null;
            layerLoop: for (@Nonnull final Layer layer : layers) {
                @Nonnull SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
                @Nonnull final DeltaSet<Layer> direction = filterDirection(orient.direction, layer);
                if (direction.getMagnitude() == 0) {
                    monitor.log(String.format("Zero derivative for layer %s; skipping", layer));
                    continue layerLoop;
                }
                orient = new SimpleLineSearchCursor(orient.subject, orient.origin, direction);
                final PointSample previous = measure;
                measure = getLineSearchStrategy().step(orient, monitor);
                if (isStrict()) {
                    monitor.log(String.format("Iteration %s reverting. Error: %s", currentIteration.get(), measure.sum));
                    monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
                    if (null == bestPoint || bestPoint.sum < measure.sum) {
                        bestOrient = orient;
                        bestPoint = measure;
                    }
                    getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
                    orient.step(0, monitor);
                    measure = previous;
                } else if (previous.sum == measure.sum) {
                    monitor.log(String.format("Iteration %s failed. Error: %s", currentIteration.get(), measure.sum));
                } else {
                    monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), measure.sum));
                    monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
                    getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
                }
            }
            monitor.log(String.format("Ideal rates: %s", getLayerRates()));
            if (null != bestPoint) {
                bestOrient.step(bestPoint.rate, monitor);
            }
            monitor.onStepComplete(new Step(measure, currentIteration.get()));
        }
    }
    return getLayerRates();
}
Also used : DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) HashMap(java.util.HashMap) Map(java.util.Map) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 88 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class NthPowerActivationLayer method eval.

@Override
public Result eval(@Nonnull final Result... inObj) {
    final int itemCnt = inObj[0].getData().length();
    assert 0 < itemCnt;
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    @Nonnull final Tensor[] inputGradientA = new Tensor[itemCnt];
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
        @Nullable final Tensor input = inObj[0].getData().get(dataIndex);
        @Nonnull final Tensor output = new Tensor(inObj[0].getData().getDimensions());
        @Nonnull final Tensor gradient = new Tensor(input.length());
        @Nullable final double[] inputData = input.getData();
        @Nullable final double[] gradientData = gradient.getData();
        @Nullable final double[] outputData = output.getData();
        inputGradientA[dataIndex] = gradient;
        if (power == 2) {
            NthPowerActivationLayer.square(input, inputData, gradientData, outputData);
        } else if (power == 0.5) {
            NthPowerActivationLayer.squareRoot(input, inputData, gradientData, outputData);
        } else if (power == 0.0) {
            NthPowerActivationLayer.unity(input, inputData, gradientData, outputData);
        } else {
            NthPowerActivationLayer.nthPower(power, input, inputData, gradientData, outputData);
        }
        input.freeRef();
        return output;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (inObj[0].isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, itemCnt).parallel().mapToObj(dataIndex -> {
                @Nonnull final Tensor passback = new Tensor(data.getDimensions());
                @Nullable final Tensor tensor = data.get(dataIndex);
                @Nullable double[] tensorData = tensor.getData();
                @Nullable final double[] gradientData = inputGradientA[dataIndex].getData();
                IntStream.range(0, passback.length()).forEach(i -> {
                    final double v = gradientData[i];
                    if (Double.isFinite(v)) {
                        passback.set(i, tensorData[i] * v);
                    }
                });
                tensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            inObj[0].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(ReferenceCounting::freeRef);
            Arrays.stream(inputGradientA).forEach(ReferenceCounting::freeRef);
        }

        @Override
        public boolean isAlive() {
            return 0.0 != power && inObj[0].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable)

Example 89 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class ProductLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    Arrays.stream(inObj).forEach(x -> x.getData().addRef());
    final Result in0 = inObj[0];
    @Nonnull final double[] sum_A = new double[in0.getData().length()];
    final Tensor[] outputA = IntStream.range(0, in0.getData().length()).mapToObj(dataIndex -> {
        double sum = 1;
        for (@Nonnull final Result element : inObj) {
            Tensor tensor = element.getData().get(dataIndex);
            @Nullable final double[] input = tensor.getData();
            for (final double element2 : input) {
                sum *= element2;
            }
            tensor.freeRef();
        }
        sum_A[dataIndex] = sum;
        return new Tensor(new double[] { sum }, 1);
    }).toArray(i -> new Tensor[i]);
    return new Result(TensorArray.wrap(outputA), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        for (@Nonnull final Result in_l : inObj) {
            if (in_l.isAlive()) {
                @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                    Tensor dataTensor = delta.get(dataIndex);
                    Tensor lTensor = in_l.getData().get(dataIndex);
                    @Nonnull final Tensor passback = new Tensor(lTensor.getDimensions());
                    for (int i = 0; i < lTensor.length(); i++) {
                        passback.set(i, dataTensor.get(0) * sum_A[dataIndex] / lTensor.getData()[i]);
                    }
                    dataTensor.freeRef();
                    lTensor.freeRef();
                    return passback;
                }).toArray(i -> new Tensor[i]));
                in_l.accumulate(buffer, tensorArray);
            }
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
            Arrays.stream(inObj).forEach(x -> x.getData().freeRef());
        }

        @Override
        public boolean isAlive() {
            for (@Nonnull final Result element : inObj) if (element.isAlive()) {
                return true;
            }
            return false;
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 90 with DeltaSet

use of com.simiacryptus.mindseye.lang.DeltaSet in project MindsEye by SimiaCryptus.

the class ReshapeLayer method evalAndFree.

@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
    assert 1 == inObj.length;
    TensorList data = inObj[0].getData();
    @Nonnull int[] inputDims = data.getDimensions();
    ReshapedTensorList reshapedTensorList = new ReshapedTensorList(data, outputDims);
    data.freeRef();
    return new Result(reshapedTensorList, (DeltaSet<Layer> buffer, TensorList delta) -> {
        @Nonnull ReshapedTensorList tensorList = new ReshapedTensorList(delta, inputDims);
        inObj[0].accumulate(buffer, tensorList);
    }) {

        @Override
        protected void _free() {
            for (@Nonnull Result result : inObj) {
                result.freeRef();
            }
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive();
        }
    };
}
Also used : ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Aggregations

DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)98 Nonnull (javax.annotation.Nonnull)98 Result (com.simiacryptus.mindseye.lang.Result)90 Nullable (javax.annotation.Nullable)88 Layer (com.simiacryptus.mindseye.lang.Layer)86 TensorList (com.simiacryptus.mindseye.lang.TensorList)86 Arrays (java.util.Arrays)77 List (java.util.List)75 Tensor (com.simiacryptus.mindseye.lang.Tensor)73 Map (java.util.Map)66 IntStream (java.util.stream.IntStream)65 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)64 JsonObject (com.google.gson.JsonObject)62 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)61 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)60 Logger (org.slf4j.Logger)47 LoggerFactory (org.slf4j.LoggerFactory)47 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)23 CudaTensor (com.simiacryptus.mindseye.lang.cudnn.CudaTensor)22 CudaTensorList (com.simiacryptus.mindseye.lang.cudnn.CudaTensorList)22