Search in sources :

Example 1 with ArrowFieldNode

use of org.apache.arrow.vector.ipc.message.ArrowFieldNode in project twister2 by DSC-SPIDAL.

the class ArrowAllToAll method isComplete.

/**
 * Check weather complete
 * @return true if operation is complete
 */
public boolean isComplete() {
    if (completed) {
        return true;
    }
    boolean isAllEmpty = true;
    for (Map.Entry<Integer, PendingSendTable> t : inputs.entrySet()) {
        PendingSendTable pst = t.getValue();
        if (pst.status == ArrowHeader.HEADER_INIT) {
            if (!pst.pending.isEmpty()) {
                pst.currentTable = pst.pending.poll();
                assert !pst.target.isEmpty();
                pst.currentTarget = pst.target.poll();
                pst.status = ArrowHeader.COLUMN_CONTINUE;
            }
        }
        if (pst.status == ArrowHeader.COLUMN_CONTINUE) {
            int noOfColumns = pst.currentTable.getColumns().size();
            boolean canContinue = true;
            while (pst.columnIndex < noOfColumns && canContinue) {
                ArrowColumn col = pst.currentTable.getColumns().get(pst.columnIndex);
                FieldVector vector = col.getVector();
                List<ArrowFieldNode> nodes = new ArrayList<>();
                List<ArrowBuf> bufs = new ArrayList<>();
                appendNodes(vector, nodes, bufs);
                while (pst.bufferIndex < bufs.size()) {
                    ArrowBuf buf = bufs.get(pst.bufferIndex);
                    int[] hdr = new int[HEADER_SIZE];
                    hdr[0] = pst.columnIndex;
                    hdr[1] = pst.bufferIndex;
                    hdr[2] = bufs.size();
                    hdr[3] = vector.getValueCount();
                    int length = (int) buf.capacity();
                    hdr[4] = length;
                    // target
                    hdr[5] = pst.currentTarget;
                    boolean accept = all.insert(buf.nioBuffer(), length, hdr, HEADER_SIZE, t.getKey());
                    if (!accept) {
                        canContinue = false;
                        break;
                    }
                    pst.bufferIndex++;
                }
                if (canContinue) {
                    pst.bufferIndex = 0;
                    pst.columnIndex++;
                }
            }
            if (canContinue) {
                pst.columnIndex = 0;
                pst.bufferIndex = 0;
                pst.status = ArrowHeader.HEADER_INIT;
            }
        }
        if (!pst.pending.isEmpty() || pst.status == ArrowHeader.COLUMN_CONTINUE) {
            isAllEmpty = false;
        }
    }
    if (isAllEmpty && finished && !finishedSent) {
        all.finish();
        finishedSent = true;
    }
    boolean b = isAllEmpty && all.isComplete() && finishedSources.size() == sourceWorkerList.size();
    if (b) {
        completed = true;
    }
    return b;
}
Also used : ArrowBuf(io.netty.buffer.ArrowBuf) ArrayList(java.util.ArrayList) FieldVector(org.apache.arrow.vector.FieldVector) ArrowColumn(edu.iu.dsc.tws.common.table.ArrowColumn) ArrowFieldNode(org.apache.arrow.vector.ipc.message.ArrowFieldNode) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with ArrowFieldNode

use of org.apache.arrow.vector.ipc.message.ArrowFieldNode in project twister2 by DSC-SPIDAL.

the class ArrowAllToAll method appendNodes.

private void appendNodes(FieldVector vector, List<ArrowFieldNode> nodes, List<ArrowBuf> buffers) {
    nodes.add(new ArrowFieldNode(vector.getValueCount(), 0));
    List<ArrowBuf> fieldBuffers = vector.getFieldBuffers();
    int expectedBufferCount = TypeLayout.getTypeBufferCount(vector.getField().getType());
    if (fieldBuffers.size() != expectedBufferCount) {
        throw new IllegalArgumentException(String.format("wrong number of buffers for field %s in vector %s. found: %s", vector.getField(), vector.getClass().getSimpleName(), fieldBuffers));
    }
    buffers.addAll(fieldBuffers);
    for (FieldVector child : vector.getChildrenFromFields()) {
        appendNodes(child, nodes, buffers);
    }
}
Also used : ArrowBuf(io.netty.buffer.ArrowBuf) ArrowFieldNode(org.apache.arrow.vector.ipc.message.ArrowFieldNode) FieldVector(org.apache.arrow.vector.FieldVector)

