To facilitate the grid searching of optimal hyper-parameters,
early stop is now available in SSLRec for methods that trained by basic trainer.
If you need to use early stop, you only need to add the patience keyword under the train keyword
in the yaml file to define the endurance number n.
When the n verification results are lower than the existing best results,
the algorithm will stop training and extract the best model parameters for testing on the test set.
Here is an example:
train:
epoch: 3000
batch_size: 4096
save_model: false
loss: pairwise
log_loss: false
test_step: 3
patience: 3If there is no patience keyword, the trainer will train the model with fixed number of epoch as defined.
Also, note that in the data_handler file, add the correct validation and test sets.
Otherwise, the algorithm will loosely use the test set as the validation set.
Here is an example in data_hander_general_cf.py:
def load_data(self):
trn_mat = self._load_one_mat(self.trn_file)
tst_mat = self._load_one_mat(self.tst_file)
val_mat = self._load_one_mat(self.val_file)
self.trn_mat = trn_mat
configs['data']['user_num'], configs['data']['item_num'] = trn_mat.shape
self.torch_adj = self._make_torch_adj(trn_mat)
if configs['train']['loss'] == 'pairwise':
trn_data = PairwiseTrnData(trn_mat)
elif configs['train']['loss'] == 'pairwise_with_epoch_flag':
trn_data = PairwiseWEpochFlagTrnData(trn_mat)
val_data = AllRankTstData(val_mat, trn_mat)
tst_data = AllRankTstData(tst_mat, trn_mat)
self.valid_dataloader = data.DataLoader(val_data, batch_size=configs['test']['batch_size'], shuffle=False, num_workers=0)
self.test_dataloader = data.DataLoader(tst_data, batch_size=configs['test']['batch_size'], shuffle=False, num_workers=0)
self.train_dataloader = data.DataLoader(trn_data, batch_size=configs['train']['batch_size'], shuffle=True, num_workers=0)