找回密码
 立即注册
首页 业界区 业界 C++用Mutex实现读写锁

C++用Mutex实现读写锁

巨耗 前天 10:23
近期答辩完成了,想回头看看之前没做过的2PL。
实现2PL有4种方式:

  • 死锁检测。本篇是为了做这个而实现的,做这个事情的原因是c++标准库的shared_mutex无法从外界告知获取锁失败。
  • 如果需要等待,那么马上结束txn。C++中有try_lock这样的方式,如果上锁失败就返回false,这样就可以实现这个了。
  • 如果需要等待,那么杀死当前已经获得锁的一方。
  • 在上锁前对资源排序。这种方式只在极度特殊的情况才能使用,一般是无法实现的。
2和4是最简单的,没什么好说的。3比1略容易一些。
基本思路

一个读写锁应该具有以下特征:

  • 多个读者可以同时访问
  • 写者独占访问
  • 写者与读者互斥
  • 避免写者饥饿或读者饥饿
  • 锁的递归使用
由于实现的锁不能够出现读饿死、写饿死的现象,所以我想到一个很简单的方法:先到先得。当然也许会有其他方案。
先到先得的方式下,如何判断一个线程是否该阻塞?

  • 第一个写请求之前的所有读请求可以进行
  • 如果第一个请求是写请求,那么只有这一个写请求可以进行
  • 如果没有写请求,那么所有读都可以进行
  • 如果没有读请求,那么第一个写请求可以进行。这实际是2的特殊情况
  • 其他请求都不可以进行
我们画图来说明一下。
假定某一刻有这些请求被阻塞,现在考虑挑出来可以执行的线程来执行
1.png

队列中,第一个写请求之前的读都可以进行,所以此时1,2线程是可以执行的。它们读完后释放锁,于是在这个队列中删除了1,2
2.png

1,2删除后,3可以正常执行,3执行后删除了。
3.png

3删除后,6在环检测的时候被要求结束,所以此时所有的读都可以进行。
4.png

实现先到先得,可以通过记录正在进行读的线程数量,正在进行写的线程数量,请求写但是被阻塞的线程数量,请求读但是被阻塞的线程数量,然后根据条件来分配资源给某个线程……维护的信息数量可能不止这些,比如说需要维护哪些线程的读被阻塞了。
而环检测的2PL,我们需要在外界通知线程锁获取失败,所以选择了使用队列来实现,这个队列需要支持:

  • 添加读者、写者(AddReader, AddWriter)
  • 删除读者、写者(RemoveReader, RemoveWriter,为了简化,统一为一个Remove了)
  • 当可以获得锁的时候,提醒可以获得锁的线程。这个可以用condition_variable实现
  • 确定某个线程是否应该阻塞
然而做这样一个队列还是需要费一些功夫的。
队列实现

明确了功能需求后,考虑一下需要什么样的数据结构。普通的队列肯定是不够的,毕竟我们会删除其中任意一个元素,容易想到的是map/set。然后考虑到先到先得的顺序要求,可以考虑额外记录一个逻辑时间timestamp,每当一个请求到达,就递增timestamp。由于加入了timestamp,所以为了支持删除,至少需要tid:timestamp的映射。而为了支持按timestamp查询,至少需要timestamp:tid的映射。此外,需要记录一个请求是读还是写,所以一共需要tid:timestamp的映射和timestamp:的映射。(好像用一个映射加一个链表好像也行?)
timestamp:映射关系,很容易想到通过std::map这种天然自带排序的数据结构来实现,即:

  • 从最小到最大遍历开头的读请求,这部分线程可以直接执行。
  • 如果是写请求开头的,那么这个写可以直接执行。
  • 解锁的时候删除该线程的记录。
