找回密码
 立即注册
首页 业界区 业界 DAPO代码实现浅析

DAPO代码实现浅析

羡渥蛛 2025-10-20 01:15:01
参考verl对dapo的实现,首先咱们看一下入口.sh和.py文件,在./recipe/dapo/文件夹中有以下目录
  1. .
  2. ├── config
  3. │   ├── dapo_megatron_trainer.yaml
  4. │   └── dapo_trainer.yaml
  5. ├── dapo_ray_trainer.py
  6. ├── main_dapo.py
  7. ├── prepare_dapo_data.sh
  8. ├── README.md
  9. ├── run_dapo_qwen2.5_32b.sh
复制代码
整体的执行顺序:

  • main_dapo.py:数据加载初始化、初始化actor_rollout model、rm model,加载reward_manager
  • dapo_ray_trainer.py:RL训练流程

    • 对batch进行repeate,每个q采样n次
    • 记录每个采样的log,以及对应的reward_score 和 advantage

      • filter掉一个q的所有sample的score都是1或都是0,继续获取新的q进行采样,直到满足要求的batch的大小达到train_prompt_bsz。(值得注意的是,batch大小是gen_prompt_bsz=3*train_prompt_bsz,通过提高采样q的个数,避免满足要求的q不到train_prompt_bsz)。

    • 每mini_batch的data进行模型更新

      • 每micro_batch的data进行前向传播(token-mean loss)与梯度计算


