Search in sources :

Example 1 with UnsafeRow

use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.

the class RssShuffleWriterSuiteJ method getUnsafeRowIterator.

private Iterator<Product2<Integer, UnsafeRow>> getUnsafeRowIterator(final int size, final AtomicInteger total, final boolean mix) {
    int current = 0;
    ListBuffer<Product2<Integer, UnsafeRow>> list = new ListBuffer<>();
    while (current < size) {
        int key = total.getAndIncrement();
        String value = key + ": " + (mix && rand.nextBoolean() ? GIANT_RECORD : NORMAL_RECORD);
        current += value.length();
        ListBuffer<Object> values = new ListBuffer<>();
        values.$plus$eq(key);
        values.$plus$eq(UTF8String.fromString(value));
        InternalRow row = InternalRow.apply(values.toSeq());
        DataType[] types = new DataType[2];
        types[0] = IntegerType$.MODULE$;
        types[1] = StringType$.MODULE$;
        UnsafeRow unsafeRow = UnsafeProjection.create(types).apply(row);
        list.$plus$eq(new Tuple2<>(key % numPartitions, unsafeRow));
    }
    return list.toIterator();
}
Also used : Product2(scala.Product2) ListBuffer(scala.collection.mutable.ListBuffer) DataType(org.apache.spark.sql.types.DataType) UTF8String(org.apache.spark.unsafe.types.UTF8String) UnsafeRow(org.apache.spark.sql.catalyst.expressions.UnsafeRow) InternalRow(org.apache.spark.sql.catalyst.InternalRow)

Example 2 with UnsafeRow

use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.

the class SortBasedShuffleWriter method fastWrite0.

private void fastWrite0(scala.collection.Iterator iterator) throws IOException {
    final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = iterator;
    while (records.hasNext()) {
        final Product2<Integer, UnsafeRow> record = records.next();
        final int partitionId = record._1();
        final UnsafeRow row = record._2();
        final int rowSize = row.getSizeInBytes();
        final int serializedRecordSize = 4 + rowSize;
        if (serializedRecordSize > pushBufferSize) {
            byte[] giantBuffer = new byte[serializedRecordSize];
            Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize));
            Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), giantBuffer, Platform.BYTE_ARRAY_OFFSET + 4, rowSize);
            pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
        } else {
            long insertStartTime = System.nanoTime();
            sortBasedPusher.insertRecord(row.getBaseObject(), row.getBaseOffset(), rowSize, partitionId, true);
            writeMetrics.incWriteTime(System.nanoTime() - insertStartTime);
        }
        tmpLengths[partitionId] += serializedRecordSize;
        tmpRecords[partitionId] += 1;
    }
}
Also used : Product2(scala.Product2) UnsafeRow(org.apache.spark.sql.catalyst.expressions.UnsafeRow)

Example 3 with UnsafeRow

use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.

the class HashBasedShuffleWriter method fastWrite0.

private void fastWrite0(scala.collection.Iterator iterator) throws IOException {
    final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = iterator;
    while (records.hasNext()) {
        final Product2<Integer, UnsafeRow> record = records.next();
        final int partitionId = record._1();
        final UnsafeRow row = record._2();
        final int rowSize = row.getSizeInBytes();
        final int serializedRecordSize = 4 + rowSize;
        byte[] buffer = sendBuffers[partitionId];
        if (buffer == null) {
            buffer = new byte[SEND_BUFFER_SIZE];
            sendBuffers[partitionId] = buffer;
            peakMemoryUsedBytes += SEND_BUFFER_SIZE;
        }
        if (serializedRecordSize > SEND_BUFFER_SIZE) {
            byte[] giantBuffer = new byte[serializedRecordSize];
            Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize));
            Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), giantBuffer, Platform.BYTE_ARRAY_OFFSET + 4, rowSize);
            pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
        } else {
            int offset = getOrUpdateOffset(partitionId, buffer, serializedRecordSize);
            Platform.putInt(buffer, Platform.BYTE_ARRAY_OFFSET + offset, Integer.reverseBytes(rowSize));
            Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), buffer, Platform.BYTE_ARRAY_OFFSET + offset + 4, rowSize);
            sendOffsets[partitionId] = offset + serializedRecordSize;
        }
        tmpLengths[partitionId] += serializedRecordSize;
        tmpRecords[partitionId] += 1;
    }
}
Also used : Product2(scala.Product2) UnsafeRow(org.apache.spark.sql.catalyst.expressions.UnsafeRow)

Example 4 with UnsafeRow

use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project BigDataSourceCode by baolibin.

the class UnsafeFixedWidthAggregationMap method iterator.

/**
 * Returns an iterator over the keys and values in this map. This uses destructive iterator of
 * BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has
 * been called.
 *
 * For efficiency, each call returns the same object.
 */
public KVIterator<UnsafeRow, UnsafeRow> iterator() {
    return new KVIterator<UnsafeRow, UnsafeRow>() {

        private final BytesToBytesMap.MapIterator mapLocationIterator = map.destructiveIterator();

        private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length());

        private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length());

        @Override
        public boolean next() {
            if (mapLocationIterator.hasNext()) {
                final BytesToBytesMap.Location loc = mapLocationIterator.next();
                key.pointTo(loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength());
                value.pointTo(loc.getValueBase(), loc.getValueOffset(), loc.getValueLength());
                return true;
            } else {
                return false;
            }
        }

        @Override
        public UnsafeRow getKey() {
            return key;
        }

        @Override
        public UnsafeRow getValue() {
            return value;
        }

        @Override
        public void close() {
        // Do nothing.
        }
    };
}
Also used : BytesToBytesMap(org.apache.spark.unsafe.map.BytesToBytesMap) UnsafeRow(org.apache.spark.sql.catalyst.expressions.UnsafeRow) KVIterator(org.apache.spark.unsafe.KVIterator)

