Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but done for single machine microbatches, in Pytorch.
The official repository that does filtering for macrobatches across machines is here
$ pip install GAF-microbatch-pytorchimport torch # mock network from torch import nn net = nn.Sequential( nn.Linear(512, 256), nn.SiLU(), nn.Linear(256, 128) ) # import the gradient agreement filtering (GAF) wrapper from GAF_microbatch_pytorch import GAFWrapper # just wrap your neural net gaf_net = GAFWrapper( net, filter_distance_thres = 0.97 ) # your batch of data x = torch.randn(16, 1024, 512) # forward and backwards as usual out = gaf_net(x) out.sum().backward() # gradients should be filtered by set threshold comparing per sample gradients within batch, as in paperYou can supply your own gradient filtering method as a Callable[[Tensor], Tensor] with the filter_gradients_fn kwarg as so
def filtering_fn(grads): # make your big discovery here return grads gaf_net = GAFWrapper( net = net, filter_gradients_fn = filtering_fn )To set all GAFWrapper states within a network, use set_filter_gradients_
from GAF_microbatch_pytorch import set_filter_gradients_ set_filter_gradients_(net, False) # turning on / off # or perhaps filter thresholds on some schedule set_filter_gradients_(net, True, 0.98)- replicate cifar results on single machine
- allow for excluding certain parameters from being filtered
@inproceedings{Chaubard2024BeyondGA, title = {Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering}, author = {Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer}, year = {2024}, url = {https://api.semanticscholar.org/CorpusID:274992650} }