Transformer Head Pruner#3884
Conversation
This comment has been minimized.
This comment has been minimized.
| break | ||
| except: | ||
| continue | ||
| if layer_idx is not None: |
There was a problem hiding this comment.
Is there a better way to get the index of the attention head? the first integer may be not strong.
There was a problem hiding this comment.
This layer_idx is the layer index of the BERT encoder. Here I include these lines of code only to show the user how they may take advantage of the pruned_heads dict inside pruner to get the pruned heads for each group, and then match each group to the original layer, and finally call the built-in transformers _prune_heads() function to do model speedup. This is meant to be a temporary workaround before we can properly handle speedup for transformers.
There was a problem hiding this comment.
If our speedup code after refactor can handle transformer, then I will replace these lines with our speedup methods (maybe in a separate pr)
There was a problem hiding this comment.
Since the users are aware of the naming of their own model to prune, I think they can also use their own rules to match layers to groups
| and include `model, optimizer, criterion, epoch` as function arguments. | ||
| criterion: function | ||
| Function used to calculate the loss between the target and the output. | ||
| For example, you can use ``torch.nn.CrossEntropyLoss()`` as input. |
There was a problem hiding this comment.
Feel like that the TransformerHeadPruner is too heavy. I prefer to locate this pruner as a one-shot pruner, which means we do not need handle with the num_iteration, optimizer, trainer, criterion, things. That's much clearer. All those finetuning related things we can offload to the outer search algorithms. We can discuss with Quanlu @QuanluZhang .
There was a problem hiding this comment.
Yes, we can further discuss on that. One challenge is that this does not fit well in our current compression V1 framework (since the current iterative and dependency aware pruner are limited to convolutions), and compression V2 is not ready yet. My initial thought was to first integrate all these logic in one pruner (because of empirically good performance compared to one-shot pruning), and then factor out when compression V2 is ready.
This pr adds a pruner for pruning attention heads in transformers.
To-do:
/ integration with pruning scheduler for iterative features