Example 5 with UnsafeRow

use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.

the class RssShuffleWriterSuiteJ method check.

private void check(final int approximateSize, final RssConf conf, final Serializer serializer) throws Exception {
    final boolean useUnsafe = serializer instanceof UnsafeRowSerializer;
    final Partitioner partitioner = useUnsafe ? new PartitionIdPassthrough(numPartitions) : new HashPartitioner(numPartitions);
    Mockito.doReturn(partitioner).when(dependency).partitioner();
    Mockito.doReturn(serializer).when(dependency).serializer();
    final File tempFile = new File(tempDir, UUID.randomUUID().toString());
    final RssShuffleHandle<Integer, String, String> handle = new RssShuffleHandle<>(appId, host, port, shuffleId, numMaps, dependency);
    final ShuffleClient client = new DummyShuffleClient(tempFile);
    final HashBasedShuffleWriter<Integer, String, String> writer = new HashBasedShuffleWriter<>(handle, taskContext, conf, client);
    assertEquals(useUnsafe, writer.canUseFastWrite());
    AtomicInteger total = new AtomicInteger(0);
    Iterator iterator = getIterator(approximateSize, total, useUnsafe, false);
    int expectChecksum = 0;
    for (int i = 0; i < total.intValue(); ++i) {
        expectChecksum ^= i;
    }
    writer.write(iterator);
    Option<MapStatus> status = writer.stop(true);
    client.shutDown();
    assertNotNull(status);
    assertTrue(status.isDefined());
    assertEquals(bmId, status.get().location());
    ShuffleWriteMetrics metrics = taskContext.taskMetrics().shuffleWriteMetrics();
    assertEquals(metrics.recordsWritten(), total.intValue());
    assertEquals(metrics.bytesWritten(), tempFile.length());
    try (FileInputStream fis = new FileInputStream(tempFile)) {
        Iterator it = serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
        int checksum = 0;
        while (it.hasNext()) {
            Product2<Integer, ?> record;
            if (useUnsafe) {
                record = (Product2<Integer, UnsafeRow>) it.next();
            } else {
                record = (Product2<Integer, String>) it.next();
            }
            assertNotNull(record);
            assertNotNull(record._1());
            assertNotNull(record._2());
            int key;
            String value;
            if (useUnsafe) {
                UnsafeRow row = (UnsafeRow) record._2();
                key = row.getInt(0);
                value = row.getString(1);
            } else {
                key = record._1();
                value = (String) record._2();
            }
            checksum ^= key;
            total.decrementAndGet();
            assertTrue("value should equals to normal record or giant record with key.", value.equals(key + ": " + NORMAL_RECORD) || value.equals(key + ": " + GIANT_RECORD));
        }
        assertEquals(0, total.intValue());
        assertEquals(expectChecksum, checksum);
    } catch (Exception e) {
        e.printStackTrace();
        fail("Should read with no exception.");
    }
}
Also used : PartitionIdPassthrough(org.apache.spark.sql.execution.PartitionIdPassthrough) DummyShuffleClient(com.aliyun.emr.rss.client.DummyShuffleClient) UTF8String(org.apache.spark.unsafe.types.UTF8String) ShuffleWriteMetrics(org.apache.spark.executor.ShuffleWriteMetrics) UnsafeRow(org.apache.spark.sql.catalyst.expressions.UnsafeRow) FileInputStream(java.io.FileInputStream) IOException(java.io.IOException) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Iterator(scala.collection.Iterator) HashPartitioner(org.apache.spark.HashPartitioner) MapStatus(org.apache.spark.scheduler.MapStatus) UnsafeRowSerializer(org.apache.spark.sql.execution.UnsafeRowSerializer) DummyShuffleClient(com.aliyun.emr.rss.client.DummyShuffleClient) ShuffleClient(com.aliyun.emr.rss.client.ShuffleClient) File(java.io.File) Partitioner(org.apache.spark.Partitioner) HashPartitioner(org.apache.spark.HashPartitioner)

Aggregations

UnsafeRow (org.apache.spark.sql.catalyst.expressions.UnsafeRow)6 UTF8String (org.apache.spark.unsafe.types.UTF8String)3 Product2 (scala.Product2)3 DummyShuffleClient (com.aliyun.emr.rss.client.DummyShuffleClient)2 ShuffleClient (com.aliyun.emr.rss.client.ShuffleClient)2 File (java.io.File)2 FileInputStream (java.io.FileInputStream)2 IOException (java.io.IOException)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 HashPartitioner (org.apache.spark.HashPartitioner)2 Partitioner (org.apache.spark.Partitioner)2 ShuffleWriteMetrics (org.apache.spark.executor.ShuffleWriteMetrics)2 MapStatus (org.apache.spark.scheduler.MapStatus)2 PartitionIdPassthrough (org.apache.spark.sql.execution.PartitionIdPassthrough)2 UnsafeRowSerializer (org.apache.spark.sql.execution.UnsafeRowSerializer)2 Iterator (scala.collection.Iterator)2 InternalRow (org.apache.spark.sql.catalyst.InternalRow)1 DataType (org.apache.spark.sql.types.DataType)1 KVIterator (org.apache.spark.unsafe.KVIterator)1 BytesToBytesMap (org.apache.spark.unsafe.map.BytesToBytesMap)1