【专题】快速排序

简介

快速排序算法是对冒泡排序算法的一种改进。平均时间复杂度\(O(nlogn)\),最坏时间复杂度\(O(n^2)\)

快速排序的基本思想:每趟排序选择一个基准值pivot,使得小于pivot的元素大于pivot的元素分隔于pivot两侧,即每一趟确定了一个元素的位置。然后对基准值两侧的区间进行递归,以达到整个序列有序。

可以看出时间复杂度与递归的层数相关,提升效率的关键就在于partition,划分时的实现。

单路快排

最基础的实现版本,前后双指针法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// 对左闭右开区间a[l, r),单路快排,前后双指针法
void quickSort(int *a, int l, int r){
if(l>=r) return;

int pivot = a[l];
int j = l;
for(int i=l+1; i<r; i++){
if(a[i] < pivot)
swap(a[++j], a[i]);
}
swap(a[l], a[j]);

quickSort(a, l, j);
quickSort(a, j+1, r);
}

这种方法有许多缺陷,但胜在简短,而且简单修改就可以变成链表版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 单链表采用直接交换数据区,修改指针的话没什么好想法
void quickSort(ListNode *head, ListNode *last){
if(head==nullptr || head==last) return;

int pivot = head->val;
ListNode *pre = head;
for(ListNode *cur=head; cur->next!=last; cur=cur->next){
if(cur->next->val < pivot){
swap(cur->next->val, pre->next->val);
pre = pre->next;
}
}
swap(head->val, pre->val);

quickSort(head, pre);
quickSort(pre->next, last);
}

双路快排

单路快排会使得与基准值pivot相等的元素总是归为一侧,存在大量相同元素时,时间复杂度退化为$ O(n^2) $。左右双指针法可以让与基准值相等的元素随机交换至两侧,但是还需处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// 左闭右开区间a[l, r),双路快排,左右双指针法
void quickSort(int *a, int l, int r){
if(l>=r) return;

int pivot = a[l];
int i=l+1, j=r-1;
while(i<=j){
while(a[i]<pivot && i<r) i++;
while(a[j]>pivot && j>l) j--;
if(i>j) break;
swap(a[i++], a[j--]);
}
swap(a[l], a[j]);

quickSort(a, l, j);
quickSort(a, j+1, r);
}

三路快排

要真正优化存在大量相同元素情况下快排的效率时,还是需要使用三路快排,将序列分为小于pivot的 $ [l, lt) $ , 等于pivot的 $ [lt, i] $ , 大于pivot的 $ [rt, r) $ 三部分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// 左闭右开区间a[l, r),三路快排
void quickSort(int *a, int l, int r){
if(l>=r) return;

int pivot = a[l];
int lt = l, rt = r;
for(int i=l+1; i<rt; ){
if(a[i] < pivot)
swap(a[i++], a[++lt]);
else if(a[i] > pivot)
swap(a[i], a[--rt]);
else
i++;
}
swap(a[l], a[lt]);

quickSort(a, l, lt);
quickSort(a, rt, r);
}

更多优化

更多优化就可以参考STL::sort()函数的实现。例如,

在基准值pivot的选取中,如果每次选取的恰好是当前序列中的最大或最小元素,划分的结果是最坏情况,递归层数大大上升。对此,我们可以使用随机选取或者三数取中法

1
2
3
4
5
6
7
8
// 左闭右开区间a[l, r),随机选取
srand(time(NULL));
int pivot = a[rand()%(r-l)+l];

// 左闭右开区间a[l, r),三数取中法
int pivot = a[l + (r-l)/2];
if(pivot<a[l] && pivot<a[r]) pivot=min(a[l], a[r]);
if(pivot>a[l] && pivot>a[r]) pivot=max(a[l], a[r]);

当序列较短时(<16),可以采用直接插入排序,减少递归深度。

1
2
3
4
5
// 还需要具体实现
if(r-l < 16){
insertionSort(a, l, r);
return;
}

当递归次数大于限制时,采用堆排序等。

代码

测试代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <bits/stdc++.h>
using namespace std;

void showArray(int *a, int n){
for(int i=0; i<n; i++){
printf("%d ", a[i]);
}
printf("\n");
}

// 单路快排,前后双指针法
void quickSort(int *a, int l, int r){
if(l>=r) return;

int pivot = a[l];
int j = l;
for(int i=l+1; i<r; i++){
if(a[i] < pivot)
swap(a[++j], a[i]);
}
swap(a[l], a[j]);

quickSort(a, l, j);
quickSort(a, j+1, r);
}

// 双路快排,左右双指针法
void quickSort(int *a, int l, int r){
if(l>=r) return;

int pivot = a[l];
int i=l+1, j=r-1;
while(i<=j){
while(a[i]<pivot && i<r) i++;
while(a[j]>pivot && j>l) j--;
if(i>j) break;
swap(a[i++], a[j--]);
}
swap(a[l], a[j]);

quickSort(a, l, j);
quickSort(a, j+1, r);
}

// 三路快排
void quickSort(int *a, int l, int r){
if(l>=r) return;

int pivot = a[l];
int lt = l, rt = r;
for(int i=l+1; i<rt; ){
if(a[i] < pivot)
swap(a[i++], a[++lt]);
else if(a[i] > pivot)
swap(a[i], a[--rt]);
else
i++;
}
swap(a[l], a[lt]);

quickSort(a, l, lt);
quickSort(a, rt, r);
}

struct ListNode {
int val;
ListNode *next;
ListNode(int x) : val(x), next(NULL) {}
};

void showList(ListNode *head){
ListNode *p = head;
while(p!=nullptr){
printf("%d ", p->val);
p = p->next;
}
printf("\n");
}

// 链表快排
void quickSort(ListNode *head, ListNode *last){
if(head==nullptr || head==last) return;

int pivot = head->val;
ListNode *pre = head;
for(ListNode *cur=head; cur->next!=last; cur=cur->next){
if(cur->next->val < pivot){
swap(cur->next->val, pre->next->val);
pre = pre->next;
}
}
swap(head->val, pre->val);

quickSort(head, pre);
quickSort(pre->next, last);
}

int main()
{
int n = 10;
int a[11] = {4,1,7,6,9,2,8,0,3,5};

ListNode *h[10];
for(int i=0; i<n; i++){
h[i] = new ListNode(a[i]);
if(i) h[i-1]->next = h[i];
}
ListNode *head = h[0];

showArray(a, n);
quickSort(a, 0, n);
showArray(a, n);

showList(head);
quickSort(head, nullptr);
showList(head);

return 0;
}