[Notes] Binary Heap

Posted by 西维蜀黍的OJ Blog on 2023-09-10, Last Modified on 2025-03-27

Knowledge

Refer to https://swsmile.info/post/data-structure-heap/

面试中碰到的

The minimum total weight of chocolates after d days

from typing import List
import heapq

def findMinWeight(weights: List[int], d: int) -> int:
    """
    Find the minimum total weight of chocolates after d days.

    Parameters:
    weights (list of int): The weights of each chocolate.
    d (int): The number of days.

    Returns:
    int: The minimum total weight of chocolates after d days.
    """

    # Convert the list of chocolate weights into a max-heap.
    # Python's heapq module implements a min-heap, so we store negative values
    # to simulate a max-heap.
    max_heap = [-w for w in weights]
    heapq.heapify(max_heap) # time: O(n)

    # Each day, eat half of the heaviest chocolate.
    for _ in range(d): # time: O(dlogn)
        # Remove the heaviest chocolate, eat half of it, and put back the rest.
        heaviest = -heapq.heappop(max_heap) # time: O(1)
        eaten = heaviest // 2
        remaining = heaviest - eaten
        # Add the remaining part back to the heap.
        heapq.heappush(max_heap, -remaining) # time: O(logn)

    # The total weight is the sum of the negatives of the max-heap values.
    total_weight = -sum(max_heap) # time: O(n)
    return total_weight

# Example usage:
chocolates = [30, 20, 25]
days = 4
findMinWeight(chocolates, days)

解析

215. Kth Largest Element in an Array

  • quick select:如果能保证 p 左边的元素们 都小于或等于 p 指向的元素 ,且 p 指向的元素 小于 p 右边的元素们。且 p 右边有 k-1 个元素,则 p 一定指向 第k大的元素
# Sort
#   - time: O(nlog(n))
#   - space: O(1) or O(n) depending on the sorting algorithm.
    def findKthLargest(self, nums: List[int], k: int) -> int:
        nums.sort()
        return nums[len(nums) - k]

# min-heap
#    - time: O((n-k)logn)
#    - space: O(1)
def findKthLargest(self, nums: List[int], k: int) -> int:
		heapq.heapify(nums) # time: O(n)
		while len(nums) > k: # time: O((n-k)logn)
			heapq.heappop(nums)

		return heapq.heappop(nums)

# min-heap 
#    - time: O(n + klogn)
#    - space: O(1)
	def findKthLargest(self, nums: List[int], k: int) -> int:
		for i in range(len(nums)): # time: O(n)
			nums[i] = -nums[i]
		heapq.heapify(nums)  # time: O(n)

		while k > 1:
			heapq.heappop(nums)
			k -= 1
		return -nums[0]
  
# QuickSelect (like quick sort) - myself1,提交会超时
    # Time Complexity:
    #   - Best Case: O(n)
    #   - Average Case: O(n),因为 n + n/2 + n/4 +... = 2n -> O(n),即每做完一次 partition,长度都会少一半
    #   - Worst Case: O(n^2), 如果是 [1,2,3,4,5,6],每做完一次 partition 会变成 [1,2,3,4,5], [1,2,3,4] ... -> 所以是O(n^2)
    # Space Complexity: O(n) # 算上递归使用的调用栈
    def findKthLargest(self, nums, k):
        """
        :type nums: List[int]
        :type k: int
        :rtype: int
        """

        def quickSelect(l, r):
            pivotIdx = r  # Just randomly choose a pivot. In this case, I choose the rightmost element as the pivot
            pivot = nums[pivotIdx]

            # p is the pointer
            p = l
            for i in range(l, r):  # [l, r] 表示需要被进行 partition 的范围
                if nums[i] <= pivot:
                    # 把 <= pivot 的元素,都放到 p 索引的左边
                    nums[i], nums[p] = nums[p], nums[i]
                    p += 1
            # 把 pivot 对应的元素(nums[pivotIdx])和 p 指向的元素交换,以使得在 p 左边的元素都比 p 对应的元素小或等于,而且 p 右边的元素对比 p 对应的元素大
            # nums = [3,2,1,5,6,4], k = 2 -> nums = [3,2,1,4 (p),6,5], 即 4 左边的元素都 <= 4,4 右边的元素都 > 4。后续继续去处理 [6, 5]
            nums[pivotIdx], nums[p] = nums[p], nums[pivotIdx]

            if p == len(nums) - k:
                # 这时,在 p 左边的元素都比 p 对应的元素小或等于,而且 p 右边的元素都比 p 对应的元素大。所以直接得出答案
                return nums[p]
            elif p < len(nums) - k:
                # 这时,说明在 p 右边的元素都比 p 对应的元素大,而要找的第 k 大的元素在 p 右边,因而要对 p 右边的元素们进行 partition,所以 left = p + 1
                return quickSelect(p + 1, r)
            else:
                return quickSelect(l, p - 1)

        return quickSelect(0, len(nums) - 1)

