Skip to content

门泊吴船亦已谋

快排的多线程优化

讲真我从高中到现在就没有手写过一次快排(是 sort 不够好吗)

然后最近看书的时候看到划分数据块并行优化的思路,发现快排和堆排都能多线程优化

于是写一个试试

思路

注意到快排在每次迭代后原数组都会被分成没有交集的两块,那么他们就可以并行;堆排也同理

注意

有一个问题是不能每次迭代都新开线程(我 CPU 就四个核心线程多也没用)
那就开一个线程池,把迭代后产生的排序任务放到队列里,有一点类似我们递归爆栈时会用到的手写栈

还有一个问题是,在区间长度比较小的时候,由于多线程的调度成本,直接在本线程内递归的性能反而更好

线程池

其实比较常见,随便写一个就行

class ThreadPool {
   private:
    typedef function<void()> Task;
    int threadCnt;
    bool isRunning;
    vector<thread> workers;
    queue<Task> tasks;
    mutex mtx;
    condition_variable cv;
    void ProcessWorker() {
        while (isRunning) {
            Task task;
            {
                unique_lock<mutex> lock(mtx);
                if (tasks.empty()) cv.wait(lock);
                if (!tasks.empty()) {
                    task = tasks.front();
                    tasks.pop();
                }
            }
            if (task) task();
        }
    }

   public:
    explicit ThreadPool(int cnt) : threadCnt(cnt), isRunning(0) {}
    ~ThreadPool() {
        if (isRunning) Stop();
    }
    void Start() {
        isRunning = 1;
        for (int i = 0; i < threadCnt; i++)
            workers.emplace_back(thread(&ThreadPool::ProcessWorker, this));
    }
    void Push(const Task &task) {
        if (isRunning) {
            lock_guard<mutex> lock(mtx);
            tasks.push(task);
            cv.notify_one();
        }
    }
    void Stop() {
        {
            lock_guard<mutex> lock(mtx);
            isRunning = false;
            cv.notify_all();
        }
        for (auto &thd : workers) thd.join();
    }
};

核心代码

每一个 SortFunc 对象都是快排中的一次迭代,给他传一个区间就行

struct SortFunc {
    ThreadPool &pool;
    vector<int> &arr;
    int l, r;

    SortFunc(ThreadPool &pool, vector<int> &arr, int l, int r) : pool(pool), arr(arr), l(l), r(r) {}
    void operator()() {
        if (l >= r) return;
        int v = arr[l];
        int ll = l;
        int rr = r;
        while (l < r) {
            while (l <= rr && arr[l] < v) l++;
            while (r >= ll && arr[r] > v) r--;
            if (l <= r)
                swap(arr[l], arr[r]), l++, r--;
        }
        if (r - ll + 1 > 64)
            pool.Push(SortFunc(pool, arr, ll, r));
        else
            SortFunc(pool, arr, ll, r)();
        if (rr - l + 1 > 64)
            pool.Push(SortFunc(pool, arr, l, rr));
        else
            SortFunc(pool, arr, l, rr)();
    }
};

int main() {
    vector<int> originArr;
    for (int i = 0; i < 1000000; i++)
        originArr.push_back(rand() % 2000);
    ThreadPool pool(4);
    pool.Start();
    pool.Push(SortFunc(pool, originArr, 0, originArr.size() - 1));
    getchar();
    return 0;
}