Skip to content

Adds functionality to save best model in CheckpointingCallback and fixes a typo#3

Open
dmsm wants to merge 2 commits intomgharbi:masterfrom
dmsm:master
Open

Adds functionality to save best model in CheckpointingCallback and fixes a typo#3
dmsm wants to merge 2 commits intomgharbi:masterfrom
dmsm:master

Conversation

@dmsm
Copy link
Copy Markdown
Contributor

@dmsm dmsm commented Jun 4, 2020

No description provided.

Comment thread ttools/callbacks.py
self.best_val_key = best_val_key

if best_val_value is not None:
LOG.info("Loaded best model ({}={})".format(best_val_key, best_val_value))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"loaded" be misleading, since the load does not happen at init time

Comment thread ttools/callbacks.py
@@ -184,7 +184,7 @@ class VisdomLoggingCallback(KeyedCallback):
0.0 disables smoothing.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a couple doc entries

Comment thread ttools/callbacks.py
def __init__(self, checkpointer, interval=600,
max_files=5, max_epochs=10):
max_files=5, max_epochs=10,
best_val_key=None, best_val_value=None):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to initialize best_val_value externally? could that be handled by the load/save mechanism so we only have one parameter?

Comment thread ttools/callbacks.py
@@ -506,6 +515,22 @@ def batch_end(self, batch_data, train_step_data):
self.checkpointer.save(filename, extras={"epoch": self.epoch})
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and elsewhere were save is called, we might want to always save the best_val_value for easy loading and treat the "None" case. It could be nice, but not required, to add a "load_best" method to the Checkpointer class.

Comment thread ttools/callbacks.py

if self.best_val_key is None:
return
if val_data[self.best_val_key] > self.best_val_value:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably document that the "best" is implicitly a min value (e.g. if someone uses accuracy instead of loss)

@Vrroom
Copy link
Copy Markdown

Vrroom commented Jun 28, 2021

Hello everyone. I have been using torch tools and love it for its simplicity and readability. For example, I tried using pytorch-lightning once instead but I just became nervous because it was harder to see what it is doing behind the scenes. So I quickly came back to ttools. Anyway, I needed a solution for the above problem. It is pretty similar to the above proposal. There are a couple of differences:

  1. This class saves the best N.
  2. Sense of how to interpret the val_data metric (whether to maximize/minimize) is included in class instantiation.

Current limitation is that if two models have the exact same score, it'll delete one of them. Hope this is useful

class CheckpointingBestNCallback (Callback) : 
    """ A callback which saves the best N models.

    Args:
        checkpointer (Checkpointer): actual checkpointer responsible for the I/O
        key: key into accumulated validation data to define metric.
        N (int, optional): number of models to save
        sense (string, optional): one of "maximize"/"minimize". 
            Denoting whether we want to maximize/minimize val metric.
    """
    
    BEST_PREFIX = "best_"

    def __init__ (self, checkpointer, key, N=3, sense="maximize") : 
        super(CheckpointingBestNCallback, self).__init__()
        self.checkpointer = checkpointer
        self.key = key
        self.N = N
        self.sense = sense
        self.default = sense == "maximize"
        self.cmp = lambda x, y : x > y if self.default else y > x
        self.ckptDict = dict()

    def validation_end(self, val_data): 
        super(CheckpointingBestNCallback, self).validation_end(val_data)
        score = val_data[self.key] 
        isBetter = any([self.cmp(score, y) for y in self.ckptDict.keys()])
        if len(self.ckptDict) < self.N or isBetter : 
            path = "{}{:.3f}".format(CheckpointingBestNCallback.BEST_PREFIX, score)
            path = path.replace('.', '-')
            self.checkpointer.save(path, extras=dict(score=score))
            self.ckptDict[score] = path
            self.__purge_old_files()

    def __purge_old_files(self) : 
        """Delete checkpoints that are beyond the max to keep."""
        chkpts = os.listdir(self.checkpointer.root)
        toBeRemoved = sorted(self.ckptDict.keys(), reverse=self.default)[self.N:]
        for s in toBeRemoved : 
            cpref = self.ckptDict[s]
            cname = [fname for fname in chkpts if cpref in fname].pop()
            self.checkpointer.delete(cname)
            self.ckptDict.pop(s)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants