当前位置:   article > 正文

heapq源码解读(四)_heapq.nlargest

heapq.nlargest

Python heapq源码解读计划(四)

本文是解读python heapq 库的最后一节,主要分析的函数为nlargestnsmallest这两个函数。

nlargest函数的实现

nlargest(n,iterable,key=None)这个函数的功能为返回数据集中最大的n个元素,等价于sorted(iterable, key=key, reverse=True)[:n]

n = 1时的处理

n = 1时,其实就是找出这个list中的最大值。这和使用max()函数是一样的,但是max()函数不接受空list,所以得想办法让heap为空的时候,返回一个空list,当heap不为空时,返回一个只有一个元素的list。

首先,需要利用python 的iter()函数来将list变成一个迭代器。

it = iter(iterable)
  • 1

然后生成一个空的object(),这个object()主要是用来处理后面空列表的情况。

sentinel = object()
  • 1

然后来判断key是否为None,最后利用max()函数获取结果。这里以key=None来说明:

result = max(it, default=sentinel)
  • 1

其实和普通的max使用类似,但是多了一个参数default=sentinel。加上这个参数的意义是当这个iter为空时,max返回的是一个object。最后只需要判断一下,如果result是sentinel,那么就返回一个空list,否则就返回[result]

这部分的源码实现如下:

    if n == 1:
        it = iter(iterable)
        sentinel = object()
        if key is None:
            result = max(it, default=sentinel)
        else:
            result = max(it, default=sentinel, key=key)
        return [] if result is sentinel else [result]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

测试:

array = [1,3,2,4]
heap = []
print("heap:",heapq.nlargest(1,heap))
for num in array:
    heapq.heappush(heap,num)
    # print("heap: ", heap)
print("heap:",heapq.nlargest(1,heap))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

结果:

heap: []
heap: [4]
  • 1
  • 2
当n > size时

首先先来比较n和数组长度size哪个更大,如果n >= size,那么直接返回逆序排序的结果即可。

    try:
        size = len(iterable)
    except (TypeError, AttributeError):
        pass
    else:
        if n >= size:
            return sorted(iterable, key=key, reverse=True)[:n]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
当 key 为 None时

n < size时,所需要的做的事情就是找出来这个堆中最大的n个值。首先,和之前一样,利用python的iter()heap变成了一个迭代器。然后从其中读取出两个元素。

it = iter(iterable)
result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
  • 1
  • 2

result为空时,直接将result返回即可。

if not result:
    return result
  • 1
  • 2

如果result不为空,则对result做一个heapify的处理,让其成为一个小根堆。然后提取出result[0][0],令其为top,这个top为result中最小的值。

heapify(result)
top = result[0][0]
  • 1
  • 2

因为it是迭代器,而之前在构造result的时候已经获取了n个元素,所以在读取it中的元素的时候,是接着之前的读取的。当出现了top的值小于elem的值的时候,就说明result中的值并不是前n大的,所以需要利用heapreplace函数来将elem给换到result中去。

_heapreplace = heapq.heapreplace
for elem in it:
    print("order:",order)
    print("elem:",elem)
    if top < elem:
        _heapreplace(result, (elem, order))
        print(result)
        top, _order = result[0]
        order -= 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

最后将result逆序排列一下,然后只拿出其中的elem来返回即可。

result.sort(reverse=True)
return [elem for (elem, order) in result]
  • 1
  • 2

order的作用应该是为了分辨值相同的元素所带来的影响,即便值相同,但是order是不同的,可以利用order来将值相同的元素分辨出来。

General case

General case和key 为 None唯一的区别就是添加了key的元素其中。

