找回密码
 立即注册
首页 业界区 科技 3D Gaussian splatting 07: 代码阅读-训练载入数据和保 ...

3D Gaussian splatting 07: 代码阅读-训练载入数据和保存结果

泠邸 前天 22:04
目录


  • 3D Gaussian splatting 01: 环境搭建
  • 3D Gaussian splatting 02: 快速评估
  • 3D Gaussian splatting 03: 用户数据训练和结果查看
  • 3D Gaussian splatting 04: 代码阅读-提取相机位姿和稀疏点云
  • 3D Gaussian splatting 05: 代码阅读-训练整体流程
  • 3D Gaussian splatting 06: 代码阅读-训练参数
  • 3D Gaussian splatting 07: 代码阅读-训练载入数据和保存结果
  • 3D Gaussian splatting 08: 代码阅读-渲染
训练载入数据

在 train.py 中载入数据对应的方法调用栈如下, 因为convert.py预处理使用的是colmap, 读取数据最终调用的是 readColmapSceneInfo 方法
  1. Scene(dataset, gaussians)
  2. └─sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
  3.   └─readColmapSceneInfo(path, images, depths, eval, train_test_exp, llffhold=8)
复制代码
读取流程是

  • 从 images.bin, cameras.bin 读取相机参数和每一帧的位姿
  • 区分训练集和测试集
  • 从 points3D.bin 读取3D点云
  1. def read_points3D_binary(path_to_model_file):
  2.     """
  3.     see: src/base/reconstruction.cc
  4.         void Reconstruction::ReadPoints3DBinary(const std::string& path)
  5.         void Reconstruction::WritePoints3DBinary(const std::string& path)
  6.     """
  7.     with open(path_to_model_file, "rb") as fid:
  8.         num_points = read_next_bytes(fid, 8, "Q")[0]
  9.         # 创建未初始化的 n * 3 数组, 随机值
  10.         xyzs = np.empty((num_points, 3))
  11.         rgbs = np.empty((num_points, 3))
  12.         errors = np.empty((num_points, 1))
  13.         for p_id in range(num_points):
  14.             binary_point_line_properties = read_next_bytes(
  15.                 fid, num_bytes=43, format_char_sequence="QdddBBBd")
  16.             xyz = np.array(binary_point_line_properties[1:4])
  17.             rgb = np.array(binary_point_line_properties[4:7])
  18.             error = np.array(binary_point_line_properties[7])
  19.             track_length = read_next_bytes(
  20.                 fid, num_bytes=8, format_char_sequence="Q")[0]
  21.             track_elems = read_next_bytes(
  22.                 fid, num_bytes=8*track_length,
  23.                 format_char_sequence="ii"*track_length)
  24.             xyzs[p_id] = xyz
  25.             rgbs[p_id] = rgb
  26.             errors[p_id] = error
  27.     return xyzs, rgbs, errors
复制代码
里面用到的read_next_bytes方法, 读取一段二进制字节, 使用struct.unpack按指定的格式, 转为对应的变量
[code]def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="
您需要登录后才可以回帖 登录 | 立即注册