# QuickSelect (like quick sort) - myself2,提交会超时
    # Time Complexity:
    #   - Best Case: O(n)
    #   - Average Case: O(n),因为 n + n/2 + n/4 +... = 2n -> O(n),即每做完一次 partition,长度都会少一半
    #   - Worst Case: O(n^2), 如果是 [1,2,3,4,5,6],每做完一次 partition 会变成 [1,2,3,4,5], [1,2,3,4] ... -> 所以是O(n^2)
    # Space Complexity: O(n) # 算上递归使用的调用栈
    def findKthLargest(self, nums, k):
        """
        :type nums: List[int]
        :type k: int
        :rtype: int
        """

        def partition(l, r):
            pivotIdx = r  # Just randomly choose a pivot. In this case, I choose the rightmost element as the pivot
            pivot = nums[pivotIdx]

            # p is the pointer
            p = l
            for i in range(l, r):  # [l, r] 表示需要被进行 partition 的范围
                if nums[i] <= pivot:
                    # 把 <= pivot 的元素,都放到 p 索引的左边
                    nums[i], nums[p] = nums[p], nums[i]
                    p += 1
            # 把 pivot 对应的元素(nums[pivotIdx])和 p 指向的元素交换,以使得在 p 左边的元素都比 p 对应的元素小或等于,而且 p 右边的元素对比 p 对应的元素大
            # nums = [3,2,1,5,6,4], k = 2 -> nums = [3,2,1,4 (p),6,5], 即 4 左边的元素都 <= 4,4 右边的元素都 > 4。后续继续去处理 [6, 5]
            nums[pivotIdx], nums[p] = nums[p], nums[pivotIdx]
            return p

        l, r = 0, len(nums) - 1
        while l < r:
            p = partition(l, r)
            if p == len(nums) - k:
                # 这时,在 p 左边的元素都比 p 对应的元素小或等于,而且 p 右边的元素都比 p 对应的元素大。所以直接得出答案
                break  # 不能直接在这 return,因为比如 [1], k = 1 的情况
            elif p < len(nums) - k:
                # 这时,说明在 p 右边的元素都比 p 对应的元素大,而要找的第 k 大的元素在 p 右边,因而要对 p 右边的元素们进行 partition,所以 left = p + 1
                l = p + 1
            else:
                r = p - 1
        return nums[len(nums) - k]
      
# QuickSelect (like quick sort) - optimal
# Time Complexity:
#   - Best Case: O(n)
#   - Average Case: O(n),因为 n + n/2 + n/4 +... = 2n -> O(n)
#   - Worst Case: O(n^2), 如果是 [1,2,3,4,5,6],每次做完会变成 [1,2,3,4,5], [1,2,3,4] ... -> 所以是O(n^2)
# Space Complexity: O(1)      
    def partition(self, nums: List[int], left: int, right: int) -> int:
        mid = (left + right) >> 1
        nums[mid], nums[left + 1] = nums[left + 1], nums[mid]
        
        if nums[left] < nums[right]:
            nums[left], nums[right] = nums[right], nums[left]
        if nums[left + 1] < nums[right]:
            nums[left + 1], nums[right] = nums[right], nums[left + 1]
        if nums[left] < nums[left + 1]:
            nums[left], nums[left + 1] = nums[left + 1], nums[left]
        
        pivot = nums[left + 1]
        i = left + 1
        j = right
        
        while True:
            while True:
                i += 1
                if not nums[i] > pivot:
                    break
            while True:
                j -= 1
                if not nums[j] < pivot:
                    break
            if i > j:
                break
            nums[i], nums[j] = nums[j], nums[i]
        
        nums[left + 1], nums[j] = nums[j], nums[left + 1]
        return j
    
    def quickSelect(self, nums: List[int], k: int) -> int:
        left = 0
        right = len(nums) - 1
        
        while True:
            if right <= left + 1:
                if right == left + 1 and nums[right] > nums[left]:
                    nums[left], nums[right] = nums[right], nums[left]
                return nums[k]
            
            j = self.partition(nums, left, right)
            
            if j >= k:
                right = j - 1
            if j <= k:
                left = j + 1
    
    def findKthLargest(self, nums: List[int], k: int) -> int:
        return self.quickSelect(nums, k - 1)