nlargest整体代码:
def nlargest(n, iterable, key=None):
    # Short-cut for n==1 is to use max()
    if n == 1:
        it = iter(iterable)
        sentinel = object()
        if key is None:
            result = max(it, default=sentinel)
        else:
            result = max(it, default=sentinel, key=key)
        return [] if result is sentinel else [result]

    # When n>=size, it's faster to use sorted()
    try:
        size = len(iterable)
    except (TypeError, AttributeError):
        pass
    else:
        if n >= size:
            return sorted(iterable, key=key, reverse=True)[:n]

    # When key is none, use simpler decoration
    if key is None:
        it = iter(iterable)
        result = [(elem, i) for i, elem in zip(range(0, -n, -1), it)]
        if not result:
            return result
        heapify(result)
        top = result[0][0]
        order = -n
        _heapreplace = heapreplace
        for elem in it:
            if top < elem:
                _heapreplace(result, (elem, order))
                top, _order = result[0]
                order -= 1
        result.sort(reverse=True)
        return [elem for (elem, order) in result]

    # General case, slowest method
    it = iter(iterable)
    result = [(key(elem), i, elem) for i, elem in zip(range(0, -n, -1), it)]
    if not result:
        return result
    heapify(result)
    top = result[0][0]
    order = -n
    _heapreplace = heapreplace
    for elem in it:
        k = key(elem)
        if top < k:
            _heapreplace(result, (k, order, elem))
            top, _order, _elem = result[0]
            order -= 1
    result.sort(reverse=True)
    return [elem for (k, order, elem) in result]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55

nsmallest函数实现

nsmallest函数的整体实现和nlargest函数的实现是类似的。

n == 1时

这一部分的具体实现和nlargerst基本一样,只不过这里使用的是min函数

源码:

if n == 1:
    it = iter(iterable)
    sentinel = object()
    if key is None:
        result = min(it, default=sentinel)
    else:
        result = min(it, default=sentinel, key=key)
    return [] if result is sentinel else [result]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
当n >= size时

这部分和nlargest函数一样。

当 k 为 None时:

这里的结构也是和nlargest函数是一样的,但是有这么几点不同:

  • nlargest函数中使用heapify的部分,这里使用的是_heapify_max

  • nlargest函数中使用heapreplace的部分,这里使用的是_heapreplace_max

_heapify_max的作用其实就是构建一个大根堆。_heapreplace_max也是在大根堆中处理元素替代的问题。

源码:

def nsmallest(n, iterable, key=None):
    """Find the n smallest elements in a dataset.

    Equivalent to:  sorted(iterable, key=key)[:n]
    """

    # Short-cut for n==1 is to use min()
    if n == 1:
        it = iter(iterable)
        sentinel = object()
        if key is None:
            result = min(it, default=sentinel)
        else:
            result = min(it, default=sentinel, key=key)
        return [] if result is sentinel else [result]

    # When n>=size, it's faster to use sorted()
    try:
        size = len(iterable)
    except (TypeError, AttributeError):
        pass
    else:
        if n >= size:
            return sorted(iterable, key=key)[:n]

    # When key is none, use simpler decoration
    if key is None:
        it = iter(iterable)
        # put the range(n) first so that zip() doesn't
        # consume one too many elements from the iterator
        result = [(elem, i) for i, elem in zip(range(n), it)]
        if not result:
            return result
        _heapify_max(result)
        top = result[0][0]
        order = n
        _heapreplace = _heapreplace_max
        for elem in it:
            if elem < top:
                _heapreplace(result, (elem, order))
                top, _order = result[0]
                order += 1
        result.sort()
        return [elem for (elem, order) in result]

    # General case, slowest method
    it = iter(iterable)
    result = [(key(elem), i, elem) for i, elem in zip(range(n), it)]
    if not result:
        return result
    _heapify_max(result)
    top = result[0][0]
    order = n
    _heapreplace = _heapreplace_max
    for elem in it:
        k = key(elem)
        if k < top:
            _heapreplace(result, (k, order, elem))
            top, _order, _elem = result[0]
            order += 1
    result.sort()
    return [elem for (k, order, elem) in result]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/204691
推荐阅读
相关标签
  

闽ICP备14008679号