具体代码实例:
main_dapo.py
  1. # Copyright 2024 Bytedance Ltd. and/or its affiliates
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
  16. """
  17. import os
  18. import socket
  19. import hydra
  20. import ray
  21. from omegaconf import OmegaConf
  22. from verl.trainer.ppo.reward import load_reward_manager
  23. from verl.utils.device import is_cuda_available
  24. from .dapo_ray_trainer import RayDAPOTrainer
  25. @hydra.main(config_path="config", config_name="dapo_trainer", version_base=None)
  26. def main(config):
  27.     run_ppo(config)
  28. #################################################################
  29. # RL训练入口
  30. #################################################################
  31. def run_ppo(config) -> None:
  32.     if not ray.is_initialized():
  33.         # this is for local ray cluster
  34.         default_runtime_env = {
  35.             "env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
  36.         }
  37.         ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
  38.         runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
  39.         runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
  40.         ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
  41.         print(f"ray init kwargs: {ray_init_kwargs}")
  42.         ray.init(**OmegaConf.to_container(ray_init_kwargs))
  43.     try:
  44.         if (
  45.             is_cuda_available
  46.             and config.global_profiler.tool == "nsys"
  47.             and OmegaConf.select(config.global_profiler, "steps") is not None
  48.             and len(OmegaConf.select(config.global_profiler, "steps")) > 0
  49.         ):
  50.             nsight_options = OmegaConf.to_container(
  51.                 config.global_profiler.global_tool_config.nsys.controller_nsight_options
  52.             )
  53.             runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
  54.         else:
  55.             runner = TaskRunner.remote()
  56.         ray.get(runner.run.remote(config))
  57.     finally:
  58.         if ray.is_initialized():
  59.             ray.shutdown()
  60. @ray.remote(num_cpus=1)  # please make sure main_task is not scheduled on head
  61. class TaskRunner:
  62.     def run(self, config):
  63.         # print initial config
  64.         from pprint import pprint
  65.         from omegaconf import OmegaConf
  66.         from verl.utils.fs import copy_to_local
  67.         print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
  68.         pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
  69.         OmegaConf.resolve(config)
  70.         # download the checkpoint from hdfs
  71.         local_path = copy_to_local(config.actor_rollout_ref.model.path)
  72.         # instantiate tokenizer
  73.         from verl.utils import hf_processor, hf_tokenizer
  74.         tokenizer = hf_tokenizer(local_path)
  75.         processor = hf_processor(local_path, use_fast=True)  # used for multimodal LLM, could be none
  76.         from verl.single_controller.ray import RayWorkerGroup
  77.         #################################################################
  78.         # 加载actor worker
  79.         #################################################################
  80.         # define worker classes
  81.         if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
  82.             assert config.critic.strategy in {"fsdp", "fsdp2"}
  83.             from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
  84.             ray_worker_group_cls = RayWorkerGroup
  85.         elif config.actor_rollout_ref.actor.strategy == "megatron":
  86.             assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
  87.             from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
  88.             ray_worker_group_cls = RayWorkerGroup
  89.         else:
  90.             raise NotImplementedError
  91.         from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
  92.         role_worker_mapping = {
  93.             Role.ActorRollout: ray.remote(ActorRolloutRefWorker),
  94.             Role.Critic: ray.remote(CriticWorker),
  95.         }
  96.         global_pool_id = "global_pool"
  97.         resource_pool_spec = {
  98.             global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
  99.         }
  100.         mapping = {
  101.             Role.ActorRollout: global_pool_id,
  102.             Role.Critic: global_pool_id,
  103.         }
  104.         # we should adopt a multi-source reward function here
  105.         # - for rule-based rm, we directly call a reward score
  106.         # - for model-based rm, we call a model
  107.         # - for code related prompt, we send to a sandbox if there are test cases
  108.         # - finally, we combine all the rewards together
  109.         # - The reward type depends on the tag of the data
  110.         if config.reward_model.enable:
  111.             if config.reward_model.strategy in {"fsdp", "fsdp2"}:
  112.                 from verl.workers.fsdp_workers import RewardModelWorker
  113.             elif config.reward_model.strategy == "megatron":
  114.                 from verl.workers.megatron_workers import RewardModelWorker
  115.             else:
  116.                 raise NotImplementedError
  117.             role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
  118.             mapping[Role.RewardModel] = global_pool_id
  119.         # reference model
  120.         if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
  121.             role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
  122.             mapping[Role.RefPolicy] = global_pool_id
  123.         #################################################################
  124.         # 加载reward manager函数。用于根据data计算对应的reward score
  125.         #################################################################
  126.         reward_fn = load_reward_manager(
  127.             config,
  128.             tokenizer,
  129.             0,
  130.             max_resp_len=config.data.max_response_length,
  131.             overlong_buffer_cfg=config.reward_model.overlong_buffer,
  132.         )
  133.         # Note that we always use function-based RM for validation
  134.         val_reward_fn = load_reward_manager(
  135.             config,
  136.             tokenizer,
  137.             1,
  138.             max_resp_len=config.data.max_response_length,
  139.             overlong_buffer_cfg=config.reward_model.overlong_buffer,
  140.         )
  141.         resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
  142.         #################################################################
  143.         # 加载主要的DAPO RL训练类,并运行.fit()
  144.         #################################################################
  145.         trainer = RayDAPOTrainer(
  146.             config=config,
  147.             tokenizer=tokenizer,
  148.             processor=processor,
  149.             role_worker_mapping=role_worker_mapping,
  150.             resource_pool_manager=resource_pool_manager,
  151.             ray_worker_group_cls=ray_worker_group_cls,
  152.             reward_fn=reward_fn,
  153.             val_reward_fn=val_reward_fn,
  154.         )
  155.         trainer.init_workers()
  156.         trainer.fit()
  157. if __name__ == "__main__":
  158.     main()
复制代码
我们紧接着来看一下from verl.trainer.ppo.reward import load_reward_manager。
配置文件中verl/recipe/dapo/run_dapo_qwen2.5_32b.sh给出了reward的类型
  1. enable_overlong_buffer=True
  2. overlong_buffer_len=$((1024 * 4)) # overlong soft
  3. overlong_penalty_factor=1.0
  4. reward_model.reward_manager=dapo \
  5. reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
  6. reward_model.overlong_buffer.len=${overlong_buffer_len} \
  7. reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
复制代码
verl.trainer.ppo.reward.py
  1. def load_reward_manager(
  2.     config: DictConfig, tokenizer: Any, num_examine: int, **reward_kwargs: Any
  3. ) -> AbstractRewardManager:
  4.     """
  5.     Load and initialize a reward manager based on the configuration.
  6.     Args:
  7.         config: PPO trainer configuration object containing reward_model fields.
  8.         tokenizer: Tokenizer object used for processing text.
  9.         num_examine: Number of samples to examine.
  10.         **reward_kwargs: Additional keyword arguments for the reward manager.
  11.     Returns:
  12.         An instance of the specified reward manager class.
  13.     """
  14.     # Try to get a custom reward function based on the configuration
  15.     # user defined reward manager can be registered in custom_reward_fn
  16.     compute_score = get_custom_reward_fn(config)
  17.     final_compute_score = compute_score
  18.     # The list of pre-defined reward managers are defined in `verl/workers/reward_manager/`:
  19.     # naive: NaiveRewardManager
  20.     # prime: PrimeRewardManager
  21.     # batch: BatchRewardManager
  22.     # dapo: DAPORewardManager
  23.     # Note(haibin.lin): For custom reward managers, please make sure they are imported and
  24.     # registered via `verl.workers.reward_manager.register`
  25.     # By default reward_manager is set to naive (NaiveRewardManager)
  26.     #################################################################
  27.     # 在这里加载具体的reward_manager
  28.     #################################################################
  29.     reward_manager_name = config.reward_model.get("reward_manager", "naive")
  30.     reward_manager_cls = get_reward_manager_cls(reward_manager_name)
  31.     if compute_score is None:
  32.         sandbox_config = config.reward_model.get("sandbox_fusion")
  33.         sandbox_url = sandbox_config.get("url") if sandbox_config else None
  34.         memory_limit_mb = sandbox_config.get("memory_limit_mb", 1024)
  35.         if sandbox_url:
  36.             sandbox_manager = multiprocessing.Manager()
  37.             # Create a semaphore to control concurrent access to the sandbox
  38.             _concurrent_semaphore = sandbox_manager.Semaphore(sandbox_config.get("max_concurrent", 64))
  39.             final_compute_score = partial(
  40.                 default_compute_score,
  41.                 sandbox_fusion_url=sandbox_url,
  42.                 concurrent_semaphore=_concurrent_semaphore,
  43.                 memory_limit_mb=memory_limit_mb,
  44.             )
  45.         else:
  46.             final_compute_score = default_compute_score
  47.     #################################################################
  48.     # 这里的reward_manager_cls 其实是DAPO,
  49.     #################################################################
  50.     # Instantiate and return the reward manager with the specified parameters
  51.     return reward_manager_cls(
  52.         tokenizer=tokenizer,
  53.         num_examine=num_examine,
  54.         compute_score=final_compute_score,
  55.         reward_fn_key=config.data.reward_fn_key,
  56.         **reward_kwargs,
  57.     )
复制代码
这里需要知道dapo的reward_manager_cls 具体是什么,因为reward需要batch数据才能计算,因此对于reward manager咱们先按下不表(其实dapo对应的reward_manager_cls是在verl/verl/workers/reward_manager/dapo.py),先去dapo_ray_trainer.py看一下batch是怎么采样的,再回来仔细阅读reward的具体计算方法。
dapo_ray_trainer.py
  1. #################################################################
  2. # RayDAPOTrainer继承于RayPPOTrainer
  3. # fit()函数:执行dapo的训练,包括(1)动态采样(2)overlong soft reward计算(3)token-level loss
  4. #################################################################
  5. class RayDAPOTrainer(RayPPOTrainer):
  6.     """
  7.     Note that this trainer runs on the driver process on a single CPU/GPU node.
  8.     """
  9.     def fit(self):
  10.         """
  11.         The training loop of PPO.
  12.         The driver process only need to call the compute functions of the worker group through RPC
  13.         to construct the PPO dataflow.
  14.         The light-weight advantage computation is done on the driver process.
  15.         """
  16.         from omegaconf import OmegaConf
  17.         from verl.utils.tracking import Tracking
  18.         logger = Tracking(
  19.             project_name=self.config.trainer.project_name,
  20.             experiment_name=self.config.trainer.experiment_name,
  21.             default_backend=self.config.trainer.logger,
  22.             config=OmegaConf.to_container(self.config, resolve=True),
  23.         )
  24.         self.global_steps = 0
  25.         self.gen_steps = 0
  26.         # load checkpoint before doing anything
  27.         self._load_checkpoint()
  28.         # perform validation before training
  29.         # currently, we only support validation using the reward_function.
  30.         if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
  31.             val_metrics = self._validate()
  32.             assert val_metrics, f"{val_metrics=}"
  33.             pprint(f"Initial validation metrics: {val_metrics}")
  34.             logger.log(data=val_metrics, step=self.global_steps)
  35.             if self.config.trainer.get("val_only", False):
  36.                 return
  37.         if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
  38.             rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
  39.             rollout_skip.wrap_generate_sequences()
  40.         # add tqdm
  41.         progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
  42.         # we start from step 1
  43.         self.global_steps += 1
  44.         self.gen_steps += 1
  45.         last_val_metrics = None
  46.         prev_step_profile = False
  47.         curr_step_profile = (
  48.             self.global_steps in self.config.global_profiler.steps
  49.             if self.config.global_profiler.steps is not None
  50.             else False
  51.         )
  52.         next_step_profile = False
  53.         timing_raw = defaultdict(float)
  54.         batch = None
  55.         #################################################################
  56.         # num_prompt_in_batch:记录filter后,std不等于0的q的个数,当模型更新后重新赋值为0
  57.         # num_gen_batches: 记录当前使用了多少个gen_batch,当模型更新后重新赋值为0
  58.         #################################################################
  59.         num_prompt_in_batch = 0
  60.         num_gen_batches = 0
  61.         #################################################################
  62.         # 正式开始训练,循环每个epoch后,循环每个gen_batch
  63.         #################################################################
  64.         for epoch in range(self.config.trainer.total_epochs):
  65.             for batch_dict in self.train_dataloader:
  66.                 metrics = {}
  67.                 with marked_timer("start_profile", timing_raw):
  68.                     self._start_profiling(
  69.                         not prev_step_profile and curr_step_profile
  70.                         if self.config.global_profiler.profile_continuous_steps
  71.                         else curr_step_profile
  72.                     )
  73.                 #################################################################
  74.                 # new_batch 是DataProto类型(具体见verl/verl/protocol.py),
  75.                 # new_batch.batch是TensorDict类型
  76.                 # new_batch中q的数量是可训练batch大小的3倍(增加采样的batch的q的个数)
  77.                 #################################################################
  78.                 new_batch: DataProto = DataProto.from_single_dict(batch_dict)
  79.                 num_gen_batches += 1
  80.                 # pop those keys for generation
  81.                 if "multi_modal_data" in new_batch.non_tensor_batch.keys():
  82.                     gen_batch = new_batch.pop(
  83.                         batch_keys=["input_ids", "attention_mask", "position_ids"],
  84.                         non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
  85.                     )
  86.                 else:
  87.                     # 从new_batch中提取对应的key,构建gen_batch
  88.                     gen_batch = new_batch.pop(
  89.                         batch_keys=["input_ids", "attention_mask", "position_ids"],
  90.                         non_tensor_batch_keys=["raw_prompt_ids"],
  91.                     )
  92.                 # 这里为什么要repeate呢,因为每个prompt要采样n次,所以repeat n次。这里的interleave=True
  93.                 # gen_batch: (bsz, response_length),
  94.                                                                 # gen_batch_output: (bsz*n, response_length)
  95.                 gen_batch_output = gen_batch.repeat(
  96.                     repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
  97.                 )
  98.                 is_last_step = self.global_steps >= self.total_training_steps
  99.                 with marked_timer("step", timing_raw):
  100.                     # generate a batch
  101.                     with marked_timer("gen", timing_raw, "red"):
  102.                         gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
  103.                         timing_raw.update(gen_batch_output.meta_info["timing"])
  104.                         gen_batch_output.meta_info.pop("timing", None)
  105.                     # 这个advatange 可以先忽略。RMAX需要先计算 贪心采样的sample的logits作为后序adv计算的baseline
  106.                     if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
  107.                         with marked_timer("gen_max", timing_raw, "red"):
  108.                             gen_baseline_batch = deepcopy(gen_batch)
  109.                             # 这里是贪心采样的baseline,do_sample = False
  110.                             gen_baseline_batch.meta_info["do_sample"] = False
  111.                             gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
  112.                             new_batch = new_batch.union(gen_baseline_output)
  113.                             # compute reward model score on new_batch
  114.                             rm_scores = None
  115.                             if self.use_rm and "rm_scores" not in new_batch.batch.keys():
  116.                                 rm_scores = self.rm_wg.compute_rm_score(new_batch)
  117.                                 new_batch = new_batch.union(rm_scores)
  118.                             reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)
  119.                             reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
  120.                             keys_to_pop = set(gen_baseline_output.batch.keys())
  121.                             if rm_scores is not None:
  122.                                 keys_to_pop.update(rm_scores.batch.keys())
  123.                             new_batch.pop(batch_keys=list(keys_to_pop))
  124.                             new_batch.batch["reward_baselines"] = reward_baseline_tensor
  125.                             del rm_scores, gen_baseline_batch, gen_baseline_output
  126.                     #################################################################
  127.                     # new_batch的大小是gen_prompt_bsz
  128.                     # 对每一个prompt设置一个专属的标识 uid
  129.                                                                                 # 之所以设置uid,是因为之后对sample计算reward时,需要对同一个q的n个sample的reward标准化
  130.                     #################################################################
  131.                     new_batch.non_tensor_batch["uid"] = np.array(
  132.                         [str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
  133.                     )
  134.                     # 对batch中的每个key进行repeat(这里应该主要是对uid进行repeat)
  135.                     # repeat to align with repeated responses in rollout
  136.                     new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
  137.                     # 把采样完的放到new_batch中
  138.                     new_batch = new_batch.union(gen_batch_output)
  139.                     with marked_timer("reward", timing_raw, "yellow"):
  140.                         # compute scores. Support both model and function-based.
  141.                         # We first compute the scores using reward model. Then, we call reward_fn to combine
  142.                         # the results from reward model and rule-based results.
  143.                         if self.use_rm and "rm_scores" not in new_batch.batch.keys():
  144.                             # we first compute reward model score
  145.                             reward_tensor = self.rm_wg.compute_rm_score(new_batch)
  146.                             new_batch = new_batch.union(reward_tensor)
  147.                         # 计算new_batch各个采样的reward,根据设置好的self.reward_fn
  148.                         # we combine with rule-based rm
  149.                         reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn)
  150.                         new_batch.batch["token_level_scores"] = reward_tensor
  151.                         if reward_extra_infos_dict:
  152.                             new_batch.non_tensor_batch.update(
  153.                                 {k: np.array(v) for k, v in reward_extra_infos_dict.items()}
  154.                             )
  155.                         # compute rewards. apply_kl_penalty if available
  156.                         if self.config.algorithm.use_kl_in_reward:
  157.                             new_batch, kl_metrics = apply_kl_penalty(
  158.                                 new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
  159.                             )
  160.                             metrics.update(
  161.                                 kl_metrics
  162.                             )  # TODO: This will be cleared if we use multiple genenration batches
  163.                         else:
  164.                             new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
  165.                                                                                
  166.                     #################################################################
  167.                     # dapo的filter(dynamic sample)部分
  168.                     #################################################################
  169.                     if not self.config.algorithm.filter_groups.enable:
  170.                         batch = new_batch
  171.                     else:  # NOTE: When prompts after filtering is less than train batch size,
  172.                         # we skip to the next generation batch
  173.                         metric_name = self.config.algorithm.filter_groups.metric
  174.                         if metric_name == "seq_final_reward":
  175.                             # Turn to numpy for easier filtering
  176.                             new_batch.non_tensor_batch["seq_final_reward"] = (
  177.                                 new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
  178.                             )
  179.                         elif metric_name == "seq_reward":
  180.                             new_batch.non_tensor_batch["seq_reward"] = (
  181.                                 new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
  182.                             )
  183.                         # {uid: [r1,r2,r3,...,rn], uid: [...], ...},记录每个轨迹所有采样的reward
  184.                         # Collect the sequence reward for each trajectory
  185.                         prompt_uid2metric_vals = defaultdict(list)
  186.                         for uid, metric_val in zip(
  187.                             new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
  188.                         ):
  189.                             prompt_uid2metric_vals[uid].append(metric_val)
  190.                         # 每个q的reward的std
  191.                         prompt_uid2metric_std = {}
  192.                         for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
  193.                             prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
  194.                         # 保留reward std不是0的q的uid
  195.                         kept_prompt_uids = [
  196.                             uid
  197.                             for uid, std in prompt_uid2metric_std.items()
  198.                             if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
  199.                         ]
  200.                         # 累积std不是0的q
  201.                         num_prompt_in_batch += len(kept_prompt_uids)
  202.                         # 记录留下来的q的sample的idx
  203.                         kept_traj_idxs = []
  204.                         for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
  205.                             if traj_from_prompt_uid in kept_prompt_uids:
  206.                                 kept_traj_idxs.append(idx)
  207.                         # 基于traj的id,检索对应的new_batch
  208.                         new_batch = new_batch[kept_traj_idxs]
  209.                         # batch是留下的traj数据的累积
  210.                         batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
  211.                         # .sh文件配置的 可以训练的batch的最小大小(q的数量)
  212.                         prompt_bsz = self.config.data.train_batch_size
  213.                         # 如果现有的累积filter出来的q的数量小于 配置的最小数量,则continue继续使用下一个new_batch进行累积
  214.                         if num_prompt_in_batch < prompt_bsz:
  215.                             print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
  216.                             max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
  217.                             # max_num_gen_batches是最多可以使用的gen_batch的个数
  218.                             # 如果其小于0的话,即没有限制;若num_gen_batches < max_num_gen_batches则继续continue
  219.                             if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
  220.                                 print(f"{num_gen_batches=}. Keep generating...")
  221.                                 self.gen_steps += 1
  222.                                 is_last_step = self.global_steps >= self.total_training_steps
  223.                                 continue
  224.                             else:
  225.                                 raise ValueError(
  226.                                     f"{num_gen_batches=} >= {max_num_gen_batches=}."
  227.                                     + " Generated too many. Please check if your data are too difficult."
  228.                                     + " You could also try set max_num_gen_batches=0 to enable endless trials."
  229.                                 )
  230.                       # 累积的符合的q个个数>=最小的可以训练的batch的大小  
  231.                       else:
  232.                             # Align the batch
  233.                             traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
  234.                             #################################################################
  235.                             # 对齐一下,多余的轨迹会被抛弃,不知道会不会导致采样的利用效率不高,
  236.                             # 会不会导致一些轨迹根本不会被训练到
  237.                             #################################################################
  238.                             batch = batch[:traj_bsz]
  239.                     #################################################################
  240.                     # actor模型更新
  241.                     #################################################################
  242.                     # === Updating ===
  243.                     batch.batch["response_mask"] = compute_response_mask(batch)
  244.                     # Balance the number of valid tokens across DP ranks.
  245.                     # NOTE: This usually changes the order of data in the `batch`,
  246.                     # which won't affect the advantage calculation (since it's based on uid),
  247.                     # but might affect the loss calculation (due to the change of mini-batching).
  248.                     # TODO: Decouple the DP balancing and mini-batching.
  249.                     if self.config.trainer.balance_batch:
  250.                         self._balance_batch(batch, metrics=metrics)
  251.                     # compute global_valid tokens
  252.                     batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
  253.                     #################################################################
  254.                     # 记录filter后的batch的每个traj的采样时的logtis(token-level)
  255.                     # 用于计算重要性采样的比值
  256.                     #################################################################
  257.                     # recompute old_log_probs
  258.                     with marked_timer("old_log_prob", timing_raw, "blue"):
  259.                         old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
  260.                         entropys = old_log_prob.batch["entropys"]
  261.                         response_masks = batch.batch["response_mask"]
  262.                         loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
  263.                         # 这里dapo的loss_agg_mode是“token_mean”
  264.                         entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
  265.                         old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
  266.                         metrics.update(old_log_prob_metrics)
  267.                         old_log_prob.batch.pop("entropys")
  268.                         batch = batch.union(old_log_prob)
  269.                     if self.use_reference_policy:
  270.                         # compute reference log_prob
  271.                         with marked_timer("ref", timing_raw, "olive"):
  272.                             ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
  273.                             batch = batch.union(ref_log_prob)
  274.                     # compute values
  275.                     if self.use_critic:
  276.                         with marked_timer("values", timing_raw, "cyan"):
  277.                             values = self.critic_wg.compute_values(batch)
  278.                             batch = batch.union(values)
  279.                     # 计算token_level的重要性采样
  280.                     # Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
  281.                     batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
  282.                     # IS and mismatch metrics already have mismatch/ prefix
  283.                     metrics.update(is_metrics)
  284.                     #################################################################
  285.                     # 计算advantage
  286.                     #################################################################
  287.                     with marked_timer("adv", timing_raw, "brown"):
  288.                         # compute advantages, executed on the driver process
  289.                         norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
  290.                         batch = compute_advantage(
  291.                             batch,
  292.                             adv_estimator=self.config.algorithm.adv_estimator,
  293.                             gamma=self.config.algorithm.gamma,
  294.                             lam=self.config.algorithm.lam,
  295.                             num_repeat=self.config.actor_rollout_ref.rollout.n,
  296.                             norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
  297.                         )
  298.                     # update critic
  299.                     if self.use_critic:
  300.                         with marked_timer("update_critic", timing_raw, "pink"):
  301.                             critic_output = self.critic_wg.update_critic(batch)
  302.                         critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
  303.                         metrics.update(critic_output_metrics)
  304.                     # implement critic warmup
  305.                     if self.config.trainer.critic_warmup <= self.global_steps:
  306.                         #################################################################
  307.                         # 更新actor model(batch的大小是train_prompt_size)
  308.                         # 每个mini_bsz 更新一次模型(参数-累积梯度)
  309.                         # 每个micro_bsz 累积一次梯度
  310.                         #################################################################
  311.                         # update actor
  312.                         with marked_timer("update_actor", timing_raw, "red"):
  313.                             actor_output = self.actor_rollout_wg.update_actor(batch)
  314.                         actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
  315.                         metrics.update(actor_output_metrics)
  316.                     # Log rollout generations if enabled
  317.                     rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
  318.                     if rollout_data_dir:
  319.                         self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir)
  320.                 # validate
  321.                 if (
  322.                     self.val_reward_fn is not None
  323.                     and self.config.trainer.test_freq > 0
  324.                     and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
  325.                 ):
  326.                     with marked_timer("testing", timing_raw, "green"):
  327.                         val_metrics: dict = self._validate()
  328.                         if is_last_step:
  329.                             last_val_metrics = val_metrics
  330.                     metrics.update(val_metrics)
  331.                 if self.config.trainer.save_freq > 0 and (
  332.                     is_last_step or self.global_steps % self.config.trainer.save_freq == 0
  333.                 ):
  334.                     with marked_timer("save_checkpoint", timing_raw, "green"):
  335.                         self._save_checkpoint()
  336.                 with marked_timer("stop_profile", timing_raw):
  337.                     next_step_profile = (
  338.                         self.global_steps + 1 in self.config.global_profiler.steps
  339.                         if self.config.global_profiler.steps is not None
  340.                         else False
  341.                     )
  342.                     self._stop_profiling(
  343.                         curr_step_profile and not next_step_profile
  344.                         if self.config.global_profiler.profile_continuous_steps
  345.                         else curr_step_profile
  346.                     )
  347.                     prev_step_profile = curr_step_profile
  348.                     curr_step_profile = next_step_profile
  349.                 # collect metrics
  350.                 metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
  351.                 metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
  352.                 # TODO: implement actual tflpo and theoretical tflpo
  353.                 n_gpus = self.resource_pool_manager.get_n_gpus()
  354.                 metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
  355.                 timing_raw = defaultdict(float)  # clear timing
  356.                 metrics["train/num_gen_batches"] = num_gen_batches
  357.                 batch = None
  358.                 num_prompt_in_batch = 0
  359.                 num_gen_batches = 0
  360.                 # TODO: make a canonical logger that supports various backend
  361.                 logger.log(data=metrics, step=self.global_steps)
  362.                 if is_last_step:
  363.                     pprint(f"Final validation metrics: {last_val_metrics}")
  364.                     progress_bar.close()
  365.                     return
  366.                 progress_bar.update(1)
  367.                 self.global_steps += 1
  368.                 self.gen_steps += 1
  369.         # check if last step checkpint exists
  370.         checkpoint_dir = os.path.join(self.config.trainer.default_local_dir, f"global_step_{self.global_steps}")
  371.         if not os.path.exists(checkpoint_dir):
  372.             # save last step checkpoint
  373.             timing_raw = defaultdict(float)
  374.             with marked_timer("save_checkpoint", timing_raw, "green"):
  375.                 self._save_checkpoint()
  376.             metrics = {f"timing/{k}": v for k, v in timing_raw.items()}
  377.             logger.log(data=metrics, step=self.global_steps)
复制代码
这时候咱们再看一下dapo的reward manager实现:主要和ppo的区别在于使用了overlong_buffer,计算长度的reward
verl/verl/workers/reward_manager/dapo.py
  1. #################################################################
  2. # 这里使用dapo注册了DAPORewardManager,因此可以用
  3. # reward_manager_cls = get_reward_manager_cls(reward_manager_name)得到
  4. #################################################################
  5. @register("dapo")
  6. class DAPORewardManager(AbstractRewardManager):
  7.     """The reward manager."""
  8.     def __init__(
  9.         self,
  10.         tokenizer,
  11.         num_examine,
  12.         compute_score=None,
  13.         reward_fn_key="data_source",
  14.         max_resp_len=None,
  15.         overlong_buffer_cfg=None,
  16.     ) -> None:
  17.         self.tokenizer = tokenizer
  18.         self.num_examine = num_examine  # the number of batches of decoded responses to print to the console
  19.         self.compute_score = compute_score or default_compute_score
  20.         self.reward_fn_key = reward_fn_key
  21.         self.overlong_buffer_cfg = overlong_buffer_cfg
  22.         self.max_resp_len = max_resp_len
  23.         if self.overlong_buffer_cfg is not None:
  24.             assert self.max_resp_len is not None, (
  25.                 f"max_resp_len must be provided if {overlong_buffer_cfg=}, but got None"
  26.             )
  27.             assert self.max_resp_len >= self.overlong_buffer_cfg.len, (
  28.                 "max_resp_len must be larger than overlong_buffer.len"
  29.             )
  30.     #################################################################
  31.     # DAPO reward manager的主要函数
  32.     #################################################################
  33.     def __call__(self, data: DataProto, return_dict: bool = False):
  34.         """We will expand this function gradually based on the available datasets"""
  35.         # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
  36.         if "rm_scores" in data.batch.keys():
  37.             if return_dict:
  38.                 reward_extra_keys = data.meta_info.get("reward_extra_keys", [])
  39.                 reward_extra_info = {key: data.non_tensor_batch[key] for key in reward_extra_keys}
  40.                 return {"reward_tensor": data.batch["rm_scores"], "reward_extra_info": reward_extra_info}
  41.             else:
  42.                 return data.batch["rm_scores"]
  43.         reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
  44.         reward_extra_info = defaultdict(list)
  45.         already_print_data_sources = {}
  46.         for i in range(len(data)):
  47.             data_item = data[i]  # DataProtoItem
  48.             prompt_ids = data_item.batch["prompts"]
  49.             prompt_length = prompt_ids.shape[-1]
  50.             ########################################################
  51.             # 值得注意的是。prompt_ids是左填充的
  52.             # response_ids是右填充的
  53.             ########################################################
  54.             valid_prompt_length = data_item.batch["attention_mask"][:prompt_length].sum()
  55.             valid_prompt_ids = prompt_ids[-valid_prompt_length:]
  56.             response_ids = data_item.batch["responses"]
  57.             valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
  58.             valid_response_ids = response_ids[:valid_response_length]
  59.             # decode
  60.             prompt_str = self.tokenizer.decode(valid_prompt_ids, skip_special_tokens=True)
  61.             response_str = self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
  62.             eos_token = self.tokenizer.eos_token
  63.             if response_str.endswith(eos_token):
  64.                 response_str = response_str[: -len(eos_token)]
  65.             ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
  66.             data_source = data_item.non_tensor_batch[self.reward_fn_key]
  67.             extra_info = data_item.non_tensor_batch.get("extra_info", {})
  68.             rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
  69.             extra_info["rollout_reward_scores"] = rollout_reward_scores
  70.             result = self.compute_score(
  71.                 data_source=data_source,
  72.                 solution_str=response_str,
  73.                 ground_truth=ground_truth,
  74.                 extra_info=extra_info,
  75.             )
  76.             score: float
  77.             if isinstance(result, dict):
  78.                 score = result["score"]
  79.                 # Store the information including original reward
  80.                 for key, value in result.items():
  81.                     reward_extra_info[key].append(value)
  82.             else:
  83.                 score = result
  84.                 reward_extra_info["acc"].append(score)
  85.             reward = score
  86.             ########################################################
  87.             # 这里是overlong reward的计算
  88.             ########################################################
  89.             if self.overlong_buffer_cfg.enable:
  90.                 overlong_buffer_len = self.overlong_buffer_cfg.len
  91.                 expected_len = self.max_resp_len - overlong_buffer_len
  92.                 exceed_len = valid_response_length - expected_len
  93.                 overlong_penalty_factor = self.overlong_buffer_cfg.penalty_factor
  94.                 overlong_reward = min(-exceed_len / overlong_buffer_len * overlong_penalty_factor, 0)
  95.                 reward += overlong_reward
  96.                 if self.overlong_buffer_cfg.log:
  97.                     reward_extra_info["overlong_reward"].append(overlong_reward)
  98.                     reward_extra_info["overlong"].append(overlong_reward < 0)
  99.             reward_tensor[i, valid_response_length - 1] = reward
  100.             if data_source not in already_print_data_sources:
  101.                 already_print_data_sources[data_source] = 0
  102.             if already_print_data_sources[data_source] < self.num_examine:
  103.                 already_print_data_sources[data_source] += 1
  104.                 print("[prompt]", prompt_str)
  105.                 print("[response]", response_str)
  106.                 print("[ground_truth]", ground_truth)
  107.                 if isinstance(result, dict):
  108.                     for key, value in result.items():
  109.                         print(f"[{key}]", value)
  110.                 else:
  111.                     print("[score]", score)
  112.         if return_dict:
  113.             return {
  114.                 "reward_tensor": reward_tensor,
  115.                 "reward_extra_info": reward_extra_info,
  116.             }
  117.         else:
  118.             return reward_tensor
复制代码
dapo和ppo的具体区别可进一步参考:dapo readme

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

2025-12-1 02:35:32

举报

您需要登录后才可以回帖 登录 | 立即注册