Example 3 with ArrowFieldNode

use of org.apache.arrow.vector.ipc.message.ArrowFieldNode in project twister2 by DSC-SPIDAL.

the class ArrowAllToAll method onReceive.

@Override
public void onReceive(int source, ChannelBuffer buffer, int length) {
    PendingReceiveTable table = receives.get(source);
    receivedBuffers++;
    ArrowBuf buf = ((ArrowChannelBuffer) buffer).getArrowBuf();
    table.buffers.add(buf);
    if (table.bufferIndex == 0) {
        table.fieldNodes.add(new ArrowFieldNode(table.noArray, 0));
    }
    VectorSchemaRoot schemaRoot = table.root;
    List<FieldVector> fieldVectors = schemaRoot.getFieldVectors();
    // we received everything for this array
    if (table.noBuffers == table.bufferIndex + 1) {
        FieldVector fieldVector = fieldVectors.get(table.columnIndex);
        loadBuffers(fieldVector, fieldVector.getField(), table.buffers.iterator(), table.fieldNodes.iterator());
        table.arrays.add(fieldVector);
        table.buffers.clear();
        if (table.arrays.size() == schemaRoot.getFieldVectors().size()) {
            List<ArrowColumn> columns = new ArrayList<>();
            // create the table
            for (FieldVector v : fieldVectors) {
                ArrowColumn c;
                if (v instanceof BaseFixedWidthVector) {
                    if (v instanceof IntVector) {
                        c = new Int4Column((IntVector) v);
                    } else if (v instanceof Float4Vector) {
                        c = new Float4Column((Float4Vector) v);
                    } else if (v instanceof Float8Vector) {
                        c = new Float8Column((Float8Vector) v);
                    } else if (v instanceof UInt8Vector) {
                        c = new Int8Column((UInt8Vector) v);
                    } else if (v instanceof UInt2Vector) {
                        c = new UInt2Column((UInt2Vector) v);
                    } else {
                        throw new RuntimeException("Un-supported type : " + v.getClass().getName());
                    }
                } else if (v instanceof BaseVariableWidthVector) {
                    if (v instanceof VarCharVector) {
                        c = new StringColumn((VarCharVector) v);
                    } else if (v instanceof VarBinaryVector) {
                        c = new BinaryColumn((VarBinaryVector) v);
                    } else {
                        throw new RuntimeException("Un-supported type : " + v.getClass().getName());
                    }
                } else {
                    throw new RuntimeException("Un-supported type : " + v.getClass().getName());
                }
                columns.add(c);
            }
            Table t = new ArrowTable(schemaRoot.getSchema(), table.noArray, columns);
            LOG.info("Received table from source " + source + " to " + table.target + " count" + t.rowCount());
            recvCallback.onReceive(source, table.target, t);
            table.clear();
        }
    }
}
Also used : BaseFixedWidthVector(org.apache.arrow.vector.BaseFixedWidthVector) VectorSchemaRoot(org.apache.arrow.vector.VectorSchemaRoot) ArrowBuf(io.netty.buffer.ArrowBuf) Float4Vector(org.apache.arrow.vector.Float4Vector) BinaryColumn(edu.iu.dsc.tws.common.table.arrow.BinaryColumn) ArrayList(java.util.ArrayList) VarBinaryVector(org.apache.arrow.vector.VarBinaryVector) ArrowColumn(edu.iu.dsc.tws.common.table.ArrowColumn) BaseVariableWidthVector(org.apache.arrow.vector.BaseVariableWidthVector) ArrowFieldNode(org.apache.arrow.vector.ipc.message.ArrowFieldNode) Int8Column(edu.iu.dsc.tws.common.table.arrow.Int8Column) StringColumn(edu.iu.dsc.tws.common.table.arrow.StringColumn) Table(edu.iu.dsc.tws.common.table.Table) ArrowTable(edu.iu.dsc.tws.common.table.arrow.ArrowTable) IntVector(org.apache.arrow.vector.IntVector) UInt2Column(edu.iu.dsc.tws.common.table.arrow.UInt2Column) Float8Vector(org.apache.arrow.vector.Float8Vector) VarCharVector(org.apache.arrow.vector.VarCharVector) FieldVector(org.apache.arrow.vector.FieldVector) Float4Column(edu.iu.dsc.tws.common.table.arrow.Float4Column) UInt8Vector(org.apache.arrow.vector.UInt8Vector) Float8Column(edu.iu.dsc.tws.common.table.arrow.Float8Column) Int4Column(edu.iu.dsc.tws.common.table.arrow.Int4Column) ArrowTable(edu.iu.dsc.tws.common.table.arrow.ArrowTable) UInt2Vector(org.apache.arrow.vector.UInt2Vector)

