找回密码
 立即注册
首页 业界区 业界 强化学习框架:OpenRLHF源码解读,模型处理 ...

强化学习框架:OpenRLHF源码解读,模型处理

阎一禾 前天 22:00
强化学习框架:OpenRLHF源码解读,模型处理

本文主要介绍 强化学习框架:OpenRLHF源码解读,模型处理
models框架设计

了解一下 OpenRLHF的模型框架设计范式:
1.png

From:https://arxiv.org/pdf/2405.11143
可以知道一个大概的流程:输入Pormpt通过Actor model输出回复 Response,而后将两部分进行拼接再去由其他模型进行处理
1、actor.py

https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/actor.py
这部分主要为加载所需要的模型
  1. class Actor(nn.Module):
  2.     def __init__(...):
  3.         if isinstance(pretrain_or_model, str):
  4.             ...
  5.             self.model = model_class.from_pretrained(
  6.                 pretrain_or_model,
  7.                 trust_remote_code=True,
  8.                 attn_implementation=attn_implementation,
  9.                 quantization_config=nf4_config,
  10.                 torch_dtype=torch.bfloat16 if bf16 else "auto",
  11.                 device_map=device_map,
  12.             )
  13.             if lora_rank > 0:
  14.                 self.model.enable_input_require_grads()
  15.                 lora_config = LoraConfig(
  16.                     task_type=TaskType.CAUSAL_LM,
  17.                     r=lora_rank,
  18.                     lora_alpha=lora_alpha,
  19.                     target_modules=target_modules,
  20.                     lora_dropout=lora_dropout,
  21.                     bias="none",
  22.                 )
  23.                 self.model = get_peft_model(self.model, lora_config)
  24.                 ...
  25.         else:
  26.             self.model = pretrain_or_model
  27.     @torch.no_grad()
  28.     def generate(self, input_ids: torch.Tensor, **kwargs):
  29.         ...
  30.         sequences = self.model.generate(**generate_args)
  31.         eos_token_id = generate_args["eos_token_id"]
  32.         pad_token_id = generate_args["pad_token_id"]
  33.         return self.process_sequences(sequences, input_ids.size(1), eos_token_id, pad_token_id)
  34.     def forward(...):
  35.         ...
  36.         output["logits"] = output["logits"].to(torch.float32) # 得到每一个token概率
  37.         ...
  38.         log_probs = log_probs_from_logits(
  39.                     output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature
  40.                 )
  41.         ...
  42.         action_log_probs = log_probs[:, -num_actions:]
复制代码
这个actor比较简单,首先从huggingface加载需要的模型,并且对模型进行部分设置如:量化/lora微调。或者直接加载自己预训练好的模型。
1、generate:模块则是根据输入的内容(比如说被 tokenizer处理好的文本)input_ids通过模型输出新的内容(根据 **kwargs获取生成文本参数设置比如说:top_k等)
2、forward:根据输入的 token 序列(sequences),计算模型在生成最后若干个 token(即 "动作")时的对数概率(log probs),之所以要这么处理是因为,在强化学习模型中(PPO、DPO等)一般而言模型的输出是一个序列,但优化目标不是“能不能生成这个序列”,而是:这个序列中,哪些 token 是“好”的?模型对这些 token 的概率应该更高!比如说在 DPO中:

\[L(θ) = E[ min(r(θ) * A, clip(r(θ), 1-ε, 1+ε) * A) ]\]
里面的

\[r(\theta)=\pi_{\theta}(a|s)/\pi_{old}(a|s)\]
就是概率比值,上面代码中:
  1. log_probs_from_logits(output["logits"][:, :-1, :], sequences[:, 1:], temperature=self.temperature)
复制代码
计算的就是:\(log(\pi_{\theta}(a|s))\),在具体代码中:
  1. def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
  2.     if temperature != 1.0:
  3.         logits.div_(temperature)
  4.     if logits.dtype in [torch.float32, torch.float64]:
  5.         batch_dim = logits.shape[:-1]
  6.         last_dim = logits.shape[-1]
  7.         try:
  8.             from flash_attn.ops.triton.cross_entropy import cross_entropy_loss
  9.             output = cross_entropy_loss(logits.reshape(-1, last_dim), labels.reshape(-1))
  10.             log_probs_labels = -output[0].view(*batch_dim)
  11.         except ImportError:
  12.             logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
  13.             logsumexp_values = _logsumexp_by_chunk(logits.reshape(-1, last_dim))
  14.             logsumexp_values = logsumexp_values.view(*batch_dim)
  15.             log_probs_labels = logits_labels - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
  16.     else:
  17.         log_probs_labels = []
  18.         for row_logits, row_labels in zip(logits, labels):  # loop to reduce peak mem consumption
  19.             row_log_probs = F.log_softmax(row_logits, dim=-1)
  20.             row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
  21.             log_probs_labels.append(row_log_probs_labels)
  22.         log_probs_labels = torch.stack(log_probs_labels)
  23.     return log_probs_labels
复制代码
补充-1
在使用 AutoModelForCausalLM.from_pretrained使用得到 model之后,其支持输入参数为:
  1. outputs = model(
  2.     input_ids=None,            # 输入的token(batch_size, seq_length)
  3.     attention_mask=None,       # 指示哪些 token 是有效的(非 padding),形状同 input_ids
  4.     position_ids=None,         # 位置编码
  5.     past_key_values=None,
  6.     inputs_embeds=None,
  7.     use_cache=None,            # 是否使用k-v cache
  8.     labels=None,               # 输入标签就直接计算loss
  9.     output_attentions=None,
  10.     output_hidden_states=None,
  11.     return_dict=None,
  12. )
