Search in sources :

Example 1 with IntArrayList

use of org.apache.flink.runtime.util.IntArrayList in project flink by apache.

the class CompactingHashTable method resizeHashTable.

/**
	 * Attempts to double the number of buckets
	 * 
	 * @return true on success
	 * @throws IOException 
	 */
private boolean resizeHashTable() throws IOException {
    final int newNumBuckets = 2 * this.numBuckets;
    final int bucketsPerSegment = this.bucketsPerSegmentMask + 1;
    final int newNumSegments = (newNumBuckets + (bucketsPerSegment - 1)) / bucketsPerSegment;
    final int additionalSegments = newNumSegments - this.buckets.length;
    final int numPartitions = this.partitions.size();
    if (this.availableMemory.size() < additionalSegments) {
        for (int i = 0; i < numPartitions; i++) {
            compactPartition(i);
            if (this.availableMemory.size() >= additionalSegments) {
                break;
            }
        }
    }
    if (this.availableMemory.size() < additionalSegments || this.closed) {
        return false;
    } else {
        this.isResizing = true;
        // allocate new buckets
        final int startOffset = (this.numBuckets * HASH_BUCKET_SIZE) % this.segmentSize;
        final int oldNumBuckets = this.numBuckets;
        final int oldNumSegments = this.buckets.length;
        MemorySegment[] mergedBuckets = new MemorySegment[newNumSegments];
        System.arraycopy(this.buckets, 0, mergedBuckets, 0, this.buckets.length);
        this.buckets = mergedBuckets;
        this.numBuckets = newNumBuckets;
        // initialize all new buckets
        boolean oldSegment = (startOffset != 0);
        final int startSegment = oldSegment ? (oldNumSegments - 1) : oldNumSegments;
        for (int i = startSegment, bucket = oldNumBuckets; i < newNumSegments && bucket < this.numBuckets; i++) {
            MemorySegment seg;
            int bucketOffset;
            if (oldSegment) {
                // the first couple of new buckets may be located on an old segment
                seg = this.buckets[i];
                for (int k = (oldNumBuckets % bucketsPerSegment); k < bucketsPerSegment && bucket < this.numBuckets; k++, bucket++) {
                    bucketOffset = k * HASH_BUCKET_SIZE;
                    // initialize the header fields
                    seg.put(bucketOffset + HEADER_PARTITION_OFFSET, assignPartition(bucket, (byte) numPartitions));
                    seg.putInt(bucketOffset + HEADER_COUNT_OFFSET, 0);
                    seg.putLong(bucketOffset + HEADER_FORWARD_OFFSET, BUCKET_FORWARD_POINTER_NOT_SET);
                }
            } else {
                seg = getNextBuffer();
                // go over all buckets in the segment
                for (int k = 0; k < bucketsPerSegment && bucket < this.numBuckets; k++, bucket++) {
                    bucketOffset = k * HASH_BUCKET_SIZE;
                    // initialize the header fields
                    seg.put(bucketOffset + HEADER_PARTITION_OFFSET, assignPartition(bucket, (byte) numPartitions));
                    seg.putInt(bucketOffset + HEADER_COUNT_OFFSET, 0);
                    seg.putLong(bucketOffset + HEADER_FORWARD_OFFSET, BUCKET_FORWARD_POINTER_NOT_SET);
                }
            }
            this.buckets[i] = seg;
            // we write on at most one old segment
            oldSegment = false;
        }
        int hashOffset;
        int hash;
        int pointerOffset;
        long pointer;
        IntArrayList hashList = new IntArrayList(NUM_ENTRIES_PER_BUCKET);
        LongArrayList pointerList = new LongArrayList(NUM_ENTRIES_PER_BUCKET);
        IntArrayList overflowHashes = new IntArrayList(64);
        LongArrayList overflowPointers = new LongArrayList(64);
        // go over all buckets and split them between old and new buckets
        for (int i = 0; i < numPartitions; i++) {
            InMemoryPartition<T> partition = this.partitions.get(i);
            final MemorySegment[] overflowSegments = partition.overflowSegments;
            int posHashCode;
            for (int j = 0, bucket = i; j < this.buckets.length && bucket < oldNumBuckets; j++) {
                MemorySegment segment = this.buckets[j];
                // go over all buckets in the segment belonging to the partition
                for (int k = bucket % bucketsPerSegment; k < bucketsPerSegment && bucket < oldNumBuckets; k += numPartitions, bucket += numPartitions) {
                    int bucketOffset = k * HASH_BUCKET_SIZE;
                    if ((int) segment.get(bucketOffset + HEADER_PARTITION_OFFSET) != i) {
                        throw new IOException("Accessed wrong bucket! wanted: " + i + " got: " + segment.get(bucketOffset + HEADER_PARTITION_OFFSET));
                    }
                    // loop over all segments that are involved in the bucket (original bucket plus overflow buckets)
                    int countInSegment = segment.getInt(bucketOffset + HEADER_COUNT_OFFSET);
                    int numInSegment = 0;
                    pointerOffset = bucketOffset + BUCKET_POINTER_START_OFFSET;
                    hashOffset = bucketOffset + BUCKET_HEADER_LENGTH;
                    while (true) {
                        while (numInSegment < countInSegment) {
                            hash = segment.getInt(hashOffset);
                            if ((hash % this.numBuckets) != bucket && (hash % this.numBuckets) != (bucket + oldNumBuckets)) {
                                throw new IOException("wanted: " + bucket + " or " + (bucket + oldNumBuckets) + " got: " + hash % this.numBuckets);
                            }
                            pointer = segment.getLong(pointerOffset);
                            hashList.add(hash);
                            pointerList.add(pointer);
                            pointerOffset += POINTER_LEN;
                            hashOffset += HASH_CODE_LEN;
                            numInSegment++;
                        }
                        // this segment is done. check if there is another chained bucket
                        final long forwardPointer = segment.getLong(bucketOffset + HEADER_FORWARD_OFFSET);
                        if (forwardPointer == BUCKET_FORWARD_POINTER_NOT_SET) {
                            break;
                        }
                        final int overflowSegNum = (int) (forwardPointer >>> 32);
                        segment = overflowSegments[overflowSegNum];
                        bucketOffset = (int) forwardPointer;
                        countInSegment = segment.getInt(bucketOffset + HEADER_COUNT_OFFSET);
                        pointerOffset = bucketOffset + BUCKET_POINTER_START_OFFSET;
                        hashOffset = bucketOffset + BUCKET_HEADER_LENGTH;
                        numInSegment = 0;
                    }
                    segment = this.buckets[j];
                    bucketOffset = k * HASH_BUCKET_SIZE;
                    // reset bucket for re-insertion
                    segment.putInt(bucketOffset + HEADER_COUNT_OFFSET, 0);
                    segment.putLong(bucketOffset + HEADER_FORWARD_OFFSET, BUCKET_FORWARD_POINTER_NOT_SET);
                    // refill table
                    if (hashList.size() != pointerList.size()) {
                        throw new IOException("Pointer and hash counts do not match. hashes: " + hashList.size() + " pointer: " + pointerList.size());
                    }
                    int newSegmentIndex = (bucket + oldNumBuckets) / bucketsPerSegment;
                    MemorySegment newSegment = this.buckets[newSegmentIndex];
                    // we need to avoid overflows in the first run
                    int oldBucketCount = 0;
                    int newBucketCount = 0;
                    while (!hashList.isEmpty()) {
                        hash = hashList.removeLast();
                        pointer = pointerList.removeLong(pointerList.size() - 1);
                        posHashCode = hash % this.numBuckets;
                        if (posHashCode == bucket && oldBucketCount < NUM_ENTRIES_PER_BUCKET) {
                            bucketOffset = (bucket % bucketsPerSegment) * HASH_BUCKET_SIZE;
                            insertBucketEntryFromStart(segment, bucketOffset, hash, pointer, partition.getPartitionNumber());
                            oldBucketCount++;
                        } else if (posHashCode == (bucket + oldNumBuckets) && newBucketCount < NUM_ENTRIES_PER_BUCKET) {
                            bucketOffset = ((bucket + oldNumBuckets) % bucketsPerSegment) * HASH_BUCKET_SIZE;
                            insertBucketEntryFromStart(newSegment, bucketOffset, hash, pointer, partition.getPartitionNumber());
                            newBucketCount++;
                        } else if (posHashCode == (bucket + oldNumBuckets) || posHashCode == bucket) {
                            overflowHashes.add(hash);
                            overflowPointers.add(pointer);
                        } else {
                            throw new IOException("Accessed wrong bucket. Target: " + bucket + " or " + (bucket + oldNumBuckets) + " Hit: " + posHashCode);
                        }
                    }
                    hashList.clear();
                    pointerList.clear();
                }
            }
            // reset partition's overflow buckets and reclaim their memory
            this.availableMemory.addAll(partition.resetOverflowBuckets());
            // clear overflow lists
            int bucketArrayPos;
            int bucketInSegmentPos;
            MemorySegment bucket;
            while (!overflowHashes.isEmpty()) {
                hash = overflowHashes.removeLast();
                pointer = overflowPointers.removeLong(overflowPointers.size() - 1);
                posHashCode = hash % this.numBuckets;
                bucketArrayPos = posHashCode >>> this.bucketsPerSegmentBits;
                bucketInSegmentPos = (posHashCode & this.bucketsPerSegmentMask) << NUM_INTRA_BUCKET_BITS;
                bucket = this.buckets[bucketArrayPos];
                insertBucketEntryFromStart(bucket, bucketInSegmentPos, hash, pointer, partition.getPartitionNumber());
            }
            overflowHashes.clear();
            overflowPointers.clear();
        }
        this.isResizing = false;
        return true;
    }
}
Also used : LongArrayList(org.apache.flink.runtime.util.LongArrayList) IOException(java.io.IOException) IntArrayList(org.apache.flink.runtime.util.IntArrayList) MemorySegment(org.apache.flink.core.memory.MemorySegment)

Aggregations

IOException (java.io.IOException)1 MemorySegment (org.apache.flink.core.memory.MemorySegment)1 IntArrayList (org.apache.flink.runtime.util.IntArrayList)1 LongArrayList (org.apache.flink.runtime.util.LongArrayList)1