当前位置:   article > 正文

[LeetCode] 分治之 Median of Two Sorted Arrays 课后题算法实现 Hard_leetcode median 分治

leetcode median 分治

写在前面

这道题目是我曾经在LeetCode上就见到过的,是很少的限定了时间复杂度的题目之一。这道题目的难度可以说是蛮高的,Hard 22.1%的通过率,可以说是在leetcode的所有题目来说也是非常难的了。当时学习分治法的时候曾经看过题目,也看了discuss但是并没有很好的思路,在O(log(m+n))的时间复杂度来说还是算比较难的。后来在老师讲课后习题的时候,恰好在算法概论第二章的一题中提到了这个类似的题目,要求找到两个有序数组中第K大的数,并且要求的时间复杂度是O(log(m)+log(n)),时间复杂度相比起来会更低。当时我在课上提出了一个符合题目时间复杂度的思路,而且也基本可行,时间复杂度也符合,但是因为终止条件什么的都还没有提出,所以实现的话还需要另外丰富这个算法。于是课后就专门找了这题来实现。


题目

There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Example 1 :

nums1 = [1, 3]
nums2 = [2]

The median is 2.0

Example 2 :

nums1 = [1, 2]
nums2 = [3, 4]

The median is (2 + 3)/2 = 2.5

分析

当时我在课上提出的算法是针对于任意给定的k(k < nums1.size() + nums2.size())的,对于这题来说,k是中位数,也就是k = (nusm1.size() + nums2.size())/2,这只是一种特殊情况,下面我将对于任意给定的k来进行设计算法。由于时间比较紧张,可能有更简单的,但是我没有考虑那么多的优化。我的思路是递归的,但是由于是尾递归,可以在实现的时候转成循环,这种转换是简单的。

算法流程

首先将两个array一个命名为A,另一个为B,A的size是m,B的size是n

  1. 如果A或B是空的,那么直接返回另一个数组的中位数
  2. 如果发现其中某个数组的最大值小于另外一个数组的最小值,那么直接求得中位数
  3. 如果发现其中某个数组的size == 1,那么直接将其二分查找插入另一个数组,并求得中位数
  4. 比较A[m/2]和B[n/2]的大小,假设A[m/2] >= B[n/2]
    4.1. 如果k<=(m+n)/2,则去掉A的后半部分,并回到步骤2
    4.2 如果k>(m+n)/2,则去掉B的前半部分,并将k赋值为k-n/2,并回到步骤2
  5. 比较A[m/2]和B[n/2]的大小,假设A[m/2] <= B[n/2]
    5.1 如果k<=(m+n)/2,则去掉B的后半部分,并回到步骤2
    5.2 如果k>(m+n)/2,则去掉A的前半部分,并将k赋值为k-m/2,并回到步骤2

实际题目会因为奇偶数,导致中位数的求法需要另外考虑,但是这并不要紧,只是稍微麻烦了一些而已。

时间复杂度分析

时间复杂度之前已经说过了,是O(log(n)+log(m)),这是易于发现的,因为每一轮递归都至少有一个数组的求解范围会缩短一半,除非遇到终止条件。

代码

