Add support causalm finetune#80
Conversation
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
d934455 to
70dfa5d
Compare
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
f359909 to
0c2df95
Compare
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
alex-jw-brooks
left a comment
There was a problem hiding this comment.
Thanks @gkumbhat, looks good! Just a few things.
Also a bit of a side note, do you think we should remove support (for now) for HFAutoSequenceClassifier ? Seems like it's effectively unusable between trainer changes & tokenizer builder stuff, it's kind of confusing to have it there when we don't enable it for anything
| lr: float = 2e-5, | ||
| # Directory where model predictions and checkpoints will be written | ||
| checkpoint_dir: str = "/tmp", | ||
| **training_arguments, |
There was a problem hiding this comment.
This is better! Can you link the trainer args in the docstring through?
| **training_arguments, | ||
| ): | ||
| """ | ||
| # FIXME: Below is currently configured for Seq2Seq only |
There was a problem hiding this comment.
This should be removed, right?
There was a problem hiding this comment.
yep. good catch! Will remove this
| "<NLP39984681E>", | ||
| NotImplementedError( | ||
| f"Generation on {type(self.model)} not support \ | ||
| currently! Please try saving and running this model in TGIS." |
There was a problem hiding this comment.
oof. Does exporting via the trainer save API + reloading give you a transformer model back? I wonder if it would be better to have the first inference call export and reload with a warning until we find something better / implement a causal LM trainer doing something similar. Slow feels better than completely broken here IMO.
Or, is there any way we can cast to the seq2seq trainer and leverage the generate API for that? I guess that probably doesn't handle shifting etc the same way...
There was a problem hiding this comment.
yeah, I think converting the seq2seq could land with weird mismatch issues.
Saving and reloading is certainly an option. It would simplify this block of run function entirely. But could be more inefficient, since the model is already on appropriate devices at this point, so loading them again, we would loose the distribution, which is mainly what I was trying to persist here.
But certainly, not having a solution of causal lm would not be great.
| device = PeftPromptTuning._get_device(device) | ||
| inputs = {k: v.to(device) for k, v in tok_tensors.items()} | ||
|
|
||
| inputs = {k: v.to(self.model.device) for k, v in tok_tensors.items()} |
There was a problem hiding this comment.
FYI @rawkintrevo is actually making this change in a separate PR (it's this issue #3). Can we put it back as part of this PR and use his when it's ready instead? Since this PR is primarily targeting fine tuning anyway
There was a problem hiding this comment.
ah true. I had to make this change to make some tests pass 😄 but yes, can change it back.
| "device_placement": True, | ||
| } | ||
|
|
||
| accelerator = Accelerator(**accelerator_args) |
There was a problem hiding this comment.
why build a separate dict here?
There was a problem hiding this comment.
I was playing with some optional parameter regarding cpu=True.. But that didn't work well, so removed that.. So this is kinda left over from that.. Will switch it back to direct arguments instead of separate dict.
|
|
||
|
|
||
| @pytest.fixture() | ||
| def set_cpu_device(request): |
There was a problem hiding this comment.
Nice - thanks for adding this
| 2. compute_metrics | ||
| 3. callbacks | ||
| 4. preprocess_logits_for_metrics | ||
| """ |
There was a problem hiding this comment.
Same questions about documenting the kwargs here in the docstring (at least the nonexpanded ones). I assume the other one probably needs it also
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
…oading the model Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
09f77dc to
d3d962c
Compare
alex-jw-brooks
left a comment
There was a problem hiding this comment.
Looks awesome! Some small typos and stuff, but LGTM
| # eval_steps=1, | ||
| # load_best_model_at_end | ||
| **training_arguments, | ||
| **dtype_based_params, |
There was a problem hiding this comment.
Might be a nice good first issue in the future to cleanly make sure there aren't collisions in these expanded dicts, but for now we can leave it
Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com> Signed-off-by: Gaurav Kumbhat <kumbhat.gaurav@gmail.com> Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
2d5bd16 to
664a3d5
Compare
Add support causalm finetune Signed-off-by: gkumbhat <kumbhat.gaurav@gmail.com>
Closes #77
Changes
set_cpu_devicefixture which changes the cuda environment variable and patchesis_availablefunction intorch.cuda