复制代码
补充-2
在LLM训练过程中遇到过短的语句为了节约显存(如果都将内容补充到相同长度,那么就会有较多的padding造成浪费),因此可以将几个短的拼接起来,但是为了区分那些是一个句子那些不是的,在 OpenRLHF中通过参数:self.packing_samples。如果没有 packing那么直接根据 attention_mask将位置编码在处理一下
  1. if not self.packing_samples:
  2.     position_ids = attention_mask.long().cumsum(-1) - 1
  3.     position_ids.masked_fill_(attention_mask == 0, 1)
  4. else:
  5.     # convert attention_mask to position_ids
  6.     if ring_attn_group is not None:
  7.         labels = sequences
  8.         sequences, attention_mask, position_ids = convert_ring_attn_params(
  9.             sequences, attention_mask, packed_seq_lens, ring_attn_group
  10.         )
  11.     else:
  12.         position_ids = reset_position_ids(attention_mask)
  13.     # explicitly ignore attention_mask for packing_samples
  14.     attention_mask = None
复制代码
其中 reset_position_ids做的就是重新做位置编码重新处理
2、model.py

https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/models/model.py
2.png

主要功能返回所需要的模型,主要返回2个模型:1、CriticModel;2、RewardModel 回顾一下这几类模型的作用:无论是在GRPO还是DPO中都会输出token然后需要去对token进行评分,起评分作用的就是 reward model 对应上面图中 reward model,除此之外都会计算 优势函数(\(Q(s,a)-V(s)\))来评估策略的好坏优势函数里面计算就是通过 critic model来对某一个策略进行评估对应上面图像中的:value model
  1. def _get_reward_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False):
  2.     class RewardModel(base_pretrained_model):
  3.         def __init__(...):
  4.             ...
  5.             # 加载模型
  6.             setattr(self, self.base_model_prefix, base_llm_model(config))
  7.             self.value_head_prefix = value_head_prefix
  8.             setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False) # 输出评分
  9.             ...
  10.         def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, return_output=False, ring_attn_group=None,pad_sequence=False, packed_seq_lens=None,):
  11.             ...# 1、处理packing
  12.             outputs = getattr(self, self.base_model_prefix)(
  13.                 input_ids, attention_mask=attention_mask, position_ids=position_ids
  14.             )
  15.             last_hidden_states = outputs["last_hidden_state"]
  16.             values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)
  17.             ...# 1、处理packing
  18.             else:
  19.                 # 输出最后一个有效token的评分代替整个句子评分
  20.                 eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
  21.                 reward = values.gather(dim=1, index=eos_indices).squeeze(1)
  22.             if not self.training and self.normalize_reward:
  23.                 reward = (reward - self.mean) / self.std
  24.             return (reward, outputs) if return_output else reward
  25.     return RewardModel
  26. def _get_critic_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False):
  27.     class CriticModel(base_pretrained_model):
  28.         def __init__(...):
  29.             ...
  30.         def forward(...):
  31.             ...# 1、处理packing
  32.             outputs = getattr(self, self.base_model_prefix)(
  33.                 input_ids, attention_mask=attention_mask, position_ids=position_ids
  34.             )
  35.             last_hidden_states = outputs["last_hidden_state"]
  36.             values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)
  37.             ...
  38.             if num_actions is None:
  39.                 assert return_output
  40.                 return outputs
  41.             if not self.packing_samples:
  42.                 action_values = values[:, -num_actions:]
  43.             else:
  44.                 assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens)
  45.                 action_values = []
  46.                 offset = 0
  47.                 for num_action, seq_len in zip(num_actions, packed_seq_lens):
  48.                     start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1
  49.                     action_values.append(values[:, start:end])
  50.                     offset += seq_len
  51.                 action_values = torch.cat(action_values, dim=1)
  52.             if return_output:
  53.                 return (action_values, outputs)
  54.             else:
  55.                 return action_values
  56.     return CriticModel
复制代码
1、reward model: 传入一个 base_pretrained_model(比如 PreTrainedModel)、一个 base_llm_model(比如 AutoModel)以及一些控制参数。函数内部返回一个定制化的奖励模型类 RewardModel,它可以在给定输入句子时,输出一个数值(reward 分数),反映输出文本的质量。在forward计算中,直接将输入model使用的几个参数(见上面的补充有具体解释)计算最后取最后一个状态的值,并且将这个值取计算评分。也就是说 reward model:首先计算下一个预测的token而后对这些token进行打分
2、critic model:具体输入参数和 reward model相同。参考之前介绍,上面代码中直接返回action_values = values[:, -num_actions:]( num_actions存在条件下)这样就会得到不同的Q(s, a1), Q(s, a2), ...
<blockquote>总结上面两组模型,在 LLM 的强化学习场景下,Reward Model 和 Critic Model 都从 last_hidden_state 得到 token-level 表达,再用 Linear 层输出每个 token 的 score。
<ul>Reward Model 最后提取的是 EOS token 的 score,表示整句话的奖励。
Critic Model 会进一步提取最后 num_actions 个 token 的 value,这些 token 是 Actor 生成的动作,对应到 PPO 中的:
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册