赞
踩
def timsort(arr): arr = arr or [] if len(arr) <= 0: return [] runs = _partition_to_runs(arr) run_stack = [] for run in runs: run_stack.append(run) while _should_merge(run_stack): _merge_stack(run_stack) while len(run_stack) > 1: _merge_stack(run_stack) return run_stack[0]
def _partition_to_runs(arr): partitioned_up_to = 0 while partitioned_up_to < len(arr): if not len(arr) - partitioned_up_to: return if len(arr) - partitioned_up_to == 1: part = list(arr[-1:]) partitioned_up_to += 1 yield part else: if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: # 这里必须是严格降序 next_pos = _find_desc_boundary(arr, partitioned_up_to) _reverse(arr, partitioned_up_to, next_pos) else: next_pos = _find_asc_boundary(arr, partitioned_up_to) part = arr[partitioned_up_to:next_pos] partitioned_up_to = next_pos yield partdef _find_desc_boundary(arr, start): if start >= len(arr) - 1: return start + 1 if arr[start] > arr[start+1]: # 这里必须是严格降序 return _find_desc_boundary(arr, start + 1) else: return start + 1def _reverse(arr, start=0, end=None): # 正常的翻转函数,实现省略def _find_asc_boundary(arr, start): if start >= len(arr) - 1: return start + 1 if arr[start] <= arr[start+1]: return _find_asc_boundary(arr, start + 1) else: return start + 1
# 因为我们每次新添 run 进入 run_stack 时都判断是否需要归并,# 并且在每次归并之后还要进一步确保 run_stack 是满足不变式的,# 所以这里只判断栈头的两个 run 就够了。def _should_merge(run_stack): if len(run_stack) < 2: return False return len(run_stack[-2]) < 2*len(run_stack[-1])def _merge(ls1, ls2): # 正常的归并函数,实现省略def _merge_stack(run_stack): head = run_stack.pop() next = run_stack.pop() new_run = _merge(next, head) run_stack.append(new_run)
def _partition_to_runs(arr): partitioned_up_to = 0 while partitioned_up_to < len(arr): if not len(arr) - partitioned_up_to: return if len(arr) - partitioned_up_to == 1: part = list(arr[-1:]) partitioned_up_to += 1 yield part else: if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: next_pos = _find_desc_boundary(arr, partitioned_up_to) _reverse(arr, partitioned_up_to, next_pos) else: next_pos = _find_asc_boundary(arr, partitioned_up_to) # 只加了这一句话 next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos) part = arr[partitioned_up_to:next_pos] partitioned_up_to = next_pos yield partdef _insertion_sort(arr, start, end): # 标准插入排序实现def _do_insertion_sort_optimization(arr, start, end): length = end - start if length < INSERTION_SORT_THRESHOLD: end = min(start+INSERTION_SORT_THRESHOLD, len(arr)) _insertion_sort(arr, start, end) return end
# -*- coding: utf-8 -*-import functoolsfrom unittest import TestCaseINSERTION_SORT_THRESHOLD = 6def _find_desc_boundary(arr, start): if start >= len(arr) - 1: return start + 1 if arr[start] > arr[start+1]: return _find_desc_boundary(arr, start + 1) else: return start + 1def _reverse(arr, start=0, end=None): if end is None: end = len(arr) for i in range(start, start + (end-start)//2): opposite = end - i - 1 arr[i], arr[opposite] = arr[opposite], arr[i]def _find_asc_boundary(arr, start): if start >= len(arr) - 1: return start + 1 if arr[start] <= arr[start+1]: return _find_asc_boundary(arr, start + 1) else: return start + 1def _insertion_sort(arr, start, end): if end - start <= 1: return for i in range(start, end): v = arr[i] j = i - 1 while j>=0 and arr[j] > v: arr[j+1] = arr[j] j -= 1 arr[j+1] = vdef _do_insertion_sort_optimization(arr, start, end): length = end - start if length < INSERTION_SORT_THRESHOLD: end = min(start+INSERTION_SORT_THRESHOLD, len(arr)) _insertion_sort(arr, start, end) return enddef _partition_to_runs(arr): partitioned_up_to = 0 while partitioned_up_to < len(arr): if not len(arr) - partitioned_up_to: return if len(arr) - partitioned_up_to == 1: part = list(arr[-1:]) partitioned_up_to += 1 yield part else: if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: next_pos = _find_desc_boundary(arr, partitioned_up_to) _reverse(arr, partitioned_up_to, next_pos) else: next_pos = _find_asc_boundary(arr, partitioned_up_to) next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos) part = arr[partitioned_up_to:next_pos] partitioned_up_to = next_pos yield partdef _should_merge(run_stack): if len(run_stack) < 2: return False return len(run_stack[-2]) < 2*len(run_stack[-1])def _merge(ls1, ls2, merge_storage=None): ret = merge_storage or [] i1 = 0 i2 = 0 while i1 < len(ls1) and i2 < len(ls2): a = ls1[i1] b = ls2[i2] if a <= b: ret.append(a) i1 += 1 else: ret.append(b) i2 += 1 ret += ls1[i1:] ret += ls2[i2:] return retdef _merge_stack(run_stack, merge_storage=None): head = run_stack.pop() next = run_stack.pop() new_run = _merge(next, head, merge_storage=merge_storage) run_stack.append(new_run)def timsort(arr): arr = arr or [] if len(arr) <= 0: return [] runs = _partition_to_runs(arr) run_stack = [] for run in runs: run_stack.append(run) while _should_merge(run_stack): _merge_stack(run_stack) while len(run_stack) > 1: _merge_stack(run_stack) return run_stack[0]class Test(TestCase): class Elem: seq_no = 0 def __init__(self, n): Elem = Test.Elem self.n = n self.seq_no = Elem.seq_no Elem.seq_no += 1 def __lt__(self, other): return self.n < other.n def __str__(self): return "E" + str(self.n) + "S" + str(self.seq_no) Elem = functools.total_ordering(Elem) def setUp(self): Test.Elem.seq_no = 0 def test_reverse(self): arr = [3, 2, 1, 4, 7, 5, 6] _reverse(arr) self.assertEquals(arr, [6, 5, 7, 4, 1, 2, 3]) arr = [3, 2, 1] _reverse(arr) self.assertEquals(arr, [1, 2, 3]) def test_find_asc_boundary(self): arr = [1, 2, 3, 3, 2] self.assertEqual(_find_asc_boundary(arr, 0), 4) arr = [1, 2, 3, 3] self.assertEqual(_find_asc_boundary(arr, 0), 4) def test_find_desc_boundary(self): arr = [3, 2, 1] self.assertEqual(_find_desc_boundary(arr, 0), 3) arr = [3, 2, 1, 1] self.assertEqual(_find_desc_boundary(arr, 0), 3) def test_merge_stack(self): arr1 = [1, 2, 3] arr2 = [2, 3, 4] stack = [arr1, arr2] _merge_stack(stack) self.assertEqual(stack, [[1, 2, 2, 3, 3, 4]]) def test_merge_stability(self): Elem = Test.Elem arr1 = map(lambda e: Elem(e), [1, 2, 3]) arr2 = map(lambda e: Elem(e), [2, 3, 4]) stack = [arr1, arr2] _merge_stack(stack) self.assertEqual(map(lambda lst: map(str, lst), stack), [['E1S0', 'E2S1', 'E2S3', 'E3S2', 'E3S4', 'E4S5']]) def test_timsort(self): Elem = Test.Elem arr = map(lambda e: Elem(e), [3, 1, 2, 2, 7, 5]) ret = timsort(arr) self.assertEquals(map(str, ret), ['E1S1', 'E2S2', 'E2S3', 'E3S0', 'E5S5', 'E7S4']) self.assertEqual(timsort([]), []) self.assertEqual(timsort(None), [])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。