use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class MultilayerPerceptronTrainBatchOp method getMaxAbsVector.
/**
* Get distinct labels and assign each label an index.
*/
private static DataSet<DenseVector> getMaxAbsVector(BatchOperator<?> data, final String[] featureColNames, final String vectorColName, final int vecSize) {
final boolean isVectorInput = !StringUtils.isNullOrWhitespaceOnly(vectorColName);
final int vectorColIdx = isVectorInput ? TableUtil.findColIndexWithAssertAndHint(data.getColNames(), vectorColName) : -1;
final int[] featureColIdx = isVectorInput ? null : TableUtil.findColIndicesWithAssertAndHint(data.getSchema(), featureColNames);
return data.getDataSet().mapPartition(new MapPartitionFunction<Row, DenseVector>() {
private static final long serialVersionUID = 7200866630508717163L;
@Override
public void mapPartition(Iterable<Row> iterable, Collector<DenseVector> collector) throws Exception {
DenseVector maxAbs = null;
if (isVectorInput) {
Map<Integer, Double> sparseMaxAbs = new HashMap<>();
int maxIdx = -1;
for (Row value : iterable) {
Vector vec = VectorUtil.getVector(value.getField(vectorColIdx));
if (maxAbs == null) {
maxAbs = new DenseVector(vecSize);
if (vec instanceof DenseVector) {
for (int i = 0; i < vec.size(); ++i) {
maxAbs.set(i, Math.abs(vec.get(i)));
}
} else {
int[] indices = ((SparseVector) vec).getIndices();
for (int i = 0; i < indices.length; ++i) {
maxAbs.set(indices[i], Math.abs(vec.get(indices[i])));
}
}
} else {
if (vec instanceof DenseVector) {
for (int i = 0; i < maxAbs.size(); ++i) {
maxAbs.set(i, Math.max(maxAbs.get(i), Math.abs(vec.get(i))));
}
} else {
int[] indices = ((SparseVector) vec).getIndices();
for (int i = 0; i < indices.length; ++i) {
maxAbs.set(indices[i], Math.max(maxAbs.get(indices[i]), Math.abs(vec.get(indices[i]))));
}
}
}
}
} else {
int n = featureColIdx.length;
for (Row value : iterable) {
if (maxAbs == null) {
maxAbs = new DenseVector(n);
for (int i = 0; i < n; i++) {
double v = ((Number) value.getField(featureColIdx[i])).doubleValue();
maxAbs.set(i, Math.abs(v));
}
} else {
for (int i = 0; i < n; i++) {
double v = ((Number) value.getField(featureColIdx[i])).doubleValue();
maxAbs.set(i, Math.max(maxAbs.get(i), Math.abs(v)));
}
}
}
}
if (maxAbs == null) {
return;
}
collector.collect(maxAbs);
}
}).reduceGroup(new GroupReduceFunction<DenseVector, DenseVector>() {
private static final long serialVersionUID = 880634306611878638L;
@Override
public void reduce(Iterable<DenseVector> iterable, Collector<DenseVector> collector) throws Exception {
DenseVector maxAbs = null;
for (DenseVector vec : iterable) {
if (maxAbs == null) {
maxAbs = vec;
} else {
for (int i = 0; i < maxAbs.size(); ++i) {
maxAbs.set(i, Math.max(maxAbs.get(i), Math.abs(vec.get(i))));
}
}
}
collector.collect(maxAbs);
}
});
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class AftSurvivalRegTrainBatchOp method svStd.
private static Vector svStd(Vector tmpVector, double[] std) {
SparseVector sv = (SparseVector) tmpVector;
int[] index = sv.getIndices();
double[] values = sv.getValues();
int size = index.length;
for (int i = 0; i < size; i++) {
values[i] /= std[index[i]];
}
return sv;
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class LibSvmSinkBatchOp method formatLibSvm.
public static String formatLibSvm(Object label, Object vector, int startIndex) {
String labelStr = "";
if (label != null) {
labelStr = String.valueOf(label);
}
String vectorStr = "";
if (vector != null) {
if (vector instanceof String) {
if (((String) vector).startsWith(("[")) && ((String) vector).endsWith("]")) {
vector = ((String) vector).substring(1, ((String) vector).length() - 1);
}
}
Vector v = VectorUtil.getVector(vector);
if (v instanceof DenseVector) {
v = toSparseVector((DenseVector) v);
}
int[] indices = ((SparseVector) v).getIndices();
for (int i = 0; i < indices.length; i++) {
indices[i] = indices[i] + startIndex;
}
vectorStr = v.toString();
}
return labelStr + " " + vectorStr;
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class MultiHotModelMapper method map.
@Override
protected void map(SlicedSelectedSample selection, SlicedResult result) throws Exception {
if (encode.equals(Encode.ASSEMBLED_VECTOR)) {
Tuple2<Integer, int[]> indices = getIndicesAndSize(selection);
double[] vals = new double[indices.f1.length];
Arrays.fill(vals, 1.0);
if (indices.f1.length != 0) {
result.set(0, new SparseVector(indices.f0, indices.f1, vals));
}
} else if (encode.equals(Encode.VECTOR)) {
for (int i = 0; i < selection.length(); ++i) {
String str = (String) selection.get(i);
Tuple2<Integer, int[]> indices = getSingleIndicesAndSize(selectedCols[i], str);
double[] vals = new double[indices.f1.length];
Arrays.fill(vals, 1.0);
if (indices.f1.length != 0) {
result.set(i, new SparseVector(indices.f0, indices.f1, vals));
}
}
}
}
use of com.alibaba.alink.common.linalg.SparseVector in project Alink by alibaba.
the class ManHattanDistance method calc.
@Override
void calc(FastDistanceVectorData left, FastDistanceSparseData right, double[] res) {
Arrays.fill(res, 0.0);
int[][] rightIndices = right.getIndices();
double[][] rightValues = right.getValues();
if (left.getVector() instanceof DenseVector) {
double[] leftData = ((DenseVector) left.getVector()).getData();
for (int i = 0; i < leftData.length; i++) {
if (null != rightIndices[i]) {
for (int j = 0; j < rightIndices[i].length; j++) {
res[rightIndices[i][j]] = res[rightIndices[i][j]] - Math.abs(rightValues[i][j]) - Math.abs(leftData[i]) + Math.abs(rightValues[i][j] - leftData[i]);
}
}
}
} else {
SparseVector vector = (SparseVector) left.getVector();
int[] indices = vector.getIndices();
double[] values = vector.getValues();
for (int i = 0; i < indices.length; i++) {
if (null != rightIndices[indices[i]]) {
for (int j = 0; j < rightIndices[indices[i]].length; j++) {
res[rightIndices[indices[i]][j]] = res[rightIndices[indices[i]][j]] - Math.abs(rightValues[indices[i]][j]) - Math.abs(values[i]) + Math.abs(rightValues[indices[i]][j] - values[i]);
}
}
}
}
double[] rightLabel = right.getLabel().getData();
double leftLabel = left.label.get(0);
for (int i = 0; i < res.length; i++) {
res[i] += rightLabel[i] + leftLabel;
}
}
Aggregations