这里给出了递归和非递归的两种实现。实现的时候千万要注意的一点就在于各种边界的+1,-1,以及不要忘记我们只是缩小了下标的范围,并没有真实的缩小两个数组,因此在比较的时候必须注意什么时候需要加上al、bl,由于一直是静态debug,所以这种错误有时候会因为思路是跟着算法走的,实际上实现的时候就出现了纰漏,最后花了很久时间在找这两个错误上。

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        a = nums1;
        b = nums2;
        isOdd = (nums1.size() + nums2.size()) % 2;
        k = (nums1.size() + nums2.size()) >> 1;
        // consider the condition that one of the two arrays is empty
        if (nums1.size() == 0) {
            if (isOdd) ans = nums2[k];
            else ans = 1.0*(nums2[k -1] + nums2[k]) /2;
            return ans;
        } else if (nums2.size() == 0) {
            if (isOdd) ans = nums1[k];
            else ans = 1.0*(nums1[k -1] + nums1[k]) /2;
            return ans;
        }
        test(0, nums1.size() -1, 0, nums2.size() -1);
        return ans;
    }
    void test(int al, int ar, int bl, int br) {
        // the end conditon
        // the whole array a smaller than array b
        int m = ar - al +1, n = br - bl+1;
        if (a[ar] <= b[bl]) {
            if (isOdd) ans = m > k ? a[al + k] : b[bl + k - m];
            else {
                int left = m > k -1 ? a[al + k -1] : b[bl + k -1 - m];
                int right = m > k ? a[al + k] : b[bl + k - m];
                ans = 1.0*(left + right) / 2;
            }
            return;
        }
        // the whole array b smaller than array a
        else if (b[br] <= a[al]) {
            if (isOdd) ans = n > k ? b[bl + k] : a[al + k - n];
            else {
                int left = n > k -1 ? b[bl + k -1] : a[al + k -1- n];
                int right = n > k ? b[bl + k] : a[al + k -n];
                ans = 1.0*(left + right) / 2;
            }
            return;
        }
        if (ar - al == 0) {
            vector<int> temp(b.begin() + bl, b.begin() + br +1);
            temp.push_back(a[al]);
            sort(temp.begin(), temp.end());
            if (isOdd) ans = temp[k];
            else ans = 1.0*(temp[k -1] + temp[k]) /2;
            return;
        } else if (br - bl == 0) {
            vector<int> temp(a.begin() + al, a.begin() + ar +1);
            temp.push_back(b[bl]);
            sort(temp.begin(), temp.end());
            if (isOdd) ans = temp[k];
            else ans = 1.0*(temp[k -1] + temp[k]) /2;
            return;
        }
        // m is the length of array a, n is the length of array b
        if (a[al + m/2] >= b[bl + n/2]) {;
            if (k <= ((m + n)>>1)) {
                ar = ar - m/2;
            } else {
                bl = bl + n/2;
                k = k - n/2;
            }
        } else {
            if (k <= ((m + n) >> 1)) {
                br = br - n/2;
            } else {
                al = al + m/2;
                k = k - m/2;
            }
        }
        test(al,ar,bl,br);
    }

    void testWithoutRecursion(int al, int ar, int bl, int br) {
        while (1) {
            // the end conditon
            // the whole array a smaller than array b
            int m = ar - al +1, n = br - bl+1;
            if (a[ar] <= b[bl]) {
                if (isOdd) ans = m > k ? a[al + k] : b[bl + k - m];
                else {
                    int left = m > k -1 ? a[al + k -1] : b[bl + k -1 - m];
                    int right = m > k ? a[al + k] : b[bl + k - m];
                    ans = 1.0*(left + right) / 2;
                }
                return;
            }
            // the whole array b smaller than array a
            else if (b[br] <= a[al]) {
                if (isOdd) ans = n > k ? b[bl + k] : a[al + k - n];
                else {
                    int left = n > k -1 ? b[bl + k -1] : a[al + k -1- n];
                    int right = n > k ? b[bl + k] : a[al + k -n];
                    ans = 1.0*(left + right) / 2;
                }
                return;
            }
            if (ar - al == 0) {
                vector<int> temp(b.begin() + bl, b.begin() + br +1);
                temp.push_back(a[al]);
                sort(temp.begin(), temp.end());
                if (isOdd) ans = temp[k];
                else ans = 1.0*(temp[k -1] + temp[k]) /2;
                return;
            } else if (br - bl == 0) {
                vector<int> temp(a.begin() + al, a.begin() + ar +1);
                temp.push_back(b[bl]);
                sort(temp.begin(), temp.end());
                if (isOdd) ans = temp[k];
                else ans = 1.0*(temp[k -1] + temp[k]) /2;
                return;
            }
            // m is the length of array a, n is the length of array b
            if (a[al + m/2] >= b[bl + n/2]) {
                if (k <= ((m + n)>>1)) {
                    ar = ar - m/2;
                } else {
                    bl = bl + n/2;
                    k = k - n/2;
                }
            } else {
                if (k <= ((m + n) >> 1)) {
                    br = br - n/2;
                } else {
                    al = al + m/2;
                    k = k - m/2;
                }
            }
        }
    }

    vector<int> a;
    vector<int> b;
    int k;
    double ans;
    // 0 means even, 1 means odd
    bool isOdd;
};
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号