赞
踩
from binary_tree import Node # 完全二叉树的叶子节点只存在于最下面两层 # 且最下面一层的叶子节点全部靠左紧密排列 # 节点编号为从0开始,从上往下,从左到右依次编号 # 某个节点i的父节点是(i-1)//2,左节点是2*i+1,有节点是2*i+2 class PerfectBinaryTree(Node): def is_empty(self): if not self.item: return True def add_one(self, data): if self.is_empty(): self.item = data return queue = [self] while queue: size = len(queue) for _ in range(size): cur = queue.pop(0) if cur.lchild: queue.append(cur.lchild) else: cur.lchild = Node(data) return if cur.rchild: queue.append(cur.rchild) else: cur.rchild = Node(data) return def add_batch(self, arr): if self.is_empty(): self.item = arr[0] arr.pop(0) queue = [self] while arr and queue: size = len(queue) for _ in range(size): cur = queue.pop(0) if cur.lchild: queue.append(cur.lchild) elif arr: cur.lchild = Node(arr[0]) arr.pop(0) queue.append(cur.lchild) if cur.rchild: queue.append(cur.rchild) elif arr: cur.rchild = Node(arr[0]) arr.pop(0) queue.append(cur.rchild) def bfs(self): queue = [self] while queue: size = len(queue) for _ in range(size): cur = queue.pop(0) if cur.item: print(cur.item) if cur.lchild: queue.append(cur.lchild) if cur.rchild: queue.append(cur.rchild) from perfect_binary_tree import PerfectBinaryTree """ 二叉堆首先是一颗完全二叉树(结构上) 二叉堆节点间关系满足: 堆中根节点<=子树中的节点值(最小堆) 堆中根节点>=子树中的节点值(最大堆) 注意:在最大堆中,只能保证当前根节点大于等于子树的所有节点(任意子树都满足),节点大小与所处层数无关 """ class Heap(PerfectBinaryTree): def __init__(self, type='big'): self.type = type self.arr = [] if self.type == 'big': self._operator = lambda a, b: a > b else: self._operator = lambda a, b: a < b def visualize(self): super().__init__() super().add_batch(self.arr.copy()) def _swap(self, i, j): self.arr[i], self.arr[j] = self.arr[j], self.arr[i] # //为向下取整 def _get_parent_index(self, i): return (i - 1) // 2 def _get_lchild_index(self, i): return 2 * i + 1 def _get_rchild_index(self, i): return 2 * i + 2 def _get_right_bottom_most_nonleaf_index(self): return self._get_parent_index(len(self.arr) - 1) def add_one(self, data): """ 上浮操作发生在追加节点的时候,在数据的最后加节点,由于必须保证堆的性质,因此该节点要上浮 """ self.arr.append(data) self._siftup(len(self.arr) - 1) def _siftup(self, i): # use this function only when self.arr is already heapified # i = 0 indicates root. Sift up until root while (i > 0) and (self._operator(self.arr[i], self.arr[self._get_parent_index(i)])): self._swap(i, self._get_parent_index(i)) i = self._get_parent_index(i) def add_batch(self, arr): """ 下浮发生在删除根节点的时候,此时把最后一个元素挪到根节点上,因此要下沉 或者根节点被替换的时候,也要下沉 下浮法建堆的复杂度: 从最下面最右边的非叶子节点开始遍历到根节点 设完全二叉树总共有 h 层,如果当前遍历节点位于第 k 层,则需要下浮 h-k 次,这一层有 2^(k-1)个节点 则这一层总共需要 (h-k) * 2^(k-1) 次下层操作 遍历需要 S = (h-(h-1)) * 2^(h-1-1) + 2 * 2^(h-3) + 3 * 2^(h-4) .... (h-1) * 2^0 = 1 * 2^(h-2) + 2 * 2^(h-3) + .... (h-1) * 2^0 2S = 2 * 2^(h-1) + 2 * 2^(h-2) + .... (h-1) * 2 2S-S = S = 2 * 2^(h-1) + [2^(h-2) + 2^(h-3) + ... 2^1] - (h-1) * 2^0 = 2 * 2^(h-1) - (h-1) * 2^0 + [2^(h-2) + 2^(h-3) + ... 2^1] = 2 * 2^(h-1) - (h-1) * 2^0 + O(2^h) = O(2^h) = O(n) """ self.arr += arr # heapify the arr for i in range(self._get_right_bottom_most_nonleaf_index(), -1, -1): self._siftdown(i, len(self.arr)-1) def _siftdown(self, i, max_index): # use this function only when self.arr is already heapified # Sift down until max_index while i <= max_index: next_index = i if (self._get_lchild_index(i) <= max_index) and \ (self._operator(self.arr[self._get_lchild_index(i)], self.arr[i])): next_index = self._get_lchild_index(i) if (self._get_rchild_index(i) <= max_index) and \ (self._operator(self.arr[self._get_rchild_index(i)], self.arr[next_index])): next_index = self._get_rchild_index(i) if next_index != i: self._swap(i, next_index) i = next_index else: break def sort(self): """ 从后往前遍历到第 i个节点时,从i开始到堆顶的堆高为log(i), 因此堆顶需要下降 log(i)次 总共次数 S = log(n-1) + log(n-2) + ..log(i).. + 1 <= integral~[1, n-1](logx) 注意 logx 的原函数为 logx * x - x = logx * x - x |x=(n-1) - log x * x - x |x=1 = O(log(n-1) * (n-1)) = O(logn * n) """ # note that the sorted self.arr is in reverse relationship with self.type for max_index in range(len(self.arr) - 1, 0, -1): self._swap(0, max_index) self._siftdown(0, max_index-1) arr = [1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10] a = Heap(type='small') a.add_batch(arr[:first_x]) # complexity O(n) for i in range(first_x, len(arr)): if arr[i] > a.arr[0]: a.arr[0] = arr[i] # complexity: O(log(first_x)) a._siftdown(0, first_x-1) # total complexity: O(n * log(first_x)) print(a.arr)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。