classCallbackHandler(TrainerCallback):
"""Internal class that just calls the list of callbacks in order."""def__init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.tokenizer = tokenizer
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = Noneself.eval_dataloader = None# ... initialization logic ...defcall_event(self, event, args, state, control, **kwargs):
for callback inself.callbacks:
result = getattr(callback, event)(
args, state, control, model=self.model, tokenizer=self.tokenizer,
optimizer=self.optimizer, lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader, eval_dataloader=self.eval_dataloader,
**kwargs,
)
if result isnotNone:
control = result
return control
三、self._save_checkpoint 源码解读
1. 完整的源码
def_save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we want to save except FullyShardedDDP.assert unwrap_model(model) isself.model, "internal model should be a reference to self.model"# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"ifself.hp_search_backend isNoneand trial isNone:
self.store_flos()
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)
ifself.is_deepspeed_enabled:
# under zero3 model file itself doesn't get saved since it's bogus!self.model_wrapped.save_checkpoint(output_dir)
# Save optimizer and schedulerifself.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
ifself.fsdp orself.is_fsdp_enabled:
ifself.is_fsdp_enabled:
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
else:
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
elif is_sagemaker_mp_enabled():
opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
smp.barrier()
if smp.rdp_rank() == 0or smp.state.cfg.shard_optimizer_state:
smp.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME), partial=True, v3=smp.state.cfg.shard_optimizer_state)
ifself.args.should_save:
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
ifself.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elifself.args.should_save andnotself.is_deepspeed_enabled:
ifself.fsdp andnotself.is_fsdp_enabled:
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
else:
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)
ifself.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpointif metrics isnotNoneandself.args.metric_for_best_model isnotNone:
metric_to_check = self.args.metric_for_best_model
ifnot metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics[metric_to_check]
operator = np.greater ifself.args.greater_is_better else np.less
if (self.state.best_metric isNoneorself.state.best_model_checkpoint isNoneor operator(metric_value, self.state.best_metric)):
self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir
# Save the Trainer stateifself.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
ifself.args.parallel_mode == ParallelMode.DISTRIBUTED:
rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
else:
rng_states["cuda"] = torch.cuda.random.get_rng_state()
if is_torch_tpu_available():
rng_states["xla"] = xm.get_rng_state()
os.makedirs(output_dir, exist_ok=True)
ifself.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
ifself.args.push_to_hub:
self._push_from_checkpoint(output_dir)
# Maybe delete some older checkpoints.ifself.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
defstore_flos(self):
# Storing the number of floating-point operations that went into the modelifself.args.parallel_mode == ParallelMode.DISTRIBUTED:
self.state.total_flos += (
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
)
self.current_flos = 0else:
self.state.total_flos += self.current_flos
self.current_flos = 0
defsave_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if output_dir isNone:
output_dir = self.args.output_dir
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
# Calling the state_dict needs to be done on the wrapped model and on all processes.
...
elif (ShardedDDPOption.ZERO_DP_2 inself.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 inself.args.sharded_ddp orself.fsdp isnotNoneorself.is_fsdp_enabled):
...
elifself.is_deepspeed_enabled:
...
elifself.args.should_save:
self._save(output_dir)
# Push to the Hub when `save_model` is called by the user.ifself.args.push_to_hub andnot _internal_call:
self.push_to_hub(commit_message="Model save")
def_save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir isnotNoneelseself.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,)
ifnot is_peft_available():
supported_classes = (PreTrainedModel, PeftModel)
ifnotisinstance(self.model, supported_classes):
if state_dict isNone:
state_dict = self.model.state_dict()
ifisinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
ifself.tokenizer isnotNone:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))