use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.
the class RssShuffleWriterSuiteJ method getUnsafeRowIterator.
private Iterator<Product2<Integer, UnsafeRow>> getUnsafeRowIterator(final int size, final AtomicInteger total, final boolean mix) {
int current = 0;
ListBuffer<Product2<Integer, UnsafeRow>> list = new ListBuffer<>();
while (current < size) {
int key = total.getAndIncrement();
String value = key + ": " + (mix && rand.nextBoolean() ? GIANT_RECORD : NORMAL_RECORD);
current += value.length();
ListBuffer<Object> values = new ListBuffer<>();
values.$plus$eq(key);
values.$plus$eq(UTF8String.fromString(value));
InternalRow row = InternalRow.apply(values.toSeq());
DataType[] types = new DataType[2];
types[0] = IntegerType$.MODULE$;
types[1] = StringType$.MODULE$;
UnsafeRow unsafeRow = UnsafeProjection.create(types).apply(row);
list.$plus$eq(new Tuple2<>(key % numPartitions, unsafeRow));
}
return list.toIterator();
}
use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.
the class SortBasedShuffleWriter method fastWrite0.
private void fastWrite0(scala.collection.Iterator iterator) throws IOException {
final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = iterator;
while (records.hasNext()) {
final Product2<Integer, UnsafeRow> record = records.next();
final int partitionId = record._1();
final UnsafeRow row = record._2();
final int rowSize = row.getSizeInBytes();
final int serializedRecordSize = 4 + rowSize;
if (serializedRecordSize > pushBufferSize) {
byte[] giantBuffer = new byte[serializedRecordSize];
Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize));
Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), giantBuffer, Platform.BYTE_ARRAY_OFFSET + 4, rowSize);
pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
} else {
long insertStartTime = System.nanoTime();
sortBasedPusher.insertRecord(row.getBaseObject(), row.getBaseOffset(), rowSize, partitionId, true);
writeMetrics.incWriteTime(System.nanoTime() - insertStartTime);
}
tmpLengths[partitionId] += serializedRecordSize;
tmpRecords[partitionId] += 1;
}
}
use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.
the class HashBasedShuffleWriter method fastWrite0.
private void fastWrite0(scala.collection.Iterator iterator) throws IOException {
final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = iterator;
while (records.hasNext()) {
final Product2<Integer, UnsafeRow> record = records.next();
final int partitionId = record._1();
final UnsafeRow row = record._2();
final int rowSize = row.getSizeInBytes();
final int serializedRecordSize = 4 + rowSize;
byte[] buffer = sendBuffers[partitionId];
if (buffer == null) {
buffer = new byte[SEND_BUFFER_SIZE];
sendBuffers[partitionId] = buffer;
peakMemoryUsedBytes += SEND_BUFFER_SIZE;
}
if (serializedRecordSize > SEND_BUFFER_SIZE) {
byte[] giantBuffer = new byte[serializedRecordSize];
Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize));
Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), giantBuffer, Platform.BYTE_ARRAY_OFFSET + 4, rowSize);
pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
} else {
int offset = getOrUpdateOffset(partitionId, buffer, serializedRecordSize);
Platform.putInt(buffer, Platform.BYTE_ARRAY_OFFSET + offset, Integer.reverseBytes(rowSize));
Platform.copyMemory(row.getBaseObject(), row.getBaseOffset(), buffer, Platform.BYTE_ARRAY_OFFSET + offset + 4, rowSize);
sendOffsets[partitionId] = offset + serializedRecordSize;
}
tmpLengths[partitionId] += serializedRecordSize;
tmpRecords[partitionId] += 1;
}
}
use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project BigDataSourceCode by baolibin.
the class UnsafeFixedWidthAggregationMap method iterator.
/**
* Returns an iterator over the keys and values in this map. This uses destructive iterator of
* BytesToBytesMap. So it is illegal to call any other method on this map after `iterator()` has
* been called.
*
* For efficiency, each call returns the same object.
*/
public KVIterator<UnsafeRow, UnsafeRow> iterator() {
return new KVIterator<UnsafeRow, UnsafeRow>() {
private final BytesToBytesMap.MapIterator mapLocationIterator = map.destructiveIterator();
private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length());
private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length());
@Override
public boolean next() {
if (mapLocationIterator.hasNext()) {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
key.pointTo(loc.getKeyBase(), loc.getKeyOffset(), loc.getKeyLength());
value.pointTo(loc.getValueBase(), loc.getValueOffset(), loc.getValueLength());
return true;
} else {
return false;
}
}
@Override
public UnsafeRow getKey() {
return key;
}
@Override
public UnsafeRow getValue() {
return value;
}
@Override
public void close() {
// Do nothing.
}
};
}
use of org.apache.spark.sql.catalyst.expressions.UnsafeRow in project RemoteShuffleService by alibaba.
the class RssShuffleWriterSuiteJ method check.
private void check(final int approximateSize, final RssConf conf, final Serializer serializer) throws Exception {
final boolean useUnsafe = serializer instanceof UnsafeRowSerializer;
final Partitioner partitioner = useUnsafe ? new PartitionIdPassthrough(numPartitions) : new HashPartitioner(numPartitions);
Mockito.doReturn(partitioner).when(dependency).partitioner();
Mockito.doReturn(serializer).when(dependency).serializer();
final File tempFile = new File(tempDir, UUID.randomUUID().toString());
final RssShuffleHandle<Integer, String, String> handle = new RssShuffleHandle<>(appId, host, port, shuffleId, numMaps, dependency);
final ShuffleClient client = new DummyShuffleClient(tempFile);
final HashBasedShuffleWriter<Integer, String, String> writer = new HashBasedShuffleWriter<>(handle, taskContext, conf, client);
assertEquals(useUnsafe, writer.canUseFastWrite());
AtomicInteger total = new AtomicInteger(0);
Iterator iterator = getIterator(approximateSize, total, useUnsafe, false);
int expectChecksum = 0;
for (int i = 0; i < total.intValue(); ++i) {
expectChecksum ^= i;
}
writer.write(iterator);
Option<MapStatus> status = writer.stop(true);
client.shutDown();
assertNotNull(status);
assertTrue(status.isDefined());
assertEquals(bmId, status.get().location());
ShuffleWriteMetrics metrics = taskContext.taskMetrics().shuffleWriteMetrics();
assertEquals(metrics.recordsWritten(), total.intValue());
assertEquals(metrics.bytesWritten(), tempFile.length());
try (FileInputStream fis = new FileInputStream(tempFile)) {
Iterator it = serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
int checksum = 0;
while (it.hasNext()) {
Product2<Integer, ?> record;
if (useUnsafe) {
record = (Product2<Integer, UnsafeRow>) it.next();
} else {
record = (Product2<Integer, String>) it.next();
}
assertNotNull(record);
assertNotNull(record._1());
assertNotNull(record._2());
int key;
String value;
if (useUnsafe) {
UnsafeRow row = (UnsafeRow) record._2();
key = row.getInt(0);
value = row.getString(1);
} else {
key = record._1();
value = (String) record._2();
}
checksum ^= key;
total.decrementAndGet();
assertTrue("value should equals to normal record or giant record with key.", value.equals(key + ": " + NORMAL_RECORD) || value.equals(key + ": " + GIANT_RECORD));
}
assertEquals(0, total.intValue());
assertEquals(expectChecksum, checksum);
} catch (Exception e) {
e.printStackTrace();
fail("Should read with no exception.");
}
}
Aggregations