fix _clone_args when there are non-tensor input#1963
fix _clone_args when there are non-tensor input#1963shunting314 wants to merge 1 commit intoshunting314/stack/33from
Conversation
102ad18 to
632298b
Compare
stack-info: PR: #1963, branch: shunting314/stack/34
stack-info: PR: #1963, branch: shunting314/stack/34
632298b to
241d005
Compare
|
|
||
| def _should_clone(idx: int) -> bool: | ||
| return idx_to_clone is None or idx in idx_to_clone | ||
| idx_to_clone_set = set(idx_to_clone) if idx_to_clone is not None else None |
There was a problem hiding this comment.
Is this the same data structure as in #1961? That PR didn't ignore non-tensors I believe. Is there a cleaner way to to manage this?
There was a problem hiding this comment.
Yes, it's the same data structure. I think the main complexity is from letting the data-structure store indices of flattend-filted-tensor list.
If the code comment mentioned in #1961 (comment) is not a problem, we can change this data structure to store indices of flatted arg list before filtering away tensors. Otherwise, I think things can be cleaned up by added a API to return flattend-filted-tensor list from the original args.
Any comments about this?
There was a problem hiding this comment.
Let's try to clean this up. I think we need:
- A better name that reflects the filtered+flattened nature.
- Some helpers methods or data structures to make working this easier.
Stacked PRs:
fix _clone_args when there are non-tensor input