归并排序实战:计算右侧小于当前元素个数与翻转对
49. 计算右侧小于当前元素的个数
题目描述
给定一个整数数组 nums,按要求返回一个新数组 counts。其中 counts[i] 表示在 nums[i] 右边且比它小的元素个数。

解法思路
这道题本质上是求逆序对的变种。普通的逆序对统计的是总数,而这里需要记录每个元素对应的数量。
归并排序过程中,元素的下标会发生变化,为了准确统计,我们需要维护一个辅助数组来绑定原始下标。当左半区间的元素大于右半区间的元素时,说明左边的这个元素比右边剩余的所有元素都大,此时可以直接累加计数。
C++ 代码实现
class Solution {
public:
vector<int> counts;
vector<int> index; // 记录 nums 中当前元素的原始下标
vector<int> tmp;
vector<int> tmpi; // 临时记录合并后对应元素位置变化下标随之变化的数组
vector<int> countSmaller(vector<int>& nums) {
counts.resize(nums.size());
index.resize(nums.size());
tmp.resize(nums.size());
tmpi.resize(nums.size());
// 初始化 index,使得 nums 每个元素下标和 index 一一对应
for(int i = 0; i < nums.size(); i++) {
index[i] = i;
}
mergesort(nums, 0, nums.size() - 1);
return counts;
}
void mergesort(vector<int>& nums, int left, int right) {
if(left == right) {
return;
}
int mid = (right - left) / 2 + left;
mergesort(nums, left, mid);
mergesort(nums, mid + 1, right);
int cur1 = left, cur2 = mid + 1, i = 0;
while(cur1 <= mid && cur2 <= right) {
if(nums[cur1] > nums[cur2]) {
counts[index[cur1]] += right - cur2 + 1; // 关键点:左边元素大于右边,则右边剩余元素均小于左边
tmpi[i] = index[cur1];
tmp[i++] = nums[cur1++];
} else {
tmpi[i] = index[cur2];
tmp[i++] = nums[cur2++];
}
}
while(cur1 <= mid) {
tmpi[i] = index[cur1];
tmp[i++] = nums[cur1++];
}
while(cur2 <= right) {
tmpi[i] = index[cur2];
tmp[i++] = nums[cur2++];
}
for(int j = left; j <= right; j++) {
nums[j] = tmp[j - left];
index[j] = tmpi[j - left];
}
}
};
算法流程解析
在合并阶段,如果 nums[cur1] > nums[cur2],由于左右区间已经有序,cur1 及其之后的所有元素都会大于 nums[cur2] 及其之后的元素。因此,对于 index[cur1] 指向的原始位置,其右侧比它小的元素数量增加了 right - cur2 + 1 个。

50. 翻转对
题目描述
给定一个数组 nums,如果 i < j 且 nums[i] > 2 * nums[j],我们就将 (i, j) 称作一个重要翻转对。

解法思路
翻转对与逆序对类似,区别在于判断条件是 nums[i] > 2 * nums[j]。同样可以利用归并排序的分治思想,将问题分解为三部分:左半区间、右半区间、以及跨越左右区间的翻转对。
关键难点在于:普通逆序对可以在合并时顺便统计,但翻转对涉及数值倍数关系,直接合并可能会打乱顺序导致无法快速判断。因此,需要在归并排序之前先完成翻转对的统计,然后再进行正常的归并操作。
C++ 代码实现
class Solution {
public:
vector<int> tmp;
int count = 0;
int reversePairs(vector<int>& nums) {
tmp.resize(nums.size());
mergesort(nums, 0, nums.size() - 1);
return count;
}
void mergesort(vector<int>& nums, int left, int right) {
if(left >= right) {
return;
}
int mid = (right - left) / 2 + left;
mergesort(nums, left, mid);
mergesort(nums, mid + 1, right);
// 先统计翻转对,再合并数组
int cur1 = left, cur2 = mid + 1;
while(cur1 <= mid && cur2 <= right) {
if(nums[cur1] > (long long)2 * nums[cur2]) {
count += right - cur2 + 1;
cur1++;
} else {
cur2++;
}
}
// 正常归并排序
cur1 = left;
cur2 = mid + 1;
int i = 0;
while(cur1 <= mid && cur2 <= right) {
if(nums[cur1] > nums[cur2]) {
tmp[i++] = nums[cur1++];
} else {
tmp[i++] = nums[cur2++];
}
}
while(cur1 <= mid) {
tmp[i++] = nums[cur1++];
}
while(cur2 <= right) {
tmp[i++] = nums[cur2++];
}
for(int i = left; i <= right; i++) {
nums[i] = tmp[i - left];
}
}
};
算法流程解析
注意这里使用了 long long 类型转换,防止 2 * nums[cur2] 发生整数溢出。统计逻辑是:当 nums[cur1] 满足条件时,由于右区间有序,cur1 左侧(相对于 cur2)的所有元素也都满足条件,因此直接累加剩余数量。