ref

355. Design Twitter

  • we need a count to indicate the creat_time of a tweet
  • self.userId2Tweets records the tweets that a user posts with [self.count, tweetId]
  • Every time an user posts a tweet, we are gonna append [self.count, tweetId] to self.userId2Tweets
    • and then we could get the latest tweet of this user by index = len(self.userId2Tweets[userId]) -1, count, tweetId = self.userId2Tweets[userId][index]
      • we could for-loop all followers of an user, and get their latest tweets
      • and then heapify them
      • If less than 10, get the previous tweets of the current tweets, until reaching 10.
class Twitter:
    def __init__(self):
        self.count = 0
        self.tweetMap = defaultdict(list)  # userId -> list of [count, tweetId]
        self.followMap = defaultdict(set)  # userId -> set of followeeIds

    def postTweet(self, userId: int, tweetId: int) -> None:
        self.tweetMap[userId].append([self.count, tweetId])
        self.count -= 1

    def getNewsFeed(self, userId: int) -> List[int]:
        res = []
        minHeap = []

        self.followMap[userId].add(userId)
        for followeeId in self.followMap[userId]:
            if followeeId in self.tweetMap:
                index = len(self.tweetMap[followeeId]) - 1
                count, tweetId = self.tweetMap[followeeId][index]
                heapq.heappush(minHeap, [count, tweetId, followeeId, index - 1])

        while minHeap and len(res) < 10:
            count, tweetId, followeeId, index = heapq.heappop(minHeap)
            res.append(tweetId)
            if index >= 0:
                count, tweetId = self.tweetMap[followeeId][index]
                heapq.heappush(minHeap, [count, tweetId, followeeId, index - 1])
        return res

    def follow(self, followerId: int, followeeId: int) -> None:
        self.followMap[followerId].add(followeeId)

    def unfollow(self, followerId: int, followeeId: int) -> None:
        if followeeId in self.followMap[followerId]:
            self.followMap[followerId].remove(followeeId)

ref

621. Task Scheduler

  • we wanna compute the task which is the most frequently first, so that the ideal count may be as small as possible
  • So we use a MinHeap to store the counts that every task needs to be computed
  • After a task is computed, we add its count_to_compute into a queue with computable_until_when.
  • After a computation, we check the queue to see whether there is a task which is computable
    • if yes, pop it from the queue, and add it to the maxHeap
  • If nothing in the maxHeap and nothing in the queue, it means that we finish everything
# maxHeap - myself
    # - time: O(n*m), n is the size of tasks, and m is the idea time, since image we have [A, A, A, A] and n is 10, so we need n*m
    # - space: o(1) #  since we have at most 26 different characters, O(n0)
    def leastInterval(self, tasks, n):
        """
        :type tasks: List[str]
        :type n: int
        :rtype: int
        """
        m = {}
        for task in tasks:
            m[task] = m.get(task, 0) + 1
        minHeap = []
        for count in m.values():
            minHeap.append(-count)
        heapq.heapify(minHeap)

        stage = []  # 存储了被 complete 一次后的任务,[<还有多少次没完成>,<可以执行该任务的最小时刻>]
        time = 0
        # time: O(n*m), n is the size of tasks, and m is the cooldown time, since image we have [A, A, A, A] and n is 10, so we need n*m
        while minHeap or stage:
            if stage:
                count, minTimestampToAdd = stage[0]
                if minTimestampToAdd <= time:  # 超过了 n,因而可以重新放入 heap,等待被重新执行了
                    stage.pop(0)
                    heapq.heappush(minHeap, -count)

            if minHeap:
                count = -heapq.heappop(minHeap)
                count -= 1
                if count > 0:
                    stage.append([count, n + time + 1])
            time += 1

        return time

# maxHeap - Needcode
# - time: O(n*m), n is the size of tasks, and m is the idea time, since image we have [A, A, A, A] and n is 10, so we need n*m
# - space: o(1) #  since we have at most 26 different characters, O(n0)
def leastInterval(self, tasks: List[str], n: int) -> int:
		m = {}  # space: O(n)
		for task in tasks:  # time: O(n)
			if task not in m:
				m[task] = 1
			else:
				m[task] += 1

		frequencies = []  # space: O(n)
		for frequency in m.values():  # time: O(n)
			frequencies.append(-frequency)
		heapq.heapify(frequencies)

		time = 0
		queue = []  # (count_to_compute, computable_until_when) # space: O(n)
		while frequencies or queue:  # time: O(n*m), n is the size of tasks, and m is the idea time, since image we have [A, A, A, A] and n is 10, so we need n*m
			time += 1
			if len(frequencies) > 0:
				count_to_compute = heapq.heappop(frequencies)
				count_to_compute += 1
				if count_to_compute < 0:
					queue.append((count_to_compute, time + n))

			if queue:
				frequency, computable_until_when = queue[0]
				if computable_until_when <= time:
					queue.pop(0)
					heapq.heappush(frequencies, frequency)
		return time

