diff --git a/README.md b/README.md index 9d03404..1047106 100644 --- a/README.md +++ b/README.md @@ -273,7 +273,7 @@ for epoch in range(epochs): out = model(data) loss = F.cross_entropy(out, target) loss.backward() # after loss.backward() - pruner.regularize(model) # <== for sparse training + pruner.regularize(model, loss) # <== for sparse training optimizer.step() # before optimizer.step() ``` diff --git a/README_CN.md b/README_CN.md index e2b71f0..aceeb81 100644 --- a/README_CN.md +++ b/README_CN.md @@ -225,7 +225,7 @@ for epoch in range(epochs): out = model(data) loss = F.cross_entropy(out, target) loss.backward() # after loss.backward() - pruner.regularize(model) # <== for sparse training + pruner.regularize(model, loss) # <== for sparse training optimizer.step() # before optimizer.step() ```