Conversation
long8v
left a comment
There was a problem hiding this comment.
(24.05.08)
compute_ablation 부분만 읽음. transformer circuit을 이해해야할 듯
| def get_args_parser(): | ||
| parser = argparse.ArgumentParser("Ablations part", add_help=False) |
There was a problem hiding this comment.
논문에서 얘기하는 mean ablation의 정체
| def main(args): | ||
|
|
||
| attns = np.load(os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), mmap_mode="r") # [b, l, h, d] | ||
| mlps = np.load(os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), mmap_mode="r") # [b, l+1, d] |
There was a problem hiding this comment.
numpy로 attention과 mlp를 저장해놓음.
[b, l, h, d] 차원인데 아마 b차원이 imagenet 모든 샘플이려나?
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"), | ||
| "rb", | ||
| ) as f: | ||
| classifier = np.load(f) |
There was a problem hiding this comment.
classifier도 있음. 이건 최종 output logit인가..?
| if args.dataset == "imagenet": | ||
| labels = np.array([i // 5 for i in range(attns.shape[0])]) | ||
| else: | ||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb" | ||
| ) as f: | ||
| labels = np.load(f) |
| baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1) | ||
| baseline_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(baseline @ classifier).float(), torch.from_numpy(labels) | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("Baseline:", baseline_acc) |
There was a problem hiding this comment.
baseline accuracy는 attns를 layer / head별로 sum을 해주고 mlp도 (l+1)차원에서 sum을 해준뒤 둘을 합쳐주면 됨
- attns: attns.sum(axis=(1,2)) [b, l, h, d] -> [b, d]
- mlps: mlps.sum(axis=1) [b, l+1, d] -> [b, d]
왜 연산이 이렇게 되는지 잘 모르겠는데 아마 transformer circuit https://transformer-circuits.pub/2021/framework/index.html#summarizing-ovqk-matrices:~:text=as%20independently%20additive.-,Attention%20Heads%20as%20Information%20Movement,-But%20if%20attention 얘를 이해하면 될듯
| mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attns.shape[0]) | ||
| mlps_ablation = attns.sum(axis=(1, 2)) + mlps_mean.sum(axis=1) | ||
| mlps_ablation_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(mlps_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ MLPs ablation:", mlps_ablation_acc) |
There was a problem hiding this comment.
mlps_mean은 아래와 같음. mlp 연산에서 그냥 배치 차원에서 mean을 구한다음에 repeat하고 이를 다시 attn랑 summation해서 구하는 방식
| mlps_no_layers = mlps.sum(axis=1) | ||
| attns_no_cls = attns.sum(axis=2) | ||
| with open( | ||
| os.path.join(args.input_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "rb" | ||
| ) as f: | ||
| cls_attn = np.load(f) # [b, l, d] | ||
| attns_no_cls = attns_no_cls - cls_attn + cls_attn.mean(axis=0)[np.newaxis, :, :] | ||
| no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_no_layers | ||
| no_cls_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(no_cls_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ CLS ablation:", no_cls_acc) | ||
| mlp_and_no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_mean.sum(axis=1) | ||
| mlp_and_no_cls_ablation_acc = ( | ||
| accuracy( | ||
| torch.from_numpy(mlp_and_no_cls_ablation @ classifier).float(), | ||
| torch.from_numpy(labels), | ||
| )[0] | ||
| * 100 | ||
| ) | ||
| print("+ MLPs + CLS ablation:", mlp_and_no_cls_ablation_acc) |
There was a problem hiding this comment.
cls ablation은 아래와 같음. cls에 대한 attention을 가져온 뒤 이걸 attns에 대해서 빼줌 (잘 이해는 못함)
#172