ref

703. Kth Largest Element in a Stream

  • use min heap with the size of K to memorise the k largest elements
  • 更像 medium
# brute force - sort
class KthLargest(object):

    def __init__(self, k, nums):
        """
        :type k: int
        :type nums: List[int]
        """
        self.k = k
        nums.sort()  # time: O(nlogn), space: O(1)/O(n)
        if len(nums) > k:
            nums = nums[len(nums) - k:]  # 只取 从第 k 大元素到最大的元素,i.e., 比最k大元素还小的元素会被移除 # space: O(n-k)
        self.arr = nums

    # time: O(k), space: O(k)
    def add(self, val):
        """
        :type val: int
        :rtype: int
        """
        self.arr.append(val)
        ...  # time: O(logk), space: O(1),用 binary search 找到要插入的位置
        ... # time: O(k), 插入元素
        if len(self.arr) > self.k:
            self.arr = self.arr[len(self.arr) - self.k:]  # 取从第 k 大元素到最大的元素 # space: O(k), time: O(k)

        return self.arr[len(self.arr) - self.k]

# priority queue/ min-heap
#  - 我们可以使用一个大小为 k 的最小堆来存储前 k 大的元素,其中优先队列的队头为队列中最小的元素,也就是第 k 大的元素。
#  - 在每次 add()的时候,将新元素 push() 到堆中,如果此时堆中的元素超过了 K,那么需要把堆中的最小元素(堆顶)pop()出来。
#  - 此时堆中的最小元素(堆顶)就是整个数据流中的第 K 大元素。
#  - 为什么使用小根堆?
#     - 因为我们需要在堆中保留数据流中的前 K 大元素,使用小根堆能保证每次调用堆的 pop() 函数时,从堆中删除的是堆中的最小的元素(堆顶)。因为比该元素还小的元素,不会影响最后的结果
#  - 为什么能保证堆顶元素是第 K 大元素?
#     - 因为小根堆中保留的一直是堆中的前 K 大的元素,堆的大小是 K,所以堆顶元素是第 K 大元素。
class KthLargest(object):

    # 初始化时间复杂度为:O(nlogk) ,其中 n 为初始化时 nums 的长度
    # 空间复杂度:O(1),不包括 nums 本身占用的空间
    def __init__(self, k, nums):
        """
        :type k: int
        :type nums: List[int]
        """
        # minHeap with k Largest integer
        self.k, self.minHeap = k, nums
        heapq.heapify(self.minHeap)  # minHeap with K largest elements, time: O(n)

        # pop elements until k elements are left in the heap, time: O((n-k)logn),最多总共需要 pop n-k 次
        # 为什么要pop:我们不 care 比 第k大元素 还小的元素,因为比该元素还小的元素,不会影响最后的结果
        while len(self.minHeap) > k:
            heapq.heappop(self.minHeap)  # O(logn)

    # 单次插入时间复杂度为:O(logk)
    # 空间复杂度:O(k)。需要使用优先队列存储前 k 大的元素
    def add(self, val):
        """
        :type val: int
        :rtype: int
        """
        heapq.heappush(self.minHeap, val)  # time: O(logk)
        # 在每次 add() 的时候,将新元素 push() 到堆中,如果此时堆中的元素的数量超过了 K,那么需要把堆中的最小元素(堆顶)pop() 出来。因为我们不 care 比 第k大元素 还小的元素
        if len(self.minHeap) > self.k:  # only keep k elements in the list
            heapq.heappop(self.minHeap)  # time: O(logk)
        return self.minHeap[0]  # the smallest element of the heap is the kth largest element

ref

973. K Closest Points to Origin

  • Python uses the first element to sort
# sort
#   - time: O(nlogn)
#   - space: O(1) or O(n) depending on the sorting algorithm.
class Solution:
    def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
        points.sort(key=lambda p: p[0]**2 + p[1]**2)
        return points[:k]
      
