use of com.tencent.angel.ml.math2.matrix.RBIntFloatMatrix in project angel by Tencent.
the class DotMatrixExecutor method apply.
public static Matrix apply(Matrix mat1, boolean trans1, Matrix mat2, boolean trans2, Boolean parallel) {
if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
if (parallel) {
return applyParallel((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
} else {
return apply((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
}
} else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof BlasFloatMatrix) {
if (parallel) {
return applyParallel((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
} else {
return apply((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
}
} else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBIntDoubleMatrix) {
return apply((BlasDoubleMatrix) mat1, trans1, (RBIntDoubleMatrix) mat2, trans2);
} else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBLongDoubleMatrix) {
return apply((BlasDoubleMatrix) mat1, trans1, (RBLongDoubleMatrix) mat2, trans2);
} else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBIntFloatMatrix) {
return apply((BlasFloatMatrix) mat1, trans1, (RBIntFloatMatrix) mat2, trans2);
} else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBLongFloatMatrix) {
return apply((BlasFloatMatrix) mat1, trans1, (RBLongFloatMatrix) mat2, trans2);
} else if (mat1 instanceof RBIntDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
return apply((RBIntDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
} else if (mat1 instanceof RBIntFloatMatrix && mat2 instanceof BlasFloatMatrix) {
return apply((RBIntFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
} else if (mat1 instanceof RowBasedMatrix && mat2 instanceof RowBasedMatrix) {
if (!trans1 && trans2) {
int outputRow = mat1.getNumRows();
int outputCol = mat2.getNumRows();
RowType type1 = mat1.getRow(0).getStorage().getType();
RowType type2 = mat2.getRow(0).getStorage().getType();
if (type1.isDouble() && type2.isDouble()) {
BlasDoubleMatrix res = MFactory.denseDoubleMatrix(outputRow, outputCol);
for (int i = 0; i < outputCol; i++) {
Vector row = mat2.getRow(i);
Vector col = mat1.dot(row);
res.setCol(i, col);
}
return res;
} else if (type1.isFloat() && type2.isFloat()) {
BlasFloatMatrix res = MFactory.denseFloatMatrix(outputRow, outputCol);
for (int i = 0; i < outputCol; i++) {
Vector row = mat2.getRow(i);
Vector col = mat1.dot(row);
res.setCol(i, col);
}
return res;
} else {
throw new AngelException("the operation is not supported!");
}
} else {
throw new AngelException("the operation is not supported!");
}
} else {
throw new AngelException("the operation is not supported!");
}
}
use of com.tencent.angel.ml.math2.matrix.RBIntFloatMatrix in project angel by Tencent.
the class DotMatrixExecutor method apply.
private static Matrix apply(BlasFloatMatrix mat1, boolean trans1, RBIntFloatMatrix mat2, boolean trans2) {
if (trans1 && !trans2) {
int outputRows = mat1.getNumCols();
IntFloatVector[] rows = new IntFloatVector[outputRows];
for (int i = 0; i < outputRows; i++) {
Vector col = mat1.getCol(i);
rows[i] = (IntFloatVector) mat2.transDot(col);
}
return MFactory.rbIntFloatMatrix(rows);
} else if (!trans1 && !trans2) {
int outputRows = mat1.getNumRows();
IntFloatVector[] rows = new IntFloatVector[outputRows];
for (int i = 0; i < outputRows; i++) {
Vector row = mat1.getRow(i);
rows[i] = (IntFloatVector) mat2.transDot(row);
}
return MFactory.rbIntFloatMatrix(rows);
} else {
throw new AngelException("the operation is not supported!");
}
}
Aggregations