Example 4 with ArrowFieldNode

use of org.apache.arrow.vector.ipc.message.ArrowFieldNode in project twister2 by DSC-SPIDAL.

the class ArrowAllToAll method loadBuffers.

private void loadBuffers(FieldVector vector, Field field, Iterator<ArrowBuf> buffers, Iterator<ArrowFieldNode> nodes) {
    checkArgument(nodes.hasNext(), "no more field nodes for for field %s and vector %s", field, vector);
    ArrowFieldNode fieldNode = nodes.next();
    int bufferLayoutCount = TypeLayout.getTypeBufferCount(field.getType());
    List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
    for (int j = 0; j < bufferLayoutCount; j++) {
        ownBuffers.add(buffers.next());
    }
    try {
        vector.loadFieldBuffers(fieldNode, ownBuffers);
    } catch (RuntimeException e) {
        throw new IllegalArgumentException("Could not load buffers for field " + field + ". error message: " + e.getMessage(), e);
    }
    List<Field> children = field.getChildren();
    if (children.size() > 0) {
        List<FieldVector> childrenFromFields = vector.getChildrenFromFields();
        checkArgument(children.size() == childrenFromFields.size(), "should have as many children as in the schema: found %s expected %s", childrenFromFields.size(), children.size());
        for (int i = 0; i < childrenFromFields.size(); i++) {
            Field child = children.get(i);
            FieldVector fieldVector = childrenFromFields.get(i);
            loadBuffers(fieldVector, child, buffers, nodes);
        }
    }
}
Also used : Field(org.apache.arrow.vector.types.pojo.Field) ArrowBuf(io.netty.buffer.ArrowBuf) ArrayList(java.util.ArrayList) ArrowFieldNode(org.apache.arrow.vector.ipc.message.ArrowFieldNode) FieldVector(org.apache.arrow.vector.FieldVector)

Aggregations

ArrowBuf (io.netty.buffer.ArrowBuf)4 FieldVector (org.apache.arrow.vector.FieldVector)4 ArrowFieldNode (org.apache.arrow.vector.ipc.message.ArrowFieldNode)4 ArrayList (java.util.ArrayList)3 ArrowColumn (edu.iu.dsc.tws.common.table.ArrowColumn)2 Table (edu.iu.dsc.tws.common.table.Table)1 ArrowTable (edu.iu.dsc.tws.common.table.arrow.ArrowTable)1 BinaryColumn (edu.iu.dsc.tws.common.table.arrow.BinaryColumn)1 Float4Column (edu.iu.dsc.tws.common.table.arrow.Float4Column)1 Float8Column (edu.iu.dsc.tws.common.table.arrow.Float8Column)1 Int4Column (edu.iu.dsc.tws.common.table.arrow.Int4Column)1 Int8Column (edu.iu.dsc.tws.common.table.arrow.Int8Column)1 StringColumn (edu.iu.dsc.tws.common.table.arrow.StringColumn)1 UInt2Column (edu.iu.dsc.tws.common.table.arrow.UInt2Column)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 BaseFixedWidthVector (org.apache.arrow.vector.BaseFixedWidthVector)1 BaseVariableWidthVector (org.apache.arrow.vector.BaseVariableWidthVector)1 Float4Vector (org.apache.arrow.vector.Float4Vector)1 Float8Vector (org.apache.arrow.vector.Float8Vector)1