# min-heap
#   - time: O(n + klogn)
#   - space: O(n)
class Solution:
    def kClosest(self, points, k):
        """
        :type points: List[List[int]]
        :type k: int
        :rtype: List[List[int]]
        """
        minHeap = []  # space: O(n)
        for x, y in points:
            minHeap.append([x * x + y * y, x, y])  # Python uses the first element to conduct "heapify"

        heapq.heapify(minHeap)  # time: O(n)
        res = []
        while k > 0:  # time: O(k)
            v = heapq.heappop(minHeap)  # time: O(logn)
            res.append([v[1], v[2]])
            k -= 1
        return res
      
# quick select
# todo
#   - time: O(n) in average case, O(n^2) in worst case
#   - space: O(1)
    def kClosest(self, points, k):
        euclidean = lambda x: x[0] ** 2 + x[1] ** 2
        def partition(l, r):
            pivotIdx = r
            pivotDist = euclidean(points[pivotIdx])
            i = l
            for j in range(l, r):
                if euclidean(points[j]) <= pivotDist:
                    points[i], points[j] = points[j], points[i]
                    i += 1
            points[i], points[r] = points[r], points[i]
            return i

        L, R = 0, len(points) - 1
        pivot = len(points)

        while pivot != k:
            pivot = partition(L, R)
            if pivot < k:
                L = pivot + 1
            else:
                R = pivot - 1
        return points[:k]

ref

1046. Last Stone Weight

  • 因为Python里没有maxHeap,只有minHeap
    • 所以可能把所有的元素都变成负数
# sort
#  - time: O(n^2logn)
#  -  space: O(1) or O(n) depending on the sorting algorithm.
    def lastStoneWeight(self, stones: List[int]) -> int:
        while len(stones) > 1:
            stones.sort()
            cur = stones.pop() - stones.pop()
            if cur:
                stones.append(cur)
                
        return stones[0] if stones else 0
      
# binary search
#   - time: O(n^2)
#   - space: O(1) or O(n) depending on the sorting algorithm.
    def lastStoneWeight(self, stones: List[int]) -> int:
        stones.sort()
        n = len(stones)

        while n > 1:
            cur = stones.pop() - stones.pop()
            n -= 2
            if cur > 0:
                l, r = 0, n
                while l < r:
                    mid = (l + r) // 2
                    if stones[mid] < cur:
                        l = mid + 1
                    else:
                        r = mid
                pos = l
                n += 1
                stones.append(0)
                for i in range(n - 1, pos, -1):
                    stones[i] = stones[i - 1]
                stones[pos] = cur

        return stones[0] if n > 0 else 0
      
# heap
#   - time: O(nlogn)
#   - space: O(1)
class Solution(object):

    def lastStoneWeight(self, stones):
        """
        :type stones: List[int]
        :rtype: int
        """
        # By default, Python inits a min-heap
        # maxHeap
        for idx, stone in enumerate(stones):
            stones[idx] = -stone
        heapq.heapify(stones)  # time: O(n)

        while len(stones) > 1:  # time: O(n)
            largest = heapq.heappop(stones)  # time: O(logn)
            second = heapq.heappop(stones)  # time: O(logn)
            stone3 = abs(-largest - (-second))
            heapq.heappush(stones, -stone3)  # time: O(logn)
        if not stones:
            return 0
        return -stones[0]

# bucket sort     
class Solution(object):
    # 没理解为什么time是 O(n+w)
    # time: O(n+w), n is the length of the stones array and w is the maximum value in the stones array.
    # space: O(w), w is the maximum value in the stones array.
    def lastStoneWeight(self, stones):
        """
        :type stones: List[int]
        :rtype: int
        """
        maxStone = max(stones)
        buckets = [0] * (maxStone + 1)  # space: O(w),w is the maximum value in the stones array.
        for stone in stones:  # time: O(n)
            buckets[stone] += 1

        first = maxStone
        while first > 0:  # time: O(w)
            if buckets[first] % 2 == 0:  # smash,因为含有偶数块重量为 first 的石头时,质量直接抵消了
                first -= 1
                continue

            buckets[first] -= 1  # 取出重量为 first 的一块 stone
            i = first
            while i > 0 and buckets[i] == 0:
                i -= 1
            if i > 0:  # 说明遇到了重量不为 0 数量的 bucket
                second = i
                buckets[second] -= 1  # 取出重量为 second 的一块 stone
                buckets[first - second] += 1  # smash
                # 不能直接 first = second,因为 first - second 可能比 second 大
                first = max(first - second, second)  # time: O(w) 因为最多每次 first 向右推 w
            else:
                return first
        return 0

ref

Ref