笔者在此前做了CMU15445,里面的GC的watermark和这个非常显相似。CMU15445中作者提到了可以使用unordered_map来将时间复杂度从O(logn)优化到O(1),这种做法我想到了,所以这里的队列使用的都是unordered_map。
  1. #pragma once
  2. #ifndef READER_WRITER_QUEUE_H
  3. #define READER_WRITER_QUEUE_H
  4. // INSPIRED BY CMU15445 fall2023 watermark
  5. #include <cassert>
  6. #include <unordered_map>
  7. class ReaderWriterQueue {
  8. public:
  9.   void AddReader(int tid) {
  10.     assert(tid_ts.count(tid) == 0);
  11.     ts_tt[next_timestamp] = {tid, TidType::kRead};
  12.     tid_ts[tid] = next_timestamp;
  13.     next_timestamp++;
  14.   }
  15.   void AddWriter(int tid) {
  16.     assert(tid_ts.count(tid) == 0);
  17.     ts_tt[next_timestamp] = {tid, TidType::kWrite};
  18.     tid_ts[tid] = next_timestamp;
  19.     next_timestamp++;
  20.   }
  21.   void Remove(int tid) {
  22.     auto ts = tid_ts.find(tid);
  23.     if (ts == tid_ts.end()) return;
  24.     assert(ts_tt.count(ts->second) == 1);
  25.     ts_tt.erase(ts->second);
  26.     tid_ts.erase(ts);
  27.   }
  28.   bool ShallBlock(int tid) {
  29.     ResetMinWriteTimestamp();
  30.     ResetMinTimestamp(); // 这两个timestamp处理可以合并
  31.     auto iter = tid_ts.find(tid);
  32.     assert(iter != tid_ts.end());
  33.     assert(ts_tt.count(iter->second) == 1);
  34.     auto ts = iter->second;
  35.     auto [_, type] = ts_tt[ts];
  36.     // 如果读者之前有写者,那么就需要阻塞等待
  37.     if (type == TidType::kRead) return ts > min_write_ts;
  38.     // 如果写者之前有读者,那么就需要阻塞等待
  39.     if (min_ts < min_write_ts) return true;
  40.     // 如果写者之前有写者,那么就需要阻塞等待
  41.     return ts_tt[min_write_ts].tid != tid;
  42.   }
  43. private:
  44.   void ResetMinWriteTimestamp() {
  45.     for (; min_write_ts < next_timestamp; min_write_ts++) {
  46.       auto iter = ts_tt.find(min_write_ts);
  47.       if (iter == ts_tt.end()) {
  48.         continue;
  49.       } else if (iter->second.type == TidType::kWrite) {
  50.         break;
  51.       } else { // iter->second.type == TidType::kRead
  52.         continue;
  53.       }
  54.     }
  55.   }
  56.   void ResetMinTimestamp() {
  57.     for (; min_ts < next_timestamp; min_ts++) {
  58.       auto iter = ts_tt.find(min_ts);
  59.       if (iter != ts_tt.end())
  60.         break;
  61.     }
  62.   }
  63.   long next_timestamp = 0;
  64.   long min_write_ts = 0;
  65.   long min_ts = 0;
  66.   struct TidType {
  67.     int tid;
  68.     enum LockType {kRead, kWrite} type;
  69.     bool operator==(const TidType &rhs) const {
  70.       return tid == rhs.tid && type == rhs.type;
  71.     }
  72.   };
  73.   std::unordered_map<long, TidType> ts_tt;
  74.   std::unordered_map<int, long> tid_ts;
  75. };
  76. #endif // READER_WRITER_QUEUE_H
复制代码
将队列封装为读写锁

这一步封装已经非常容易了,一个请求到来,添加到队列中。如果需要阻塞,那么就通过condition_variable等待通知。解锁的时候,不仅仅需要在队列中进行移除,还需要notify_all。notify_all还可以优化,但是这不是那么容易的事情了,不考虑。
cv.wait可能因为意外而结束等待,所以写这个条件的时候需要进行循环判定,或者使用第二个pred参数。
  1. #pragma once
  2. #ifndef SIMPLE_SHARED_MUTEX_H
  3. #define SIMPLE_SHARED_MUTEX_H
  4. #include <condition_variable>
  5. #include <ctime>
  6. #include <cstdio>
  7. #include <mutex>
  8. #include <unistd.h>
  9. #include "reader_writer_queue.h"
  10. class SimpleSharedMutex {
  11. public:
  12.   void lock() {
  13.     std::unique_lock lock{mtx};
  14.     auto tid = ::gettid();
  15.     queue.AddWriter(tid);
  16.     while (queue.ShallBlock(tid)) cv.wait(lock);
  17.     // printf("lock %d\n", tid);
  18.   }
  19.   void shared_lock() {
  20.     std::unique_lock lock{mtx};
  21.     auto tid = ::gettid();
  22.     queue.AddReader(tid);
  23.     while (queue.ShallBlock(tid)) cv.wait(lock);
  24.     // printf("slock %d\n", tid);
  25.   }
  26.   void unlock() {
  27.     std::unique_lock lock{mtx};
  28.     queue.Remove(::gettid());
  29.     cv.notify_all();
  30.     // printf("ulock %d\n", ::gettid());
  31.   }
  32.   void shared_unlock() {
  33.     std::unique_lock lock{mtx};
  34.     queue.Remove(::gettid());
  35.     cv.notify_all();
  36.     // printf("uslock %d\n", ::gettid());
  37.   }
  38. private:
  39.   std::mutex mtx;
  40.   ReaderWriterQueue queue;
  41.   std::condition_variable cv;
  42. };
  43. #endif // SIMPLE_SHARED_MUTEX_H
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册