[Notes] Binary Heap

Posted by 西维蜀黍的OJ Blog on 2023-09-10, Last Modified on 2024-02-01

Knowledge

Refer to https://swsmile.info/post/data-structure-heap/

堆化(heapifying)

def heapify(arr, n, i):
	"""
	Function to heapify a subtree rooted with node i which is an index in arr[]. n is size of heap
	"""
	largest = i  # Initialize largest as root
	left = 2 * i + 1     # left = 2*i + 1
	right = 2 * i + 2    # right = 2*i + 2

	# See if left child of root exists and is greater than root
	if left < n and arr[left] > arr[largest]:
		largest = left

	# See if right child of root exists and is greater than the root
	if right < n and arr[right] > arr[largest]:
		largest = right

	# Change root, if needed
	if largest != i:
		arr[i], arr[largest] = arr[largest], arr[i]  # swap

		# Heapify the root
		heapify(arr, n, largest)

# Example usage
arr = [12, 11, 13, 5, 6, 7]
n = len(arr)

# Build a maxheap
for i in range(n//2 - 1, -1, -1):
	heapify(arr, n, i)

面试中碰到的

The minimum total weight of chocolates after d days

from typing import List
import heapq

def findMinWeight(weights: List[int], d: int) -> int:
    """
    Find the minimum total weight of chocolates after d days.

    Parameters:
    weights (list of int): The weights of each chocolate.
    d (int): The number of days.

    Returns:
    int: The minimum total weight of chocolates after d days.
    """

    # Convert the list of chocolate weights into a max-heap.
    # Python's heapq module implements a min-heap, so we store negative values
    # to simulate a max-heap.
    max_heap = [-w for w in weights]
    heapq.heapify(max_heap) # time: O(n)

    # Each day, eat half of the heaviest chocolate.
    for _ in range(d): # time: O(dlogn)
        # Remove the heaviest chocolate, eat half of it, and put back the rest.
        heaviest = -heapq.heappop(max_heap) # time: O(1)
        eaten = heaviest // 2
        remaining = heaviest - eaten
        # Add the remaining part back to the heap.
        heapq.heappush(max_heap, -remaining) # time: O(logn)

    # The total weight is the sum of the negatives of the max-heap values.
    total_weight = -sum(max_heap) # time: O(n)
    return total_weight

# Example usage:
chocolates = [30, 20, 25]
days = 4
findMinWeight(chocolates, days)

解析

215. Kth Largest Element in an Array

# Solution 1 - Sort
# - time: O(nlog(n))
nums.sort()
return num[len(nums) - k]

# Solution 2 - heap
# minHeap
# - time: O((n-k)logn)
# - space: O(1)
def findKthLargest(self, nums: List[int], k: int) -> int:
		heapq.heapify(nums) # time: O(n)
		while len(nums) > k: # time: O((n-k)logn)
			heapq.heappop(nums)

		return heapq.heappop(nums)

# maxHeap 
# - time: O(klogn)
# - space: O(1)
	def findKthLargest(self, nums: List[int], k: int) -> int:
		for i in range(len(nums)): # time: O(n)
			nums[i] = -nums[i]
		heapq.heapify(nums)  # time: O(n)

		while k > 1:
			heapq.heappop(nums)
			k -= 1
		return -nums[0]
  
# Solution 3 - QuickSelect (like quick sort) - 写法 1
# Time Complexity:
#   - Best Case: O(n)
#   - Average Case: O(n),因为 n + n/2 + n/4 +... = 2n -> O(n)
#   - Worst Case: O(n^2), 如果是 [1,2,3,4,5,6],每次做完会变成 [1,2,3,4,5], [1,2,3,4] ... -> 所以是O(n^2)
# Extra Space Complexity: O(1)
	def partition(self, nums, l, r):
		# choose the rightmost element as the pivot
		# p is the pointer
		p = l
		pivot = nums[r]
		for idx in range(l, r):
			if nums[idx] <= pivot:
				# 把 <= pivot 的元素,都放到 pivot 元素的左边
				nums[idx], nums[p] = nums[p], nums[idx]
				p += 1
		# 把 pivot 对应的元素和 p 指向的元素交换,以使得在 p 对应索引左边的元素都比 p 对应的元素小或等于,而且 p 对应索引右边的元素对比 p 对应的元素大
		nums[r], nums[p] = nums[p], nums[r]

		# nums = [3,2,1,5,6,4], k = 2 -> nums = [3,2,1,4 (p),6,5], 即 4 左边的元素都 <= 4,4 右边的元素都 > 4。后续继续去处理 [6, 5]
		return p

	def findKthLargest(self, nums: List[int], k: int) -> int:
		l, r = 0, len(nums) - 1
		while l < r:  # [l:r+1] is the current valid searching range for the kth largest element
			p = self.partition(nums, l, r)
			if p == len(nums) - k:
				break
			elif p < len(nums) - k:  # 这时,说明在 p 右边的元素都比 p 对应的元素大,而要找的第 k 大的元素在 p 右边,所以 left = p + 1
				l = p + 1
			else:  # 这时,在 p 对应索引左边的元素都比 p 对应的元素小或等于,而且 p 对应索引右边的元素对比 p 对应的元素大。所以直接得出答案
				r = p - 1
		return nums[len(nums) - k]

# Solution 3 - QuickSelect (like quick sort) - 写法 2
# Time Complexity:
#   - Best Case: O(n)
#   - Average Case: O(n),因为 n + n/2 + n/4 +... = 2n -> O(n)
#   - Worst Case: O(n^2), 如果是 [1,2,3,4,5,6],每次做完会变成 [1,2,3,4,5], [1,2,3,4] ... -> 所以是O(n^2)
# Extra Space Complexity: O(1)      
	def findKthLargest(self, nums: List[int], k: int) -> int:
		k = len(nums) - k

		def quick_select(left, right):
			pivot = nums[right]
			p = left
			for idx in range(left, right):
				if nums[idx] <= pivot:
					nums[idx], nums[p] = nums[p], nums[idx]
					p += 1

			nums[right], nums[p] = nums[p], nums[right]
			# 这样做完后,[3,2,1,5,6,4] -> [3,2,1,4 (p),6,5], 即 4 左边的元素都 <= 4,4 右边的元素都 > 4。后续继续去处理 [6, 5]

			if p == k: # 这时,在 p 对应索引左边的元素都比 p 对应的元素小或等于,而且 p 对应索引右边的元素对比 p 对应的元素大。所以直接得出答案
				return nums[k]
			elif p > k:  # 这时,说明在p左边的元素都比 p 小或等于,而要找的第 k 大的元素在 p 左边,所以 right = p -1
				return quick_select(left, p - 1)
			else:  # 这时,说明在p右边的元素都比 p 对应的元素大,而要找的第 k 大的元素在 p 右边,所以 left = p +1
				return quick_select(p + 1, right)

		return quick_select(0, len(nums) - 1)

ref

355. Design Twitter

  • we need a count to indicate the creat_time of a tweet
  • self.userId2Tweets records the tweets that a user posts with [self.count, tweetId]
  • Every time an user posts a tweet, we are gonna append [self.count, tweetId] to self.userId2Tweets
    • and then we could get the latest tweet of this user by index = len(self.userId2Tweets[userId]) -1, count, tweetId = self.userId2Tweets[userId][index]
      • we could for-loop all followers of an user, and get their latest tweets
      • and then heapify them
      • If less than 10, get the previous tweets of the current tweets, until reaching 10.
class Twitter:
    def __init__(self):
        self.count = 0
        self.tweetMap = defaultdict(list)  # userId -> list of [count, tweetIds]
        self.followMap = defaultdict(set)  # userId -> set of followeeId

    def postTweet(self, userId: int, tweetId: int) -> None:
        self.tweetMap[userId].append([self.count, tweetId])
        self.count -= 1

    def getNewsFeed(self, userId: int) -> List[int]:
        res = []
        minHeap = []

        self.followMap[userId].add(userId)
        for followeeId in self.followMap[userId]:
            if followeeId in self.tweetMap:
                index = len(self.tweetMap[followeeId]) - 1
                count, tweetId = self.tweetMap[followeeId][index]
                heapq.heappush(minHeap, [count, tweetId, followeeId, index - 1])

        while minHeap and len(res) < 10:
            count, tweetId, followeeId, index = heapq.heappop(minHeap)
            res.append(tweetId)
            if index >= 0:
                count, tweetId = self.tweetMap[followeeId][index]
                heapq.heappush(minHeap, [count, tweetId, followeeId, index - 1])
        return res

    def follow(self, followerId: int, followeeId: int) -> None:
        self.followMap[followerId].add(followeeId)

    def unfollow(self, followerId: int, followeeId: int) -> None:
        if followeeId in self.followMap[followerId]:
            self.followMap[followerId].remove(followeeId)

ref

621. Task Scheduler

  • we wanna compute the task which is the most frequently first, so that the ideal count may be as small as possible
  • So we use a MinHeap to store the counts that every task needs to be computed
  • After a task is computed, we add its count_to_compute into a queue with computable_until_when.
  • After a computation, we check the queue to see whether there is a task which is computable
    • if yes, pop it from the queue, and add it to the maxHeap
  • If nothing in the maxHeap and nothing in the queue, it means that we finish everything
# solution 1, 写法1(Needcode)
# - time: O(n*m), n is the size of tasks, and m is the idea time, since image we have [A, A, A, A] and n is 10, so we need n*m
# - space: o(n)
def leastInterval(self, tasks: List[str], n: int) -> int:
		m = {}  # space: O(n)
		for task in tasks:  # time: O(n)
			if task not in m:
				m[task] = 1
			else:
				m[task] += 1

		frequencies = []  # space: O(n)
		for frequency in m.values():  # time: O(n)
			frequencies.append(-frequency)
		heapq.heapify(frequencies)

		time = 0
		queue = []  # (count_to_compute, computable_until_when) # space: O(n)
		while frequencies or queue:  # time: O(n*m), n is the size of tasks, and m is the idea time, since image we have [A, A, A, A] and n is 10, so we need n*m
			time += 1
			if len(frequencies) > 0:
				count_to_compute = heapq.heappop(frequencies)
				count_to_compute += 1
				if count_to_compute < 0:
					queue.append((count_to_compute, time + n))

			if queue:
				frequency, computable_until_when = queue[0]
				if computable_until_when <= time:
					queue.pop(0)
					heapq.heappush(frequencies, frequency)
		return time

# solution 1, 写法2(by myself)
# - time: O(n*m), n is the size of tasks, and m is the idea time, since image we have [A, A, A, A] and n is 10, so we need n*m
# - space: o(n)
	def leastInterval(self, tasks: List[str], n: int) -> int:
		m = {}
		for task in tasks:
			if task in m:
				m[task] += 1
			else:
				m[task] = 1

		maxHeap = []
		for num in m.values():
			maxHeap.append(-num)
		heapq.heapify(maxHeap)

		time = 0
		stage = []
		while len(maxHeap) > 0 or stage: # time: O(n*m), n is the size of tasks, and m is the cooldown time, since image we have [A, A, A, A] and n is 10, so we need n*m
			# add back
			if stage:
				minTimeToAdd, task_num = stage[0][0], stage[0][1]
				if minTimeToAdd <= time:
					stage.pop(0)
					heapq.heappush(maxHeap, task_num)

			if len(maxHeap) > 0:
				task_num = -heapq.heappop(maxHeap)
				task_num -= 1
				if task_num > 0:
					stage.append([time + n + 1, -task_num])
			time += 1

		return time

ref

703. Kth Largest Element in a Stream

  • use min heap with the size of K to memorise the k largest elements
# time: O((n-k)logn)
# space: O(n)
class KthLargest:
	def __init__(self, k: int, nums: List[int]):
		self.k, self.nums = k, nums
		heapq.heapify(self.nums)  # minHeap with K largest elements, time: O(n)
		while len(self.nums) > k:  # pop elements until k elements are left, time: O((n-k)logn)
			heapq.heappop(self.nums)

	def add(self, val: int) -> int:
		heapq.heappush(self.nums, val)  # time: O(logk)
		if len(self.nums) > self.k:  # only keep k elements in the list
			heapq.heappop(self.nums)  # time: O(logk)
		return self.nums[0]  # the smallest element of the heap is the kth element

ref

973. K Closest Points to Origin

  • Python uses the first element to sort
class Solution:
	def kClosest(self, points: List[List[int]], k: int) -> List[List[int]]:
		minHeap = []  # space: O(n)
		for x, y in points:
			minHeap.append([x * x + y * y, x, y]) # Python uses the first element to sort

		heapq.heapify(minHeap)  # time: O(n)
		res = []
		while k > 0:
			v = heapq.heappop(minHeap)  # time: O(klogn)
			res.append([v[1], v[2]])
			k -= 1
		return res

ref

1046. Last Stone Weight

  • 把所有的元素都变成负数,因为Python里没有maxHeap,只有minHeap
	def lastStoneWeight(self, stones: List[int]) -> int:
		# maxHeap
		for idx, stone in enumerate(stones):
			stones[idx] = -stone
		heapq.heapify(stones)  # O(n)

		while len(stones) > 1:
			v1 = heapq.heappop(stones)  # O(logn)
			v2 = heapq.heappop(stones)  # O(logn)

			if v1 != v2:
				heapq.heappush(stones, -abs(v2 - v1))  # O(logn)

		if len(stones) == 0:
			return 0
		return -stones[0]

ref

Ref