Search in sources :

Example 1 with DoubleMatrix

use of org.jblas.DoubleMatrix in project MultiTypeTree by tgvaughan.

the class SCMigrationModel method getRpowN.

@Override
public DoubleMatrix getRpowN(int n, boolean symmetric) {
    updateMatrices();
    List<DoubleMatrix> matPowerList;
    DoubleMatrix mat, matPowerMax;
    if (symmetric) {
        matPowerList = RsymPowN;
        mat = Rsym;
        matPowerMax = RsymPowMax;
    } else {
        matPowerList = RpowN;
        mat = R;
        matPowerMax = RpowMax;
    }
    if (n >= matPowerList.size()) {
        // Steady state of matrix iteration already reached
        if ((symmetric && RsymPowSteady) || (!symmetric && RpowSteady)) {
            // System.out.println("Assuming R SS.");
            return matPowerList.get(matPowerList.size() - 1);
        }
        int startN = matPowerList.size();
        for (int i = startN; i <= n; i++) {
            matPowerList.add(matPowerList.get(i - 1).mmul(mat));
            matPowerMax.maxi(matPowerList.get(i));
            // Occasionally check whether matrix iteration has reached steady state
            if (i % 10 == 0) {
                double maxDiff = 0.0;
                for (double el : matPowerList.get(i).sub(matPowerList.get(i - 1)).toArray()) maxDiff = Math.max(maxDiff, Math.abs(el));
                if (!(maxDiff > 0)) {
                    if (symmetric)
                        RsymPowSteady = true;
                    else
                        RpowSteady = true;
                    return matPowerList.get(i);
                }
            }
        }
    }
    return matPowerList.get(n);
}
Also used : DoubleMatrix(org.jblas.DoubleMatrix)

Example 2 with DoubleMatrix

use of org.jblas.DoubleMatrix in project MultiTypeTree by tgvaughan.

the class SCMigrationModel method updateMatrices.

/**
 * Ensure all local fields including matrices and eigenvalue decomposition
 * objects are consistent with current values held by inputs.
 */
public void updateMatrices() {
    if (!dirty)
        return;
    mu = 0.0;
    muSym = 0.0;
    Q = new DoubleMatrix(nTypes, nTypes);
    Qsym = new DoubleMatrix(nTypes, nTypes);
    // transition rate matrix Qsym:
    for (int i = 0; i < nTypes; i++) {
        Q.put(i, i, 0.0);
        Qsym.put(i, i, 0.0);
        for (int j = 0; j < nTypes; j++) {
            if (i != j) {
                Q.put(i, j, getBackwardRate(i, j));
                Q.put(i, i, Q.get(i, i) - Q.get(i, j));
                Qsym.put(i, j, 0.5 * (getBackwardRate(i, j) + getBackwardRate(j, i)));
                Qsym.put(i, i, Qsym.get(i, i) - Qsym.get(i, j));
            }
        }
        if (-Q.get(i, i) > mu)
            mu = -Q.get(i, i);
        if (-Qsym.get(i, i) > muSym)
            muSym = -Qsym.get(i, i);
    }
    // Set up uniformized backward transition rate matrices R and Rsym:
    R = Q.mul(1.0 / mu).add(DoubleMatrix.eye(nTypes));
    Rsym = Qsym.mul(1.0 / muSym).add(DoubleMatrix.eye(nTypes));
    // Clear cached powers of R and Rsym and steady state flag:
    RpowN.clear();
    RsymPowN.clear();
    RpowSteady = false;
    RsymPowSteady = false;
    // Power sequences initially contain R^0 = I
    RpowN.add(DoubleMatrix.eye(nTypes));
    RsymPowN.add(DoubleMatrix.eye(nTypes));
    RpowMax = DoubleMatrix.eye(nTypes);
    RsymPowMax = DoubleMatrix.eye(nTypes);
    dirty = false;
}
Also used : DoubleMatrix(org.jblas.DoubleMatrix)

Example 3 with DoubleMatrix

use of org.jblas.DoubleMatrix in project MultiTypeTree by tgvaughan.

the class SCMigrationModel method main.

/**
 * Main for debugging.
 *
 * @param args
 */
public static void main(String[] args) {
    int n = 10;
    DoubleMatrix Q = new DoubleMatrix(n, n);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            Q.put(i, j, i * n + j);
        }
    }
    MatrixFunctions.expm(Q.mul(0.001)).print();
    Q.print();
}
Also used : DoubleMatrix(org.jblas.DoubleMatrix)

Example 4 with DoubleMatrix

use of org.jblas.DoubleMatrix in project MindsEye by SimiaCryptus.

the class FullyConnectedLayer method multiply.

/**
 * Multiply.
 *
 * @param matrix the matrix
 * @param in     the in
 * @param out    the out
 */
public static void multiply(final double[] matrix, @Nonnull final double[] in, @Nonnull final double[] out) {
    @Nonnull final DoubleMatrix matrixObj = new DoubleMatrix(out.length, in.length, matrix);
    matrixObj.mmuli(new DoubleMatrix(in.length, 1, in), new DoubleMatrix(out.length, 1, out));
}
Also used : DoubleMatrix(org.jblas.DoubleMatrix) Nonnull(javax.annotation.Nonnull)

