Search in sources :

Example 1 with RBLongFloatMatrix

use of com.tencent.angel.ml.math2.matrix.RBLongFloatMatrix in project angel by Tencent.

the class DotMatrixExecutor method apply.

public static Matrix apply(Matrix mat1, boolean trans1, Matrix mat2, boolean trans2, Boolean parallel) {
    if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
        if (parallel) {
            return applyParallel((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
        } else {
            return apply((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
        }
    } else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof BlasFloatMatrix) {
        if (parallel) {
            return applyParallel((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
        } else {
            return apply((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
        }
    } else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBIntDoubleMatrix) {
        return apply((BlasDoubleMatrix) mat1, trans1, (RBIntDoubleMatrix) mat2, trans2);
    } else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBLongDoubleMatrix) {
        return apply((BlasDoubleMatrix) mat1, trans1, (RBLongDoubleMatrix) mat2, trans2);
    } else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBIntFloatMatrix) {
        return apply((BlasFloatMatrix) mat1, trans1, (RBIntFloatMatrix) mat2, trans2);
    } else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBLongFloatMatrix) {
        return apply((BlasFloatMatrix) mat1, trans1, (RBLongFloatMatrix) mat2, trans2);
    } else if (mat1 instanceof RBIntDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
        return apply((RBIntDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
    } else if (mat1 instanceof RBIntFloatMatrix && mat2 instanceof BlasFloatMatrix) {
        return apply((RBIntFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
    } else if (mat1 instanceof RowBasedMatrix && mat2 instanceof RowBasedMatrix) {
        if (!trans1 && trans2) {
            int outputRow = mat1.getNumRows();
            int outputCol = mat2.getNumRows();
            RowType type1 = mat1.getRow(0).getStorage().getType();
            RowType type2 = mat2.getRow(0).getStorage().getType();
            if (type1.isDouble() && type2.isDouble()) {
                BlasDoubleMatrix res = MFactory.denseDoubleMatrix(outputRow, outputCol);
                for (int i = 0; i < outputCol; i++) {
                    Vector row = mat2.getRow(i);
                    Vector col = mat1.dot(row);
                    res.setCol(i, col);
                }
                return res;
            } else if (type1.isFloat() && type2.isFloat()) {
                BlasFloatMatrix res = MFactory.denseFloatMatrix(outputRow, outputCol);
                for (int i = 0; i < outputCol; i++) {
                    Vector row = mat2.getRow(i);
                    Vector col = mat1.dot(row);
                    res.setCol(i, col);
                }
                return res;
            } else {
                throw new AngelException("the operation is not supported!");
            }
        } else {
            throw new AngelException("the operation is not supported!");
        }
    } else {
        throw new AngelException("the operation is not supported!");
    }
}
Also used : RBLongDoubleMatrix(com.tencent.angel.ml.math2.matrix.RBLongDoubleMatrix) AngelException(com.tencent.angel.exception.AngelException) RBLongFloatMatrix(com.tencent.angel.ml.math2.matrix.RBLongFloatMatrix) RBIntFloatMatrix(com.tencent.angel.ml.math2.matrix.RBIntFloatMatrix) RowBasedMatrix(com.tencent.angel.ml.math2.matrix.RowBasedMatrix) RBIntDoubleMatrix(com.tencent.angel.ml.math2.matrix.RBIntDoubleMatrix) RowType(com.tencent.angel.ml.matrix.RowType) BlasDoubleMatrix(com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix) BlasFloatMatrix(com.tencent.angel.ml.math2.matrix.BlasFloatMatrix) IntLongVector(com.tencent.angel.ml.math2.vector.IntLongVector) IntFloatVector(com.tencent.angel.ml.math2.vector.IntFloatVector) LongDoubleVector(com.tencent.angel.ml.math2.vector.LongDoubleVector) Vector(com.tencent.angel.ml.math2.vector.Vector) LongFloatVector(com.tencent.angel.ml.math2.vector.LongFloatVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDummyVector(com.tencent.angel.ml.math2.vector.IntDummyVector)

Example 2 with RBLongFloatMatrix

use of com.tencent.angel.ml.math2.matrix.RBLongFloatMatrix in project angel by Tencent.

the class DotMatrixExecutor method apply.

private static Matrix apply(BlasFloatMatrix mat1, boolean trans1, RBLongFloatMatrix mat2, boolean trans2) {
    if (trans1 && !trans2) {
        int outputRows = mat1.getNumCols();
        LongFloatVector[] rows = new LongFloatVector[outputRows];
        for (int i = 0; i < outputRows; i++) {
            Vector col = mat1.getCol(i);
            rows[i] = (LongFloatVector) mat2.transDot(col);
        }
        return MFactory.rbLongFloatMatrix(rows);
    } else if (!trans1 && !trans2) {
        int outputRows = mat1.getNumRows();
        LongFloatVector[] rows = new LongFloatVector[outputRows];
        for (int i = 0; i < outputRows; i++) {
            Vector row = mat1.getRow(i);
            rows[i] = (LongFloatVector) mat2.transDot(row);
        }
        return MFactory.rbLongFloatMatrix(rows);
    } else {
        throw new AngelException("the operation is not supported!");
    }
}
Also used : AngelException(com.tencent.angel.exception.AngelException) LongFloatVector(com.tencent.angel.ml.math2.vector.LongFloatVector) IntLongVector(com.tencent.angel.ml.math2.vector.IntLongVector) IntFloatVector(com.tencent.angel.ml.math2.vector.IntFloatVector) LongDoubleVector(com.tencent.angel.ml.math2.vector.LongDoubleVector) Vector(com.tencent.angel.ml.math2.vector.Vector) LongFloatVector(com.tencent.angel.ml.math2.vector.LongFloatVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDummyVector(com.tencent.angel.ml.math2.vector.IntDummyVector)

Aggregations

AngelException (com.tencent.angel.exception.AngelException)2 IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)2 IntDummyVector (com.tencent.angel.ml.math2.vector.IntDummyVector)2 IntFloatVector (com.tencent.angel.ml.math2.vector.IntFloatVector)2 IntIntVector (com.tencent.angel.ml.math2.vector.IntIntVector)2 IntLongVector (com.tencent.angel.ml.math2.vector.IntLongVector)2 LongDoubleVector (com.tencent.angel.ml.math2.vector.LongDoubleVector)2 LongFloatVector (com.tencent.angel.ml.math2.vector.LongFloatVector)2 Vector (com.tencent.angel.ml.math2.vector.Vector)2 BlasDoubleMatrix (com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix)1 BlasFloatMatrix (com.tencent.angel.ml.math2.matrix.BlasFloatMatrix)1 RBIntDoubleMatrix (com.tencent.angel.ml.math2.matrix.RBIntDoubleMatrix)1 RBIntFloatMatrix (com.tencent.angel.ml.math2.matrix.RBIntFloatMatrix)1 RBLongDoubleMatrix (com.tencent.angel.ml.math2.matrix.RBLongDoubleMatrix)1 RBLongFloatMatrix (com.tencent.angel.ml.math2.matrix.RBLongFloatMatrix)1 RowBasedMatrix (com.tencent.angel.ml.math2.matrix.RowBasedMatrix)1 RowType (com.tencent.angel.ml.matrix.RowType)1