Example 5 with DoubleMatrix

use of org.jblas.DoubleMatrix in project MindsEye by SimiaCryptus.

the class FullyConnectedLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    final TensorList indata = inObj[0].getData();
    indata.addRef();
    for (@Nonnull Result result : inObj) {
        result.addRef();
    }
    FullyConnectedLayer.this.addRef();
    assert Tensor.length(indata.getDimensions()) == Tensor.length(this.inputDims) : Arrays.toString(indata.getDimensions()) + " == " + Arrays.toString(this.inputDims);
    @Nonnull DoubleMatrix doubleMatrix = new DoubleMatrix(Tensor.length(indata.getDimensions()), Tensor.length(outputDims), this.weights.getData());
    @Nonnull final DoubleMatrix matrixObj = FullyConnectedLayer.transpose(doubleMatrix);
    @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, indata.length()).parallel().mapToObj(dataIndex -> {
        @Nullable final Tensor input = indata.get(dataIndex);
        @Nullable final Tensor output = new Tensor(outputDims);
        matrixObj.mmuli(new DoubleMatrix(input.length(), 1, input.getData()), new DoubleMatrix(output.length(), 1, output.getData()));
        input.freeRef();
        return output;
    }).toArray(i -> new Tensor[i]));
    RecycleBin.DOUBLES.recycle(matrixObj.data, matrixObj.data.length);
    this.weights.addRef();
    return new Result(tensorArray, (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (!isFrozen()) {
            final Delta<Layer> deltaBuffer = buffer.get(FullyConnectedLayer.this, this.weights.getData());
            final int threads = 4;
            IntStream.range(0, threads).parallel().mapToObj(x -> x).flatMap(thread -> {
                @Nullable Stream<Tensor> stream = IntStream.range(0, indata.length()).filter(i -> thread == i % threads).mapToObj(dataIndex -> {
                    @Nonnull final Tensor weightDelta = new Tensor(Tensor.length(inputDims), Tensor.length(outputDims));
                    Tensor deltaTensor = delta.get(dataIndex);
                    Tensor inputTensor = indata.get(dataIndex);
                    FullyConnectedLayer.crossMultiplyT(deltaTensor.getData(), inputTensor.getData(), weightDelta.getData());
                    inputTensor.freeRef();
                    deltaTensor.freeRef();
                    return weightDelta;
                });
                return stream;
            }).reduce((a, b) -> {
                @Nullable Tensor c = a.addAndFree(b);
                b.freeRef();
                return c;
            }).map(data -> {
                @Nonnull Delta<Layer> layerDelta = deltaBuffer.addInPlace(data.getData());
                data.freeRef();
                return layerDelta;
            });
            deltaBuffer.freeRef();
        }
        if (inObj[0].isAlive()) {
            @Nonnull final TensorList tensorList = TensorArray.wrap(IntStream.range(0, indata.length()).parallel().mapToObj(dataIndex -> {
                Tensor deltaTensor = delta.get(dataIndex);
                @Nonnull final Tensor passback = new Tensor(indata.getDimensions());
                FullyConnectedLayer.multiply(this.weights.getData(), deltaTensor.getData(), passback.getData());
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            inObj[0].accumulate(buffer, tensorList);
        }
    }) {

        @Override
        protected void _free() {
            indata.freeRef();
            FullyConnectedLayer.this.freeRef();
            for (@Nonnull Result result : inObj) {
                result.freeRef();
            }
            FullyConnectedLayer.this.weights.freeRef();
        }

        @Override
        public boolean isAlive() {
            return !isFrozen() || Arrays.stream(inObj).anyMatch(x -> x.isAlive());
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) JsonUtil(com.simiacryptus.util.io.JsonUtil) Delta(com.simiacryptus.mindseye.lang.Delta) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) DoubleMatrix(org.jblas.DoubleMatrix) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) Logger(org.slf4j.Logger) IntToDoubleFunction(java.util.function.IntToDoubleFunction) FastRandom(com.simiacryptus.util.FastRandom) ToDoubleBiFunction(java.util.function.ToDoubleBiFunction) RecycleBin(com.simiacryptus.mindseye.lang.RecycleBin) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Stream(java.util.stream.Stream) ToDoubleFunction(java.util.function.ToDoubleFunction) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSupplier(java.util.function.DoubleSupplier) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) DoubleMatrix(org.jblas.DoubleMatrix) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

DoubleMatrix (org.jblas.DoubleMatrix)6 Nonnull (javax.annotation.Nonnull)3 JsonObject (com.google.gson.JsonObject)1 Coordinate (com.simiacryptus.mindseye.lang.Coordinate)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 Delta (com.simiacryptus.mindseye.lang.Delta)1 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)1 RecycleBin (com.simiacryptus.mindseye.lang.RecycleBin)1 Result (com.simiacryptus.mindseye.lang.Result)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)1 TensorList (com.simiacryptus.mindseye.lang.TensorList)1 FastRandom (com.simiacryptus.util.FastRandom)1 Util (com.simiacryptus.util.Util)1 JsonUtil (com.simiacryptus.util.io.JsonUtil)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Map (java.util.Map)1