diff --git a/.gitignore b/.gitignore index a005653dd..81f5d2ece 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ /env __pycache__/ + +/dataset/dataflow/*/* +/evaluation/log/* \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..026229000 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "captum"] + path = captum + url = git@github.com:shitongzhu/captum.git diff --git a/captum b/captum new file mode 160000 index 000000000..caa7f2ae5 --- /dev/null +++ b/captum @@ -0,0 +1 @@ +Subproject commit caa7f2ae52016e95ceb5f1c67f3a0cbf625770df diff --git a/evaluation/script/aggregate_exp_res.py b/evaluation/script/aggregate_exp_res.py new file mode 100644 index 000000000..0fbf9b279 --- /dev/null +++ b/evaluation/script/aggregate_exp_res.py @@ -0,0 +1,205 @@ +import os +import argparse +import matplotlib +matplotlib.use('Agg') # to avoid using Xserver +import matplotlib.pyplot as plt + + +COLOR_LIST = ["blue", "red", "yellow", "orange"] + + +def parse_log_file(fpath): + global log_info + + with open(fpath, 'r') as fin: + data = fin.readlines() + + header_str = '\t'.join(data[1].strip().split(' | ')[1].split('\t')[1:]) + variants = header_str.split('\t') + for variant in variants: + if variant not in log_info: + if "DELETION" in variant or "RETENTION" in variant: + log_info[variant] = [[] for _ in range(10)] + else: + log_info[variant] = [] + + del data[:2] + for row in data: + row = row.strip() + result_str = '\t'.join(row.split(' | ')[1].split('\t')[1:]) + + scores = list(map(lambda x: float(x), result_str.split('\t')[:8])) + for i in range(8): + log_info[variants[i]].append(scores[i]) + + prob_delta = list(map(lambda x: eval(x), result_str.split('\t')[8:])) + for i in range(8, 16): + for j in range(len(prob_delta[i - 8])): + log_info[variants[i]][j].append(prob_delta[i - 8][j]) + + +parser = argparse.ArgumentParser(description='Aggregate attr accu logs.') +parser.add_argument( + "--task", type=str, help="specify the task that log aggregation should be applied to.") +parser.add_argument("--dir", type=str, help="specify log directory.") +args = parser.parse_args() + +log_info = {} + +log_filenames = os.listdir(args.dir) +for fname in log_filenames: + if args.task in fname: + parse_log_file(args.dir + '/' + fname) + +print("Analyzing task name: %s..." % args.task) +print("In directory: %s" % args.dir) + +print("====== Mean Attribution Score ======") +color_choice_deletion = 0 +color_choice_retention = 0 +for variant, scores in log_info.items(): + if "DELETION" in variant: + if color_choice_deletion == 0: + save_img_path = args.dir + "/viz/DELETION_comparison.png" + x_list, y_list = [], [] + for i in range(len(scores)): + print("[%s] Step #%d (mean) --> %f" % + (variant, i, sum(scores[i]) / len(scores[i]))) + x_list.append(i) + y_list.append(sum(scores[i]) / len(scores[i])) + plt.plot(x_list, y_list, color=COLOR_LIST[color_choice_deletion], label=variant.replace("DELETION_RES_", '')) + plt.legend() + color_choice_deletion += 1 + if color_choice_deletion == 4: + plt.xlabel("number of steps") + plt.ylabel("predicted class probability") + plt.title("Deletion Game Results") + plt.show() + plt.savefig(save_img_path, format="PNG") + plt.clf() + elif "RETENTION" in variant: + if color_choice_retention == 0: + save_img_path = args.dir + "/viz/RETENTION_comparison.png" + x_list, y_list = [], [] + for i in range(len(scores)): + print("[%s] Step #%d (mean) --> %f" % + (variant, i, sum(scores[i]) / len(scores[i]))) + x_list.append(i) + y_list.append(sum(scores[i]) / len(scores[i])) + plt.plot(x_list, y_list, color=COLOR_LIST[color_choice_retention], label=variant.replace("RETENTION_RES_", '')) + plt.legend() + color_choice_retention += 1 + if color_choice_retention == 4: + plt.xlabel("number of steps") + plt.ylabel("predicted class probability") + plt.title("Retention Game Results") + plt.show() + plt.savefig(save_img_path, format="PNG") + plt.clf() + else: + mean_score = sum(scores) / len(scores) + if variant in {"ASCENDING_DEPENDENCY_GUIDED_IG", "UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "DESCENDING_DEPENDENCY_GUIDED_IG"}: + std_mean_score = sum(log_info["STANDARD_IG"]) / \ + len(log_info["STANDARD_IG"]) + if std_mean_score == 0.0: + continue + margin = (mean_score - std_mean_score) / std_mean_score + print("[ATTR_ACC_SCORE] Variant: %s | # Samples: %d | Mean score: %f (%s)" % + (variant, len(scores), mean_score, "{:.2f}".format(margin * 100) + "%")) + elif variant == "STANDARD_IG": + print("[ATTR_ACC_SCORE] Variant: %s | # Samples: %d | Mean score: %f" % + (variant, len(scores), mean_score)) + + if variant in {"FAITH_ASCENDING_DEPENDENCY_GUIDED_IG", "FAITH_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "FAITH_DESCENDING_DEPENDENCY_GUIDED_IG"}: + std_mean_score = sum(log_info["FAITH_STANDARD_IG"]) / \ + len(log_info["FAITH_STANDARD_IG"]) + if std_mean_score == 0.0: + continue + margin = (mean_score - std_mean_score) / std_mean_score + print("[FAITH_SCORE] Variant: %s | # Samples: %d | Mean score: %f (%s)" % + (variant, len(scores), mean_score, "{:.2f}".format(margin * 100) + "%")) + elif variant == "FAITH_STANDARD_IG": + print("[FAITH_SCORE] Variant: %s | # Samples: %d | Mean score: %f" % + (variant, len(scores), mean_score)) + +running_ranks = {} +for variant, _ in log_info.items(): + running_ranks[variant] = [] + +for i in range(len(log_info["STANDARD_IG"])): + attr_acc_std_ig = log_info["STANDARD_IG"][i] + attr_acc_dep_guided_ig = log_info["ASCENDING_DEPENDENCY_GUIDED_IG"][i] + attr_acc_dep_guided_ig_unaccumulated = log_info["UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG"][i] + attr_acc_reverse_dep_guided_ig = log_info["DESCENDING_DEPENDENCY_GUIDED_IG"][i] + faith_score_std_ig = log_info["FAITH_STANDARD_IG"][i] + faith_score_dep_guided_ig = log_info["FAITH_ASCENDING_DEPENDENCY_GUIDED_IG"][i] + faith_score_dep_guided_ig_unaccumulated = log_info[ + "FAITH_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG"][i] + faith_score_reverse_dep_guided_ig = log_info["FAITH_DESCENDING_DEPENDENCY_GUIDED_IG"][i] + + sorted_acc_scores = sorted([ + attr_acc_std_ig, + attr_acc_dep_guided_ig, + attr_acc_dep_guided_ig_unaccumulated, + attr_acc_reverse_dep_guided_ig + ]) + variant_rank = list(map(lambda x: sorted_acc_scores.index(x), [ + attr_acc_std_ig, + attr_acc_dep_guided_ig, + attr_acc_dep_guided_ig_unaccumulated, + attr_acc_reverse_dep_guided_ig, + ])) + + sorted_faith_scores = sorted([faith_score_std_ig, + faith_score_dep_guided_ig, + faith_score_dep_guided_ig_unaccumulated, + faith_score_reverse_dep_guided_ig + ]) + variant_rank_faith = list(map(lambda x: sorted_faith_scores.index(x), [ + faith_score_std_ig, + faith_score_dep_guided_ig, + faith_score_dep_guided_ig_unaccumulated, + faith_score_reverse_dep_guided_ig, + ])) + + running_ranks["STANDARD_IG"].append(variant_rank[0]) + running_ranks["ASCENDING_DEPENDENCY_GUIDED_IG"].append(variant_rank[1]) + running_ranks["UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG"].append( + variant_rank[2]) + running_ranks["DESCENDING_DEPENDENCY_GUIDED_IG"].append(variant_rank[3]) + running_ranks["FAITH_STANDARD_IG"].append(variant_rank_faith[0]) + running_ranks["FAITH_ASCENDING_DEPENDENCY_GUIDED_IG"].append( + variant_rank_faith[1]) + running_ranks["FAITH_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG"].append( + variant_rank_faith[2]) + running_ranks["FAITH_DESCENDING_DEPENDENCY_GUIDED_IG"].append( + variant_rank_faith[3]) + +print("====== Ranking For Variants ======") +for variant, ranks in running_ranks.items(): + if "DELETION" in variant or "RETENTION" in variant: + continue + mean_rank = sum(ranks) / len(ranks) + if variant in {"ASCENDING_DEPENDENCY_GUIDED_IG", "UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "DESCENDING_DEPENDENCY_GUIDED_IG"}: + std_mean_rank = sum(running_ranks["STANDARD_IG"]) / \ + len(running_ranks["STANDARD_IG"]) + if std_mean_rank == 0.0: + continue + margin = (mean_rank - std_mean_rank) / std_mean_rank + print("[ATTR_ACC_SCORE] Variant: %s | # Samples: %d | Mean rank: %f (%s)" % + (variant, len(ranks), mean_rank, "{:.2f}".format(margin * 100) + "%")) + elif variant == "STANDARD_IG": + print("[ATTR_ACC_SCORE] Variant: %s | # Samples: %d | Mean rank: %f" % + (variant, len(ranks), mean_rank)) + + if variant in {"FAITH_ASCENDING_DEPENDENCY_GUIDED_IG", "FAITH_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "FAITH_DESCENDING_DEPENDENCY_GUIDED_IG"}: + std_mean_rank = sum(running_ranks["STANDARD_IG"]) / \ + len(running_ranks["STANDARD_IG"]) + if std_mean_rank == 0.0: + continue + margin = (mean_rank - std_mean_rank) / std_mean_rank + print("[FAITH_SCORE] Variant: %s | # Samples: %d | Mean rank: %f (%s)" % + (variant, len(ranks), mean_rank, "{:.2f}".format(margin * 100) + "%")) + elif variant == "FAITH_STANDARD_IG": + print("[FAITH_SCORE] Variant: %s | # Samples: %d | Mean rank: %f" % + (variant, len(ranks), mean_rank)) diff --git a/evaluation/script/batch_ggnn_test.sh b/evaluation/script/batch_ggnn_test.sh new file mode 100644 index 000000000..2b1476a90 --- /dev/null +++ b/evaluation/script/batch_ggnn_test.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +echo "Total number of instances to run: $1" +echo "Max number of graph nodes: $2" +echo "Max toleratable ratio of edge removal for removing cycles: $3" +echo "Task name: $4" +echo "Total test size: $5" +echo "Additional arg #1: $6" + +per_instance_size=$(($5/$1)) + +echo "Per-instance test size: $per_instance_size" + +if [ $4 = "domtree" ]; then + model_id="14" +else + model_id="15" +fi + +for instance_id in `seq 1 $1` +do + echo "Starting instance with ID $instance_id..." + nohup bazel run --verbose_failures //programl/task/dataflow:ggnn_test \ + -- --model=/logs/programl/$4/ddf_30/checkpoints/0$model_id.Checkpoint.pb \ + --ig -dep_guided_ig --save_vis --only_pred_y --batch --random_test_size $per_instance_size \ + --max_vis_graph_complexity $2 --max_removed_edges_ratio $3 --task $4 \ + --filter_adjacant_nodes --instance_id $instance_id --num_instances $1 $6\ + > ../log/nohup_$4_exp_$2_$3_$1_$instance_id.log 2>&1 & +done diff --git a/evaluation/script/clean_up_after_exp.sh b/evaluation/script/clean_up_after_exp.sh new file mode 100644 index 000000000..61cc6fbef --- /dev/null +++ b/evaluation/script/clean_up_after_exp.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +echo "Cleaning up dirs..." + +rm ../../dataset/dataflow/exp_log/* +rm ../../dataset/dataflow/vis_res/* +rm ../log/* \ No newline at end of file diff --git a/evaluation/script/single_ggnn_test.sh b/evaluation/script/single_ggnn_test.sh new file mode 100644 index 000000000..90fb0f392 --- /dev/null +++ b/evaluation/script/single_ggnn_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +echo "Max number of graph nodes: $1" +echo "Max toleratable ratio of edge removal for removing cycles: $2" +echo "Task name: $3" +echo "Additional arg #1: $4" + +if [ $3 = "domtree" ]; then + model_id="14" +else + model_id="15" +fi + +bazel run --verbose_failures //programl/task/dataflow:ggnn_test \ + -- --model=/logs/programl/$3/ddf_30/checkpoints/0$model_id.Checkpoint.pb \ + --ig -dep_guided_ig --save_vis --only_pred_y --batch --random_test_size 100 \ + --max_vis_graph_complexity $1 --max_removed_edges_ratio $2 --task $3 \ + --filter_adjacant_nodes --instance_id 1 --num_instances 1 --debug $4 \ No newline at end of file diff --git a/install b/install new file mode 100755 index 000000000..84222ffc1 --- /dev/null +++ b/install @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +help() { + cat </dev/null || + source "$(grep -sm1 "^$f " "${RUNFILES_MANIFEST_FILE:-/dev/null}" | cut -f2- -d' ')" 2>/dev/null || + source "$0.runfiles/$f" 2>/dev/null || + source "$(grep -sm1 "^$f " "$0.runfiles_manifest" | cut -f2- -d' ')" 2>/dev/null || + source "$(grep -sm1 "^$f " "$0.exe.runfiles_manifest" | cut -f2- -d' ')" 2>/dev/null || + { + echo >&2 "ERROR: cannot find $f" + exit 1 + } +f= +# --- end app init --- + +set -euo pipefail + +BINARIES=( + "$(DataPath programl/programl/cmd/analyze)" + "$(DataPath programl/programl/cmd/clang2graph)" + "$(DataPath programl/programl/cmd/graph2cdfg)" + "$(DataPath programl/programl/cmd/graph2dot)" + "$(DataPath programl/programl/cmd/graph2json)" + "$(DataPath programl/programl/cmd/llvm2graph)" + "$(DataPath programl/programl/cmd/pbq)" + "$(DataPath programl/programl/cmd/xla2graph)" +) + +if [[ $(uname) == Darwin ]]; then + LLVM_LIBS="$(DataPath clang-llvm-10.0.0-x86_64-apple-darwin/lib)" +else + LLVM_LIBS="$(DataPath clang-llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04/lib)" +fi + +main() { + set +u + if [[ "$1" == "--help" ]]; then + help + exit 1 + fi + set -u + + local prefix=${1:-~/.local/opt/programl} + mkdir -p "$prefix/bin" "$prefix/lib" + + echo "Installing ProGraML command line tools ..." + echo + for bin in "${BINARIES[@]}"; do + dst="$prefix/bin/$(basename $bin)" + echo " $dst" + rm -f "$dst" + cp $bin "$dst" + done + + echo + echo "Installing libraries to $prefix/libs ..." + rsync -ah --delete --exclude '*.a' "$LLVM_LIBS/" "$prefix/lib/" + + echo + echo "====================================================" + echo "To use them, add the following to your ~/.$(basename $SHELL)rc:" + echo + echo "export PATH=$prefix/bin:\$PATH" + echo "export LD_LIBRARY_PATH=$prefix/lib:\$LD_LIBRARY_PATH" +} +main "$@" diff --git a/programl/BUILD b/programl/BUILD new file mode 100644 index 000000000..ea401114d --- /dev/null +++ b/programl/BUILD @@ -0,0 +1,27 @@ +# Copyright 2019-2020 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +py_library( + name = "serialize_ops", + srcs = [ + "exceptions.py", + "serialize_ops.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//programl/proto:program_graph_py", + ], +) diff --git a/programl/exceptions.py b/programl/exceptions.py new file mode 100644 index 000000000..35b0a55fd --- /dev/null +++ b/programl/exceptions.py @@ -0,0 +1,27 @@ +# Copyright 2019-2020 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class UnsupportedCompiler(TypeError): + """Exception raised if the requested compiler is not supported.""" + + +class GraphCreationError(ValueError): + """Exception raised if a graph creation op fails.""" + + +class GraphTransformError(ValueError): + """Exception raised if a graph transform op fails.""" diff --git a/programl/models/batch_results.py b/programl/models/batch_results.py index a0e617603..9dafc2724 100644 --- a/programl/models/batch_results.py +++ b/programl/models/batch_results.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module defines data structures for model results for a mini-batch.""" -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, List import numpy as np import sklearn.metrics @@ -55,6 +55,13 @@ class BatchResults(NamedTuple): f1: float # The confusion matrix. confusion_matrix: np.array + # Attributions. + attributions: np.array + # Faithfulness score. + faithfulness_score: float + # Deletion/Retention game results. + deletion_res: List[float] + retention_res: List[float] @property def has_learning_rate(self) -> bool: @@ -99,6 +106,10 @@ def Create( model_converged: bool = False, learning_rate: Optional[float] = None, loss: Optional[float] = None, + attributions: List[int] = None, + faithfulness_score: float = None, + deletion_res: List[float] = None, + retention_res: List[float] = None, ): """Construct a results instance from 1-hot targets and predictions. @@ -169,6 +180,10 @@ def Create( confusion_matrix=BuildConfusionMatrix( targets=true_y, predictions=pred_y, num_classes=y_dimensionality ), + attributions=attributions, + faithfulness_score=faithfulness_score, + deletion_res=deletion_res, + retention_res=retention_res, ) diff --git a/programl/models/ggnn/ggnn.py b/programl/models/ggnn/ggnn.py index 12f93c318..90fa8e704 100644 --- a/programl/models/ggnn/ggnn.py +++ b/programl/models/ggnn/ggnn.py @@ -15,13 +15,14 @@ # limitations under the License. """A gated graph neural network classifier.""" import typing -from typing import Dict, Tuple +from typing import Dict, Tuple, List import numpy as np import torch from labm8.py import app from labm8.py.progress import NullContext, ProgressContext from torch import nn +import torch.nn.functional as F from programl.graph.format.py.graph_tuple import GraphTuple from programl.models.batch_data import BatchData @@ -36,6 +37,10 @@ from programl.models.model import Model from programl.proto import epoch_pb2 +# For model explainability +from captum.attr import IntegratedGradients +from copy import deepcopy + FLAGS = app.FLAGS # Graph unrolling flags. @@ -318,7 +323,7 @@ def trainable_parameter_count(self) -> int: return self.model.trainable_parameter_count def PrepareModelInputs( - self, epoch_type: epoch_pb2.EpochType, batch: BatchData + self, epoch_type: epoch_pb2.EpochType, batch: BatchData, node_out=None ) -> Dict[str, torch.Tensor]: """RunBatch() helper method to prepare inputs to model. @@ -356,46 +361,215 @@ def PrepareModelInputs( torch.from_numpy(x).to(self.model.dev, torch.long) for x in graph_tuple.edge_positions ] + + raw_in = self.model.node_embeddings(vocab_ids, selector_ids) model_inputs = { - "vocab_ids": vocab_ids, - "selector_ids": selector_ids, + "raw_in": raw_in, "labels": labels, "edge_lists": edge_lists, + "node_out": node_out, "pos_lists": edge_positions, } - # maybe fetch more inputs. - # TODO: - # if graph_tuple.has_graph_y: - # assert ( - # epoch_type != epoch_pb2.TRAIN - # or graph_tuple.graph_tuple_count > 1 - # ), f"graph_count is {graph_tuple.graph_tuple_count}" - # num_graphs = torch.tensor(graph_tuple.graph_tuple_count).to( - # self.model.dev, torch.long - # ) - # graph_nodes_list = torch.from_numpy( - # graph_tuple.disjoint_nodes_list - # ).to(self.model.dev, torch.long) - # - # aux_in = torch.from_numpy(graph_tuple.graph_x).to( - # self.model.dev, torch.get_default_dtype() - # ) - # model_inputs.update( - # { - # "num_graphs": num_graphs, - # "graph_nodes_list": graph_nodes_list, - # "aux_in": aux_in, - # } - # ) - return model_inputs + def DoFaithfulnessTest( + self, + model_inputs, + attr_orders, + predictions, + verbal=True, + ) -> float: + # This function calculates the faithfulness score by doing the following: + # -- (1) find the node with the highest attribution score, remove it and + # test if the model prediction changes; if not, go to (2); if so, return 1 + # -- (2) find the node with the 2nd highest attribution score, remove it and + # test if the model prediction changes; if not, go to (3); if so, return 2 + # ...... + # -- (n) find the node with the nth highest attribution score, remove it and + # test if the model prediction changes; if not, go to (n+1); if so, return n + if self.model.training: + self.model.eval() + self.model.opt.zero_grad() + + if verbal: + print("Number of attributions: %d" % len(attr_orders)) + + for i in range(1, len(attr_orders)): + curr_raw_in = torch.clone(model_inputs["raw_in"]) + labels = deepcopy(model_inputs["labels"]) + edge_lists = deepcopy(model_inputs["edge_lists"]) + node_out = deepcopy(model_inputs["node_out"]) + pos_lists = deepcopy(model_inputs["pos_lists"]) + + for j in range(i): + curr_attr_idx = attr_orders.index(j) + curr_raw_in[curr_attr_idx] = 0.0 + logits = self.model( + raw_in=curr_raw_in, + labels=labels, + edge_lists=edge_lists, + node_out=node_out, + pos_lists=pos_lists, + )[1] + logits = F.softmax(logits, dim=1) + curr_predictions = torch.argmax(logits) + + if predictions.detach().cpu() != curr_predictions.detach().cpu(): + break + + if verbal: + print("Removed nodes with highest %d attributions (logits: %s)..." % (i + 1, str(logits))) + + del curr_raw_in + + return float(1 - (i + 1) / (len(attr_orders))) + + def RemoveEdges( + self, + edge_lists, + pos_lists, + removed_nodes, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + new_edge_lists = [] + new_pos_lists = [] + + for i in range(len(edge_lists)): + edge_list = edge_lists[i] + new_edge_list = [] + new_pos_list = [] + + for j in range(len(edge_list)): + tail, head = edge_list[j] + if tail.item() in set(removed_nodes) or head.item() in set(removed_nodes): + continue + else: + new_edge_list.append(edge_lists[i][j]) + new_pos_list.append(pos_lists[i][j]) + + if new_edge_list == [] and new_pos_list == []: + new_edge_lists.append(torch.empty((0, edge_lists[i].shape[1]), device=edge_lists[i].device, dtype=edge_lists[i].dtype)) + new_pos_lists.append(torch.empty(0, device=pos_lists[i].device, dtype=pos_lists[i].dtype)) + else: + new_edge_lists.append(torch.stack(new_edge_list)) + new_pos_lists.append(torch.stack(new_pos_list)) + + return new_edge_lists, new_pos_lists + + def DoDeletionAndRetentionGameTests( + self, + model_inputs, + attr_orders, + predictions, + logits, + verbal=True, + remove_edges=False, + ) -> float: + # This function calculates the faithfulness score by playing two games (deletion and retention): + # -- (1) For deletion: incrementally find the node with highest attribution scores, + # remove it and measure the drop in prediction accuracy (logit probs) + # -- (2) For retention: incrementally find the node with lowest attribution scores, + # remove it and measure the increase in prediction accuracy (logit probs) + if self.model.training: + self.model.eval() + self.model.opt.zero_grad() + + if verbal: + print("Number of attributions: %d" % len(attr_orders)) + + # First pass for Deletion Game + deletion_res = [logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()]] + for i in range(1, len(attr_orders)): + curr_raw_in = torch.clone(model_inputs["raw_in"]) + labels = deepcopy(model_inputs["labels"]) + edge_lists = deepcopy(model_inputs["edge_lists"]) + node_out = deepcopy(model_inputs["node_out"]) + pos_lists = deepcopy(model_inputs["pos_lists"]) + + removed_nodes = [] + for j in range(i): + curr_attr_idx = attr_orders.index(j) + curr_raw_in[curr_attr_idx] = 0.0 + removed_nodes.append(curr_attr_idx) + if remove_edges: + edge_lists, pos_lists = self.RemoveEdges(edge_lists, pos_lists, removed_nodes) + curr_logits = self.model( + raw_in=curr_raw_in, + labels=labels, + edge_lists=edge_lists, + node_out=node_out, + pos_lists=pos_lists, + )[1] + curr_logits = F.softmax(curr_logits, dim=1) + + curr_class_prob = curr_logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()] + class_prob_drop = logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()] - curr_class_prob + deletion_res.append(curr_class_prob) + + if verbal: + print("Deleted nodes with highest %d attributions (prob drop: %f -- from %f to %f)..." % (i, class_prob_drop, logits.detach( + ).cpu().tolist()[0][predictions.detach().cpu().item()], curr_logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()])) + + del curr_raw_in + + # Second pass for Retention Game + retention_res = [] + baseline_class_prob = 0.0 + for i in range(len(attr_orders)): + curr_raw_in = torch.clone(model_inputs["raw_in"]) + labels = deepcopy(model_inputs["labels"]) + edge_lists = deepcopy(model_inputs["edge_lists"]) + node_out = deepcopy(model_inputs["node_out"]) + pos_lists = deepcopy(model_inputs["pos_lists"]) + + removed_nodes = [] + for j in list(range(len(attr_orders)))[i:]: + curr_attr_idx = attr_orders.index(j) + curr_raw_in[curr_attr_idx] = 0.0 + removed_nodes.append(curr_attr_idx) + if remove_edges: + edge_lists, pos_lists = self.RemoveEdges(edge_lists, pos_lists, removed_nodes) + curr_logits = self.model( + raw_in=curr_raw_in, + labels=labels, + edge_lists=edge_lists, + node_out=node_out, + pos_lists=pos_lists, + )[1] + curr_logits = F.softmax(curr_logits, dim=1) + + if i == 0: + baseline_class_prob = curr_logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()] + retention_res.append(baseline_class_prob) + else: + curr_class_prob = curr_logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()] + class_prob_increase = curr_class_prob - baseline_class_prob + retention_res.append(curr_class_prob) + + if verbal: + print("Retained nodes with highest %d attributions (prob increase: %s -- from %f to %f)..." % (i, class_prob_increase, baseline_class_prob, + curr_logits.detach().cpu().tolist()[0][predictions.detach().cpu().item()])) + + del curr_raw_in + + return deletion_res, retention_res + def RunBatch( self, epoch_type: epoch_pb2.EpochType, batch: BatchData, ctx: ProgressContext = NullContext, + run_ig=False, + dep_guided_ig=False, + interpolation_order=None, + node_out=None, + accumulate_gradients=True, + return_delta=False, + reverse=False, + average_attrs=True, + do_faithfulness_test=False, + do_deletion_retention_games=True, + remove_edges=False, ) -> BatchResults: """Process a mini-batch of data through the GGNN. @@ -407,7 +581,17 @@ def RunBatch( Returns: A batch results instance. """ - model_inputs = self.PrepareModelInputs(epoch_type, batch) + def get_sorted_indices(array): + a = np.argsort(array) + a[a.copy()] = np.arange(len(a)) + a = len(a) - a - 1 # remember to reserve the index + return a + + assert dep_guided_ig == (interpolation_order is not None), "Invalid dep_guided_ig/interpolation_order combination!" + + interpolation_order = deepcopy(interpolation_order) + + model_inputs = self.PrepareModelInputs(epoch_type, batch, node_out) unroll_steps = np.array( GetUnrollSteps(epoch_type, batch, FLAGS.unroll_strategy), dtype=np.int64, @@ -424,6 +608,14 @@ def RunBatch( self.model.opt.zero_grad() # Inference only, don't trace the computation graph. with torch.no_grad(): + if run_ig: + # Let's make a copy in case any parameter is altered in place + raw_in = torch.clone(model_inputs["raw_in"]) + labels = deepcopy(model_inputs["labels"]) + edge_lists = deepcopy(model_inputs["edge_lists"]) + node_out = deepcopy(model_inputs["node_out"]) + pos_lists = deepcopy(model_inputs["pos_lists"]) + outputs = self.model(**model_inputs) ( @@ -433,6 +625,122 @@ def RunBatch( *unroll_stats, ) = outputs + logits = F.softmax(logits, dim=1) + predictions = torch.argmax(logits) + + if run_ig: + print("Starting IG explanation...") + print("Dim of input: %s" % str(raw_in.shape)) + print("Original logits (no manipulation): %s" % str(logits)) + + ig = IntegratedGradients(self.model) + + if dep_guided_ig: + method = "dependency_guided_ig_nonuniform" + n_steps = len(interpolation_order[0]) + print("Using dependency-guided IG (method: %s | n_steps: %d) | accumulate gradients: %s | reversed: %s | average attrs: %s" % + (method, n_steps, str(accumulate_gradients), str(reverse), str(average_attrs))) + else: + method = "gausslegendre" + n_steps = raw_in.shape[0] # for fair comparison + print("Using stock IG (method: %s | n_steps: %d) | accumulate gradients: %s | reversed: %s | average attrs: %s" % ( + method, n_steps, str(accumulate_gradients), str(reverse), str(average_attrs))) + + if node_out is not None: + node_outs = [node_out] * n_steps + + if return_delta: + attributions_list, approximation_error_list = [], [] + if interpolation_order is None: + attributions, approximation_error = ig.attribute( + raw_in, + additional_forward_args=( + labels, edge_lists, node_outs, pos_lists), + method=method, + return_convergence_delta=True, + target=predictions, + interpolation_order=interpolation_order, + n_steps=n_steps, + accumulate_gradients=accumulate_gradients, + ) + else: + for order in interpolation_order: + attributions, approximation_error = ig.attribute( + raw_in, + additional_forward_args=( + labels, edge_lists, node_outs, pos_lists), + method=method, + return_convergence_delta=True, + target=predictions, + interpolation_order=order, + n_steps=n_steps, + accumulate_gradients=accumulate_gradients, + ) + attributions_list.append(attributions) + approximation_error_list.append(approximation_error) + attributions = torch.mean(torch.stack(attributions_list)) + approximation_error = torch.mean( + torch.stack(approximation_error_list)) + else: + attributions_list, approximation_error_list = [], [] + if interpolation_order is None: + attributions = ig.attribute( + raw_in, + additional_forward_args=( + labels, edge_lists, node_outs, pos_lists), + method=method, + return_convergence_delta=False, + target=predictions, + interpolation_order=interpolation_order, + n_steps=n_steps, + accumulate_gradients=accumulate_gradients, + ) + else: + for order in interpolation_order: + attributions = ig.attribute( + raw_in, + additional_forward_args=( + labels, edge_lists, node_outs, pos_lists), + method=method, + return_convergence_delta=False, + target=predictions, + interpolation_order=order, + n_steps=n_steps, + accumulate_gradients=accumulate_gradients, + ) + attributions_list.append(attributions) + attributions = torch.mean(torch.stack(attributions_list), dim=0) + summerized_attributions = torch.mean(attributions, dim=1) + summerized_attributions_indices = get_sorted_indices( + summerized_attributions.detach().cpu().numpy()) + + if return_delta: + print("Mean error: %f" % approximation_error.mean()) + print("Summerized attributions indices: %s" % str(summerized_attributions_indices)) + + if do_faithfulness_test: + faithfulness_score = self.DoFaithfulnessTest( + model_inputs=model_inputs, + attr_orders=summerized_attributions_indices.tolist(), + predictions=predictions, + ) + print("Faithfulness score: %f" % faithfulness_score) + else: + faithfulness_score = 0.0 + + if do_deletion_retention_games: + deletion_res, retention_res = self.DoDeletionAndRetentionGameTests( + model_inputs=model_inputs, + attr_orders=summerized_attributions_indices.tolist(), + logits=logits, + predictions=predictions, + remove_edges=remove_edges, + ) + print("Deletion game results: %s" % str(deletion_res)) + print("Retention game results: %s" % str(retention_res)) + + print("IG steps finished.") + loss = self.model.loss((logits, graph_features), targets) if epoch_type == epoch_pb2.TRAIN: @@ -464,14 +772,42 @@ def RunBatch( model_converged = unroll_stats[1] if unroll_stats else False iteration_count = unroll_stats[0] if unroll_stats else unroll_steps - return BatchResults.Create( - targets=batch.model_data.node_labels, - predictions=logits.detach().cpu().numpy(), - model_converged=model_converged, - learning_rate=self.model.learning_rate, - iteration_count=iteration_count, - loss=loss.item(), - ) + if run_ig: + if node_out is not None: + return BatchResults.Create( + targets=np.expand_dims(batch.model_data.node_labels[node_out], axis=0), + predictions=logits.detach().cpu().numpy(), + model_converged=model_converged, + learning_rate=self.model.learning_rate, + iteration_count=iteration_count, + loss=loss.item(), + attributions=summerized_attributions_indices, + faithfulness_score=faithfulness_score, + deletion_res=deletion_res, + retention_res=retention_res, + ) + else: + return BatchResults.Create( + targets=batch.model_data.node_labels, + predictions=logits.detach().cpu().numpy(), + model_converged=model_converged, + learning_rate=self.model.learning_rate, + iteration_count=iteration_count, + loss=loss.item(), + attributions=summerized_attributions_indices, + faithfulness_score=faithfulness_score, + deletion_res=deletion_res, + retention_res=retention_res, + ) + else: + return BatchResults.Create( + targets=batch.model_data.node_labels, + predictions=logits.detach().cpu().numpy(), + model_converged=model_converged, + learning_rate=self.model.learning_rate, + iteration_count=iteration_count, + loss=loss.item(), + ) def GetModelData(self) -> typing.Any: return { diff --git a/programl/models/ggnn/ggnn_model.py b/programl/models/ggnn/ggnn_model.py index 60ad1e018..a3b302114 100644 --- a/programl/models/ggnn/ggnn_model.py +++ b/programl/models/ggnn/ggnn_model.py @@ -29,7 +29,10 @@ FLAGS = app.FLAGS app.DEFINE_boolean( - "block_gpu", True, "Prevent model from hitchhiking on an occupied gpu." + "block_gpu", False, "Prevent model from hitchhiking on an occupied gpu." +) +app.DEFINE_boolean( + "cpu_only", False, "Prevent model from using any gpu." ) @@ -56,14 +59,20 @@ def __init__( self.metrics = Metrics() # Move the model to device before making the optimizer. - if FLAGS.block_gpu: - self.dev = ( - torch.device("cuda") - if gpu_scheduler.LockExclusiveProcessGpuAccess() - else torch.device("cpu") - ) + if FLAGS.cpu_only: + self.dev = torch.device("cpu") else: - self.dev = torch.device("cuda") + import random + gpu_num = random.randint(0, 1) + if FLAGS.block_gpu: + self.dev = ( + torch.device("cuda:%d" % gpu_num) + if gpu_scheduler.LockExclusiveProcessGpuAccess() + else + torch.device("cpu") + ) + else: + self.dev = torch.device("cuda:%d" % gpu_num) self.to(self.dev) @@ -95,20 +104,18 @@ def GetLRScheduler(self, optimizer, gamma): def forward( self, - vocab_ids, + raw_in, labels, edge_lists, - selector_ids=None, + node_out=None, pos_lists=None, num_graphs=None, graph_nodes_list=None, node_types=None, aux_in=None, ): - raw_in = self.node_embeddings(vocab_ids, selector_ids) - # self.ggnn might change raw_in inplace, so use the two outputs instead! raw_out, raw_in, *unroll_stats = self.ggnn( - edge_lists, raw_in, pos_lists, node_types + raw_in, edge_lists, pos_lists, node_types ) if self.has_graph_labels: @@ -128,7 +135,13 @@ def forward( # accuracy, pred_targets, correct, targets # metrics_tuple = self.metrics(logits, labels) targets = labels.argmax(dim=1) - + + if node_out is not None: + # We now need to select the output target/logit for the node + # we are interested (instead of all) + targets = torch.reshape(targets[node_out], (-1,)) + logits = torch.reshape(logits[node_out], (-1, logits.shape[1])) + outputs = ( targets, logits, diff --git a/programl/models/ggnn/ggnn_proper.py b/programl/models/ggnn/ggnn_proper.py index a021ea6d7..a695f8b1e 100644 --- a/programl/models/ggnn/ggnn_proper.py +++ b/programl/models/ggnn/ggnn_proper.py @@ -24,6 +24,8 @@ from programl.models.ggnn.messaging_layer import MessagingLayer from programl.models.ggnn.readout import Readout +from copy import deepcopy # for avoiding side-effects from forward pass + class GGNNProper(nn.Module): def __init__( @@ -106,26 +108,31 @@ def __init__( def forward( self, - edge_lists, node_states, + edge_lists, pos_lists=None, node_types=None, ): old_node_states = node_states.clone() - + # edge_lists will also be manipulated in forward pass + # so we need to first make a copy + old_edge_lists = deepcopy(edge_lists) + old_pos_lists = deepcopy(pos_lists) + old_node_types = deepcopy(node_types) + if self.use_backward_edges: - back_edge_lists = [x.flip([1]) for x in edge_lists] - edge_lists.extend(back_edge_lists) + back_edge_lists = [x.flip([1]) for x in old_edge_lists] + old_edge_lists.extend(back_edge_lists) # For backward edges we keep the positions of the forward edge. if self.position_embeddings: - pos_lists.extend(pos_lists) + old_pos_lists.extend(old_pos_lists) if self.unroll_strategy == "label_convergence": node_states, unroll_steps, converged = self.label_convergence_forward( - edge_lists, + old_edge_lists, node_states, - pos_lists, - node_types, + old_pos_lists, + old_node_types, initial_node_states=old_node_states, ) return node_states, old_node_states, unroll_steps, converged @@ -135,7 +142,7 @@ def forward( bincount = torch.zeros( node_states.size()[0], dtype=torch.long, device=node_states.device ) - for edge_list in edge_lists: + for edge_list in old_edge_lists: edge_targets = edge_list[:, 1] edge_bincount = edge_targets.bincount(minlength=node_states.size()[0]) bincount += edge_bincount @@ -147,19 +154,19 @@ def forward( # Clamp the position lists in the range [0,edge_position_max) to match # the pre-computed position embeddings table. - if pos_lists: - for pos_list in pos_lists: + if old_pos_lists: + for pos_list in old_pos_lists: pos_list.clamp_(0, self.edge_position_max - 1) for (layer_idx, num_timesteps) in enumerate(self.layer_timesteps): for t in range(num_timesteps): messages = self.message[layer_idx]( - edge_lists, + old_edge_lists, node_states, msg_mean_divisor=msg_mean_divisor, - pos_lists=pos_lists, + pos_lists=old_pos_lists, ) - node_states = self.update[layer_idx](messages, node_states, node_types) + node_states = self.update[layer_idx](messages, node_states, old_node_types) return node_states, old_node_states diff --git a/programl/models/ggnn/messaging_layer.py b/programl/models/ggnn/messaging_layer.py index 6907ca5c9..e9942efb5 100644 --- a/programl/models/ggnn/messaging_layer.py +++ b/programl/models/ggnn/messaging_layer.py @@ -91,7 +91,7 @@ def forward(self, edge_lists, node_states, msg_mean_divisor=None, pos_lists=None ) messages_by_targets = torch.zeros_like(node_states) - + for i, edge_list in enumerate(edge_lists): edge_sources = edge_list[:, 0] edge_targets = edge_list[:, 1] diff --git a/programl/models/lstm/BUILD b/programl/models/lstm/BUILD index 7c9c47677..b1b643a16 100644 --- a/programl/models/lstm/BUILD +++ b/programl/models/lstm/BUILD @@ -27,7 +27,7 @@ py_library( "//programl/proto:epoch_py", "//third_party/py/labm8", "//third_party/py/numpy", - "//third_party/py/tensorflow", + "//third_party/py/torch", ], ) diff --git a/programl/models/lstm/lstm.py b/programl/models/lstm/lstm.py index 95688e15d..8b894cad9 100644 --- a/programl/models/lstm/lstm.py +++ b/programl/models/lstm/lstm.py @@ -19,12 +19,15 @@ from typing import Any, Dict, List import numpy as np -import tensorflow as tf +import torch from labm8.py import app from labm8.py.progress import NullContext, ProgressContext +from torch import nn, optim from programl.models.batch_data import BatchData from programl.models.batch_results import BatchResults +from programl.models.ggnn.loss import Loss +from programl.models.ggnn.node_embeddings import NodeEmbeddings from programl.models.lstm.lstm_batch import LstmBatchData from programl.models.model import Model from programl.proto import epoch_pb2 @@ -57,19 +60,24 @@ "The value used for the positive class in the 1-hot selector embedding " "vectors. Has no effect when selector embeddings are not used.", ) -app.DEFINE_boolean( - "cudnn_lstm", - True, - "If set, use CuDNNLSTM implementation when a GPU is available. Else use " - "default Keras implementation. Note that the two implementations are " - "incompatible - a model saved using one LSTM type cannot be restored using " - "the other LSTM type.", -) app.DEFINE_float("learning_rate", 0.001, "The mode learning rate.") app.DEFINE_boolean( "trainable_embeddings", True, "Whether the embeddings are trainable." ) +# Embeddings options. +app.DEFINE_string( + "text_embedding_type", + "random", + "The type of node embeddings to use. One of " + "{constant_zero, constant_random, random}.", +) +app.DEFINE_integer( + "text_embedding_dimensionality", + 32, + "The dimensionality of node text embeddings.", +) + class Lstm(Model): """An LSTM model for node-level classification.""" @@ -78,106 +86,56 @@ def __init__( self, vocabulary: Dict[str, int], node_y_dimensionality: int, + graph_y_dimensionality: int, + graph_x_dimensionality: int, + use_selector_embeddings: bool, test_only: bool = False, name: str = "lstm", ): """Constructor.""" - super(Lstm, self).__init__( - test_only=test_only, vocabulary=vocabulary, name=name - ) + super().__init__(test_only=test_only, vocabulary=vocabulary, name=name) self.vocabulary = vocabulary self.node_y_dimensionality = node_y_dimensionality + self.graph_y_dimensionality = graph_y_dimensionality + self.graph_x_dimensionality = graph_x_dimensionality + self.node_selector_dimensionality = 2 if use_selector_embeddings else 0 # Flag values. self.batch_size = FLAGS.batch_size self.padded_sequence_length = FLAGS.padded_sequence_length - # Reset any previous Tensorflow session. This is required when running - # consecutive LSTM models in the same process. - tf.compat.v1.keras.backend.clear_session() - - @staticmethod - def MakeLstmLayer(*args, **kwargs): - """Construct an LSTM layer. - - If a GPU is available and --cudnn_lstm, this will use NVIDIA's fast - CuDNNLSTM implementation. Else it will use Keras' builtin LSTM, which is - much slower but works on CPU. - """ - if FLAGS.cudnn_lstm and tf.compat.v1.test.is_gpu_available(): - return tf.compat.v1.keras.layers.CuDNNLSTM(*args, **kwargs) - else: - return tf.compat.v1.keras.layers.LSTM(*args, **kwargs, implementation=1) - - def CreateKerasModel(self) -> tf.compat.v1.keras.Model: - """Construct the tensorflow computation graph.""" - vocab_ids = tf.compat.v1.keras.layers.Input( - batch_shape=( - self.batch_size, - self.padded_sequence_length, + self.model = LstmModel( + node_embeddings=NodeEmbeddings( + node_embeddings_type=FLAGS.text_embedding_type, + use_selector_embeddings=self.node_selector_dimensionality, + selector_embedding_value=FLAGS.selector_embedding_value, + embedding_shape=( + # Add one to the vocabulary size to account for the out-of-vocab token. + len(vocabulary) + 1, + FLAGS.text_embedding_dimensionality, + ), ), - dtype="int32", - name="sequence_in", - ) - embeddings = tf.compat.v1.keras.layers.Embedding( - input_dim=len(self.vocabulary) + 2, - input_length=self.padded_sequence_length, - output_dim=FLAGS.hidden_size, - name="embedding", - trainable=FLAGS.trainable_embeddings, - )(vocab_ids) - - selector_vectors = tf.compat.v1.keras.layers.Input( - batch_shape=(self.batch_size, self.padded_sequence_length, 2), - dtype="float32", - name="selector_vectors", - ) - - lang_model_input = tf.compat.v1.keras.layers.Concatenate( - axis=2, name="embeddings_and_selector_vectorss" - )( - [embeddings, selector_vectors], - ) - - # Recurrent layers. - lang_model = self.MakeLstmLayer( - FLAGS.hidden_size, return_sequences=True, name="lstm_1" - )(lang_model_input) - lang_model = self.MakeLstmLayer( - FLAGS.hidden_size, - return_sequences=True, - return_state=False, - name="lstm_2", - )(lang_model) - - # Dense layers. - for i in range(1, FLAGS.hidden_dense_layer_count + 1): - lang_model = tf.compat.v1.keras.layers.Dense( - FLAGS.hidden_size, - activation="relu", - name=f"dense_{i}", - )(lang_model) - node_out = tf.compat.v1.keras.layers.Dense( - self.node_y_dimensionality, - activation="sigmoid", - name="node_out", - )(lang_model) - - model = tf.compat.v1.keras.Model( - inputs=[vocab_ids, selector_vectors], - outputs=[node_out], - ) - model.compile( - optimizer=tf.compat.v1.keras.optimizers.Adam( - learning_rate=FLAGS.learning_rate + loss=Loss( + num_classes=self.node_y_dimensionality, + has_aux_input=self.has_aux_input, + intermediate_loss_weight=None, # NOTE(cec): Intentionally broken. + class_prevalence_weighting=False, ), - metrics=["accuracy"], - loss=["categorical_crossentropy"], - loss_weights=[1.0], + padded_sequence_length=self.padded_sequence_length, + learning_rate=FLAGS.learning_rate, + test_only=test_only, + hidden_size=FLAGS.hidden_size, + hidden_dense_layer_count=FLAGS.hidden_dense_layer_count, ) - return model + @property + def num_classes(self) -> int: + return self.node_y_dimensionality or self.graph_y_dimensionality + + @property + def has_aux_input(self) -> bool: + return self.graph_x_dimensionality > 0 def CreateModelData(self, test_only: bool) -> None: """Initialize an LSTM model. This is called during Initialize().""" @@ -209,24 +167,43 @@ def RunBatch( self.batch_size, self.padded_sequence_length, ), model_data.encoded_sequences.shape - assert model_data.selector_vectors.shape == ( + assert model_data.selector_ids.shape == ( self.batch_size, self.padded_sequence_length, - 2, - ), model_data.selector_vectors.shape + ), model_data.selector_ids.shape x = [model_data.encoded_sequences, model_data.selector_vectors] y = [model_data.node_labels] if epoch_type == epoch_pb2.TRAIN: - loss, *_ = self.model.train_on_batch(x, y) + if not self.model.training: + self.model.train() + targets, logits = self.model( + model_data.encoded_sequences, + model_data.selector_ids, + model_data.node_labels, + ) else: - loss = None + if self.model.training: + self.model.eval() + self.model.opt.zero_grad() + # Inference only, don't trace the computation graph. + with torch.no_grad(): + targets, logits = self.model( + model_data.encoded_sequences, + model_data.selector_ids, + model_data.node_labels, + ) + + loss = self.model.loss((logits, None), targets) - padded_predictions = self.model.predict_on_batch(x) + if epoch_type == epoch_pb2.TRAIN: + loss.backward() + self.model.opt.step() + self.model.opt.zero_grad() # Reshape the outputs. - predictions = self.ReshapePaddedModelOutput(batch_data, padded_predictions) + predictions = self.ReshapePaddedModelOutput(batch_data, outputs) # Flatten the targets and predictions lists so that we can compare them. # Shape (batch_node_count, node_y_dimensionality). @@ -234,9 +211,10 @@ def RunBatch( predictions = np.concatenate(predictions) return BatchResults.Create( - targets=targets, - predictions=predictions, - loss=loss, + targets=model_data.node_labels, + predictions=logits.detach().cpu().numpy(), + learning_rate=self.model.learning_rate, + loss=loss.item(), ) def ReshapePaddedModelOutput( @@ -282,36 +260,74 @@ def ReshapePaddedModelOutput( def GetModelData(self) -> Any: """Get the model state.""" - # According to https://keras.io/getting-started/faq/, it is not recommended - # to pickle a Keras model. So as a workaround, I use Keras's saving - # mechanism to store the weights, and pickle that. - with tempfile.TemporaryDirectory(prefix="lstm_pickle_") as d: - path = pathlib.Path(d) / "weights.h5" - self.model.save(path) - with open(path, "rb") as f: - model_data = f.read() - return model_data + return { + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.model.opt.state_dict(), + "scheduler_state_dict": self.model.scheduler.state_dict(), + } def LoadModelData(self, data_to_load: Any) -> None: """Restore the model state.""" - # Load the weights from a file generated by ModelDataToSave(). - with tempfile.TemporaryDirectory(prefix="lstm_pickle_") as d: - path = pathlib.Path(d) / "weights.h5" - with open(path, "wb") as f: - f.write(data_to_load) - - # The default TF graph is finalized in Initialize(), so we must - # first reset the session and create a new graph. - tf.compat.v1.reset_default_graph() - SetAllowedGrowthOnKerasSession() - - self.model = tf.compat.v1.keras.models.load_model(path) - - -def SetAllowedGrowthOnKerasSession(): - """Allow growth on GPU for Keras.""" - config = tf.compat.v1.ConfigProto() - config.gpu_options.allow_growth = True - session = tf.compat.v1.Session(config=config) - tf.compat.v1.keras.backend.set_session(session) - return session + self.model.load_state_dict(data_to_load["model_state_dict"]) + # only restore opt if needed. opt should be None o/w. + if not self.test_only: + self.model.opt.load_state_dict(data_to_load["optimizer_state_dict"]) + self.model.scheduler.load_state_dict(data_to_load["scheduler_state_dict"]) + + +class LstmModel(nn.Module): + def __init__( + self, + node_embeddings: NodeEmbeddings, + loss: Loss, + padded_sequence_length: int, + test_only: bool, + learning_rate: float, + hidden_size: int, + hidden_dense_layer_count: int, # TODO(cec): Implement. + ): + super().__init__() + self.node_embeddings = node_embeddings + self.loss = loss + self.padded_sequence_length = padded_sequence_length + self.learning_rate = learning_rate + self.hidden_size = hidden_size + self.learning_rate = learning_rate + + self.lstm = nn.LSTM( + self.node_embeddings.embedding_dimensionality + 2, + self.hidden_size, + ) + self.hidden2label = nn.Linear(self.hidden_size, 2) + + if test_only: + self.opt = None + self.eval() + else: + self.opt = optim.AdamW(self.parameters(), lr=self.learning_rate) + + def forward( + self, + encoded_sequences, + selector_ids, + node_labels, + ): + print("SHAPES", encoded_sequences.shape, selector_ids.shape, node_labels.shape) + + encoded_sequences = torch.tensor(encoded_sequences, dtype=torch.long) + selector_ids = torch.tensor(selector_ids, dtype=torch.long) + node_labels = torch.tensor(node_labels, dtype=torch.long) + + # Embed and concatenate sequences and selector vectors. + embeddings = self.node_embeddings(encoded_sequences, selector_ids) + + lstm_out, _ = self.lstm( + embeddings.view(self.padded_sequence_length, len(encoded_sequences), -1) + ) + print(lstm_out.shape) + + label_space = self.hidden2label(lstm_out.view(self.padded_sequence_length, -1)) + logits = F.log_softmax(label_space, dim=2) + + targets = node_labels + return logits, targets diff --git a/programl/models/lstm/lstm_batch.py b/programl/models/lstm/lstm_batch.py index bcac7f619..f9bf3a417 100644 --- a/programl/models/lstm/lstm_batch.py +++ b/programl/models/lstm/lstm_batch.py @@ -31,8 +31,8 @@ class LstmBatchData(NamedTuple): # Shape (batch_size, padded_sequence_length, 1), dtype np.int32 encoded_sequences: np.array - # Shape (batch_size, padded_sequence_length, 2), dtype np.int32 - selector_vectors: np.array + # Shape (batch_size, padded_sequence_length, 1), dtype np.int32 + selector_ids: np.array # Shape (batch_size, padded_sequence_length, node_y_dimensionality), # dtype np.float32 node_labels: np.array diff --git a/programl/serialize_ops.py b/programl/serialize_ops.py new file mode 100644 index 000000000..bc4ef32c6 --- /dev/null +++ b/programl/serialize_ops.py @@ -0,0 +1,181 @@ +# Copyright 2019-2020 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Graph serialization ops are used for storing or transferring Program +Graphs. +""" +import gzip +from pathlib import Path +from typing import Iterable, List, Optional + +import google.protobuf.message +import google.protobuf.text_format + +from programl.exceptions import GraphCreationError +from programl.proto.program_graph_pb2 import ProgramGraph, ProgramGraphList + + +def save_graphs( + path: Path, graphs: Iterable[ProgramGraph], compression: Optional[str] = "gz" +) -> None: + """Save a sequence of program graphs to a file. + + :param path: The file to write. + + :param graphs: A sequence of Program Graphs. + + :param compression: Either :code:`gz` for GZip compression (the default), or + :code:`None` for no compression. Compression increases the cost of + serializing and deserializing but can greatly reduce the size of the + serialized graphs. + + :raises TypeError: If an unsupported :code:`compression` is given. + """ + with open(path, "wb") as f: + f.write(to_bytes(graphs, compression=compression)) + + +def load_graphs( + path: Path, idx_list: Optional[List[int]] = None, compression: Optional[str] = "gz" +) -> List[ProgramGraph]: + """Load program graphs from a file. + + :param path: The file to read from. + + :param idx_list: A zero-based list of graph indices to return. If not + provided, all graphs are loaded. + + :param compression: Either :code:`gz` for GZip compression (the default), or + :code:`None` for no compression. Compression increases the cost of + serializing and deserializing but can greatly reduce the size of the + serialized graphs. + + :return: A sequence of Program Graphs. + + :raises TypeError: If an unsupported :code:`compression` is given. + + :raise GraphCreationError: If deserialization fails. + """ + with open(path, "rb") as f: + return from_bytes(f.read(), idx_list=idx_list, compression=compression) + + +def to_bytes( + graphs: Iterable[ProgramGraph], compression: Optional[str] = "gz" +) -> bytes: + """Serialize a sequence of Program Graphs to a byte array. + + :param graphs: A sequence of Program Graphs. + + :param compression: Either :code:`gz` for GZip compression (the default), or + :code:`None` for no compression. Compression increases the cost of + serializing and deserializing but can greatly reduce the size of the + serialized graphs. + + :return: The serialized program graphs. + """ + compressors = { + "gz": gzip.compress, + None: lambda d: d, + } + if compression not in compressors: + compressors = ", ".join(sorted(str(x) for x in compressors)) + raise TypeError( + f"Invalid compression argument: {compression}. " + f"Supported compressions: {compressors}" + ) + compress = compressors[compression] + + return compress(ProgramGraphList(graph=list(graphs)).SerializeToString()) + + +def from_bytes( + data: bytes, idx_list: Optional[List[int]] = None, compression: Optional[str] = "gz" +) -> List[ProgramGraph]: + """Deserialize Program Graphs from a byte array. + + :param data: The serialized Program Graphs. + + :param idx_list: A zero-based list of graph indices to return. If not + provided, all graphs are returned. + + :param compression: Either :code:`gz` for GZip compression (the default), or + :code:`None` for no compression. Compression increases the cost of + serializing and deserializing but can greatly reduce the size of the + serialized graphs. + + :return: A list of Program Graphs. + + :raise GraphCreationError: If deserialization fails. + """ + decompressors = { + "gz": gzip.decompress, + None: lambda d: d, + } + if compression not in decompressors: + decompressors = ", ".join(sorted(str(x) for x in decompressors)) + raise TypeError( + f"Invalid compression argument: {compression}. " + f"Supported compressions: {decompressors}" + ) + decompress = decompressors[compression] + + graph_list = ProgramGraphList() + try: + graph_list.ParseFromString(decompress(data)) + except (gzip.BadGzipFile, google.protobuf.message.DecodeError) as e: + raise GraphCreationError(str(e)) from e + + if idx_list: + return [graph_list.graph[i] for i in idx_list] + return list(graph_list.graph) + + +def to_string(graphs: Iterable[ProgramGraph]) -> str: + """Serialize a sequence of Program Graphs to a human-readable string. + + The generated string has a JSON-like syntax that is designed for human + readability. This is the least compact form of serialization. + + :param graphs: A sequence of Program Graphs. + + :return: The serialized program graphs. + """ + return str(ProgramGraphList(graph=list(graphs))) + + +def from_string( + string: str, idx_list: Optional[List[int]] = None +) -> List[ProgramGraph]: + """Deserialize Program Graphs from a human-readable string. + + :param data: The serialized Program Graphs. + + :param idx_list: A zero-based list of graph indices to return. If not + provided, all graphs are returned. + + :return: A list of Program Graphs. + + :raise GraphCreationError: If deserialization fails. + """ + graph_list = ProgramGraphList() + try: + google.protobuf.text_format.Merge(string, graph_list) + except google.protobuf.text_format.ParseError as e: + raise GraphCreationError(str(e)) from e + + if idx_list: + return [graph_list.graph[i] for i in idx_list] + return list(graph_list.graph) diff --git a/programl/task/dataflow/BUILD b/programl/task/dataflow/BUILD index 8ad5bbb00..a49ad96ae 100644 --- a/programl/task/dataflow/BUILD +++ b/programl/task/dataflow/BUILD @@ -74,8 +74,8 @@ py_binary( ) py_binary( - name = "ggnn_test_one", - srcs = ["ggnn_test_one.py"], + name = "ggnn_test", + srcs = ["ggnn_test.py"], deps = [ "//programl/models:base_graph_loader", "//programl/models:batch_results", @@ -88,8 +88,7 @@ py_binary( "//programl/task/dataflow:ggnn_batch_builder", "//programl/task/dataflow:vocabulary", "//programl/task/dataflow/dataset:pathflag", - "//third_party/py/labm8", - "//third_party/py/numpy", + "//programl:serialize_ops", ], ) @@ -99,6 +98,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//programl/graph/format/py:cdfg", + "//programl/graph/format/py:nx_format", "//programl/models:base_graph_loader", "//programl/proto:epoch_py", "//programl/proto:program_graph_features_py", @@ -117,6 +117,7 @@ py_library( "//programl/models:base_graph_loader", "//programl/models:batch_data", "//programl/models/lstm:lstm_batch", + "//third_party/py/keras_preprocessing", "//third_party/py/labm8", "//third_party/py/numpy", ], @@ -178,6 +179,18 @@ py_binary( ], ) +py_test( + name = "train_lstm_test", + srcs = ["train_lstm_test.py"], + data = [ + "//programl/test/data:reachability_dataflow_dataset", + ], + deps = [ + ":train_lstm", + "//third_party/py/labm8", + ], +) + py_library( name = "vocabulary", srcs = ["vocabulary.py"], diff --git a/programl/task/dataflow/ggnn_test.py b/programl/task/dataflow/ggnn_test.py new file mode 100644 index 000000000..69805a929 --- /dev/null +++ b/programl/task/dataflow/ggnn_test.py @@ -0,0 +1,1155 @@ +# Copyright 2019-2020 the ProGraML authors. +# +# Contact Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run inference of a trained GGNN model on a single graph input. +""" +import matplotlib +matplotlib.use('Agg') # to avoid using Xserver +import networkx as nx +from labm8.py import app, pbutil +import numpy as np +import matplotlib.pyplot as plt +import torch +from programl import serialize_ops +from programl.task.dataflow.ggnn_batch_builder import DataflowGgnnBatchBuilder +from programl.task.dataflow import dataflow, vocabulary +from programl.proto import ( + checkpoint_pb2, + epoch_pb2, + program_graph_features_pb2, + program_graph_pb2, +) +from programl.models.ggnn.ggnn import Ggnn +from programl.models.batch_results import BatchResults +from programl.models.base_graph_loader import BaseGraphLoader +from programl.graph.format.py.nx_format import ProgramGraphToNetworkX +import igraph as ig +from datetime import datetime +import logging +from copy import deepcopy +from networkx.drawing.nx_agraph import graphviz_layout +import pathlib +from typing import Any, Iterable, List, Tuple +from os import listdir +import random + +random.seed(888) + + +app.DEFINE_boolean( + "cdfg", + False, + "If set, use the CDFG representation for programs. Defaults to ProGraML " + "representations.", +) +app.DEFINE_boolean( + "ig", + False, + "If set, run IG analysis.", +) +app.DEFINE_boolean( + "batch", + False, + "If set, test samples in batch.", +) +app.DEFINE_boolean( + "save_graph", + False, + "If set, save annotated graphs.", +) +app.DEFINE_boolean( + "save_vis", + False, + "If set, save visualization images.", +) +app.DEFINE_boolean( + "dep_guided_ig", + False, + "If set, enable dependency-guided IG attribution.", +) +app.DEFINE_boolean( + "only_pred_y", + False, + "If set, only calculate IG attributions for pred_y=1 nodes.", +) +app.DEFINE_boolean( + "use_acyclic_for_std_ig", + False, + "If set, use acyclicalized graphs to standard IG as well.", +) +app.DEFINE_integer( + "max_vocab_size", + 0, + "If > 0, limit the size of the vocabulary to this number.", +) +app.DEFINE_integer( + "max_vis_graph_complexity", + 0, + "If > 0, limit the max complexity of visualized graphs.", +) +app.DEFINE_integer( + "random_test_size", + 0, + "If > 0, randomly select this many graph evaluations.", +) +app.DEFINE_float( + "max_removed_edges_ratio", + -1, + "If > -1, limit the max number of removed edges.", +) +app.DEFINE_boolean( + "filter_adjacant_nodes", + False, + "If set, filter out nodes that are too far from source.", +) +app.DEFINE_float("target_vocab_cumfreq", 1.0, + "The target cumulative frequency that.") +app.DEFINE_string( + "ds_path", + str(pathlib.Path("~/code-model-interpretation/ProGraML/dataset/dataflow").expanduser()), + "The dataset directory.", +) +app.DEFINE_string("model", None, "The model checkpoint to restore") +app.DEFINE_string( + "input", + None, + "Path of the input graph features list and index into it, " + "e.g., /path/to/foo:1 to select the second graph from file /path/to/foo.", +) +app.DEFINE_integer( + "instance_id", + None, + "ID for this instance (for concurrency)." +) +app.DEFINE_integer( + "num_instances", + None, + "Total number of instances to run (for concurrency)." +) +app.DEFINE_string( + "task", + "datadep", + "Specify what task to test against." +) +app.DEFINE_boolean( + "debug", + False, + "Whether to stop encountering exceptions." +) +app.DEFINE_boolean( + "remove_edges", + False, + "Whether to remove edges in deletion/retention games." +) +FLAGS = app.FLAGS + +ATTR_ACC_ASC_ORDER_TASKS = { + "datadep", +} +ATTR_ACC_DES_ORDER_TASKS = { + "reachability", + "domtree", + "liveness", +} + + +class SingleGraphLoader(BaseGraphLoader): + """`A graph loader which reads a single graph.""" + + def __init__( + self, + graph: program_graph_pb2.ProgramGraph, + features: program_graph_features_pb2.ProgramGraphFeatures, + ): + self.graph = graph + self.features = features + + def IterableType(self) -> Any: + return ( + program_graph_pb2.ProgramGraph, + program_graph_features_pb2.ProgramGraphFeatures, + ) + + def __iter__(self) -> Iterable["IterableType"]: + yield self.graph, self.features + + def Stop(self): + pass + + +class TooComplexGraphError(Exception): + pass + + +class TooManyRootNodesError(Exception): + pass + + +class CycleInGraphError(Exception): + pass + + +class TooManyEdgesRemovedError(Exception): + pass + + +class NoQualifiedOutNodeError(Exception): + pass + + +def RemoveCyclesFromGraph( + networkx_graph: nx.DiGraph, + method: str = "eades", +) -> nx.DiGraph: + # This is a sufficint solution, with optimality guarantee if "ip" method + # is seleted (but it will be very slow on large graph). A heuristic-based + # method "eades" is also provided, with fast speed but only upper bound of + # the number of removed edges of "|E|/2 - |V|/6". + print("Removing cycles...") + igraph_graph = ig.Graph.from_networkx(networkx_graph) + edges_to_remove = ig.Graph.feedback_arc_set(igraph_graph, method=method) + edge_list = igraph_graph.get_edgelist() + for edge_id in edges_to_remove: + tail, head = edge_list[edge_id] + networkx_graph.remove_edge(u=tail, v=head) + return networkx_graph + + +def FilterDistantNodes( + nodes_out: List[int], + target_node_id: int, + graph: program_graph_pb2.ProgramGraph, +) -> List[int]: + filtered_nodes_out = [] + networkx_graph = ProgramGraphToNetworkX(graph) + for source_node_id in nodes_out: + dist = nx.algorithms.shortest_paths.generic.shortest_path_length( + networkx_graph, + source=source_node_id, + target=target_node_id, + ) + if dist > 1: + filtered_nodes_out.append(source_node_id) + return filtered_nodes_out + + +def CalculateInterpolationOrderFromGraph( + graph: program_graph_pb2.ProgramGraph, + max_removed_edges_ratio: float = -1, + sample_multi_topo_orders: bool = True, + max_num_topo_orders: int = 20, +) -> List[int]: + # This function returns the (topological) order of nodes to evaluate + # for interpolations in IG + networkx_graph = ProgramGraphToNetworkX(graph) + is_acyclic = nx.algorithms.dag.is_directed_acyclic_graph(networkx_graph) + if is_acyclic: + if sample_multi_topo_orders: + sampled_ordered_nodes = [] + sampled_ordered_nodes_gen = nx.algorithms.dag.all_topological_sorts( + networkx_graph) + topo_orders_cnt = 0 + for order_nodes in sampled_ordered_nodes_gen: + topo_orders_cnt += 1 + if topo_orders_cnt > max_num_topo_orders: + break + sampled_ordered_nodes.append(order_nodes) + ordered_nodes = sampled_ordered_nodes + else: + ordered_nodes = [list(nx.topological_sort(networkx_graph))] + + return ordered_nodes, networkx_graph + else: + # Cycle(s) detected and we need to remove them now + if max_removed_edges_ratio != -1: + original_num_edges = len(networkx_graph.edges) + acyclic_networkx_graph = RemoveCyclesFromGraph(networkx_graph) + if max_removed_edges_ratio != -1: + trimmed_num_edges = len(acyclic_networkx_graph.edges) + num_removed_edges = original_num_edges - trimmed_num_edges + print("Total edges: %d | Removed %d edges." % + (original_num_edges, num_removed_edges)) + if (num_removed_edges / original_num_edges) > max_removed_edges_ratio: + raise TooManyEdgesRemovedError + + # Sanity check, only return the graph if it is acyclic + is_acyclic = nx.algorithms.dag.is_directed_acyclic_graph( + acyclic_networkx_graph) + if not is_acyclic: + raise CycleInGraphError + else: + if sample_multi_topo_orders: + sampled_ordered_nodes = [] + sampled_ordered_nodes_gen = nx.algorithms.dag.all_topological_sorts( + acyclic_networkx_graph) + topo_orders_cnt = 0 + for order_nodes in sampled_ordered_nodes_gen: + topo_orders_cnt += 1 + if topo_orders_cnt > max_num_topo_orders: + break + sampled_ordered_nodes.append(order_nodes) + ordered_nodes = sampled_ordered_nodes + else: + ordered_nodes = [ + list(nx.topological_sort(acyclic_networkx_graph))] + + return ordered_nodes, acyclic_networkx_graph + + +def GenerateInterpolationOrderFromGraph( + features_list_path: pathlib.Path, + features_list_index: int, + ds_path: str, + max_removed_edges_ratio: float, +) -> Tuple[List[int], nx.DiGraph]: + path = pathlib.Path(ds_path) + + features_list = pbutil.FromFile( + features_list_path, + program_graph_features_pb2.ProgramGraphFeaturesList(), + ) + features = features_list.graph[features_list_index] + + graph_name = features_list_path.name[: - + len(".ProgramGraphFeaturesList.pb")] + graph = pbutil.FromFile( + path / "graphs" / f"{graph_name}.ProgramGraph.pb", + program_graph_pb2.ProgramGraph(), + ) + + if FLAGS.max_vis_graph_complexity != 0: + if len(graph.node) > FLAGS.max_vis_graph_complexity: + raise TooComplexGraphError + if len(graph.edge) > FLAGS.max_vis_graph_complexity * 2: + raise TooComplexGraphError + + # First, we need to fix empty node features + graph = FixEmptyNodeFeatures(graph, features) + + interpolation_order, acyclic_networkx_graph = CalculateInterpolationOrderFromGraph( + graph, + max_removed_edges_ratio=max_removed_edges_ratio, + ) + return interpolation_order, acyclic_networkx_graph + + +def TestOne( + features_list_path: pathlib.Path, + features_list_index: int, + checkpoint_path: pathlib.Path, + ds_path: str, + run_ig: bool, + dep_guided_ig: bool, + all_nodes_out: bool, + reverse: bool, + filter_adjacant_nodes: bool, + accumulate_gradients: bool, + interpolation_order: List[int], + acyclic_networkx_graph: nx.DiGraph, +) -> BatchResults: + if dep_guided_ig and not run_ig: + print("run_ig and dep_guided_ig args take different values which is invalid!") + raise RuntimeError + + path = pathlib.Path(ds_path) + + features_list = pbutil.FromFile( + features_list_path, + program_graph_features_pb2.ProgramGraphFeaturesList(), + ) + features = features_list.graph[features_list_index] + + graph_name = features_list_path.name[: - + len(".ProgramGraphFeaturesList.pb")] + graph = pbutil.FromFile( + path / "graphs" / f"{graph_name}.ProgramGraph.pb", + program_graph_pb2.ProgramGraph(), + ) + + # First, we need to fix empty node features + graph = FixEmptyNodeFeatures(graph, features) + + interpolation_order = deepcopy(interpolation_order) + + if run_ig: # we can also compute accuracies for standard IG + if reverse: + reversed_interpolation_order = [] + for order in interpolation_order: + order.reverse() + reversed_interpolation_order.append(order) + interpolation_order = reversed_interpolation_order + if not dep_guided_ig: + interpolation_order = None + else: + interpolation_order = None + + acyclic_networkx_graph = deepcopy(acyclic_networkx_graph) + + root_nodes = [] + if all_nodes_out: + nodes_out = [] + for i in range(len(graph.node)): + if features.node_features.feature_list["data_flow_root_node"].feature[i].int64_list.value == [1]: + root_nodes.append(i) + if features.node_features.feature_list["data_flow_value"].feature[i].int64_list.value == [1]: + nodes_out.append(i) + + # Filter nodes that are not suitable for evaluations (too far). + if filter_adjacant_nodes: + nodes_out = FilterDistantNodes(nodes_out, root_nodes[0], graph) + if len(nodes_out) == 0: + raise NoQualifiedOutNodeError + else: + for i in range(len(graph.node)): + if features.node_features.feature_list["data_flow_root_node"].feature[i].int64_list.value == [1]: + root_nodes.append(i) + if len(root_nodes) > 1: + raise TooManyRootNodesError + + # Instantiate and restore the model. + vocab = vocabulary.LoadVocabulary( + path, + model_name="cdfg" if FLAGS.cdfg else "programl", + max_items=FLAGS.max_vocab_size, + target_cumfreq=FLAGS.target_vocab_cumfreq, + ) + + if FLAGS.cdfg: + FLAGS.use_position_embeddings = False + + model = Ggnn( + vocabulary=vocab, + test_only=True, + node_y_dimensionality=2, + graph_y_dimensionality=0, + graph_x_dimensionality=0, + use_selector_embeddings=True, + ) + checkpoint = pbutil.FromFile(checkpoint_path, checkpoint_pb2.Checkpoint()) + model.RestoreCheckpoint(checkpoint) + + batch = list( + DataflowGgnnBatchBuilder( + graph_loader=SingleGraphLoader(graph=graph, features=features), + vocabulary=vocab, + max_node_size=int(1e9), + use_cdfg=FLAGS.cdfg, + max_batch_count=1, + ) + ) + assert len(batch) == 1, "[ERROR] More than one graph exist in this batch (which should not happen in testing)!" + batch = batch[0] + + if all_nodes_out: + results_predicted_nodes = [] + for node_out in nodes_out: + results = model.RunBatch( + epoch_pb2.TEST, + batch, + run_ig=run_ig, + dep_guided_ig=dep_guided_ig, + interpolation_order=interpolation_order, + node_out=node_out, + accumulate_gradients=accumulate_gradients, + reverse=reverse, + average_attrs=True, + remove_edges=FLAGS.remove_edges, + ) + results_predicted_nodes.append(results) + return AnnotateGraphWithBatchResultsForPredictedNodes( + graph, + features, + results_predicted_nodes, + nodes_out, + run_ig, + acyclic_networkx_graph, + ) + else: + results = model.RunBatch( + epoch_pb2.TEST, + batch, + run_ig=run_ig, + dep_guided_ig=dep_guided_ig, + interpolation_order=interpolation_order, + ) + return AnnotateGraphWithBatchResults(graph, features, results, run_ig) + + +def FixEmptyNodeFeatures( + graph: program_graph_pb2.ProgramGraph, + features: program_graph_features_pb2.ProgramGraphFeatures, +) -> program_graph_pb2.ProgramGraph: + assert len(graph.node) == len( + features.node_features.feature_list["data_flow_value"].feature + ) + assert len(graph.node) == len( + features.node_features.feature_list["data_flow_root_node"].feature + ) + + for i, node in enumerate(graph.node): + # Fix empty node feature errors so that we can persist graphs + if node.features.feature["full_text"].bytes_list.value == []: + node.features.feature["full_text"].bytes_list.value.append(b'') + + return graph + + +def CalculateAttributionAccuracyScore( + graph: nx.DiGraph, + attribution_order: List[int], + source_node_id: int, + target_node_id: int, + use_all_paths: bool = True, + dummy_score: bool = True, +) -> float: + if dummy_score: + return 0.0 + + if use_all_paths: + all_paths = list( + nx.algorithms.simple_paths.all_simple_paths( + graph, + source=source_node_id, + target=target_node_id + ) + ) + all_shortest_paths = list( + nx.algorithms.shortest_paths.generic.all_shortest_paths( + graph, + source=source_node_id, + target=target_node_id + ) + ) + + if use_all_paths: + path_nodes_set = set([node for path in all_paths for node in path]) + shortest_path_nodes_set = set( + [node for path in all_shortest_paths for node in path]) + + if use_all_paths: + path_score = 0.0 + shortest_path_score = 0.0 + for i in range(len(attribution_order)): + attr_order = attribution_order[i] + if use_all_paths: + if i in path_nodes_set: + path_score += 1 / (attr_order + 1) + if i in shortest_path_nodes_set: + shortest_path_score += 1 / (attr_order + 1) + + if use_all_paths: + final_score = 0.5 * path_score + 0.5 * shortest_path_score + else: + final_score = shortest_path_score + return final_score + + +def AnnotateGraphWithBatchResultsForPredictedNodes( + base_graph: program_graph_pb2.ProgramGraph, + features: program_graph_features_pb2.ProgramGraphFeatures, + results_predicted_nodes: List[BatchResults], + nodes_out: List[int], + run_ig: bool, + acyclic_networkx_graph: nx.DiGraph, + num_deletion_retention_nodes: int = 10, +) -> program_graph_pb2.ProgramGraph: + """Annotate graph with features describing the target labels and predicted outcomes.""" + assert len(base_graph.node) == len( + features.node_features.feature_list["data_flow_value"].feature + ) + assert len(base_graph.node) == len( + features.node_features.feature_list["data_flow_root_node"].feature + ) + + graphs = [] + + for i in range(len(results_predicted_nodes)): + results = results_predicted_nodes[i] + graph = deepcopy(base_graph) + + if run_ig: + assert len(graph.node) == results.attributions.shape[0] + + true_y = np.argmax(results.targets, axis=1) + pred_y = np.argmax(results.predictions, axis=1) + + for j, node in enumerate(graph.node): + node.features.feature["data_flow_root_node"].CopyFrom( + features.node_features.feature_list["data_flow_root_node"].feature[j] + ) + if j == nodes_out[i]: + node.features.feature["true_y"].int64_list.value.append( + true_y[0]) + node.features.feature["pred_y"].int64_list.value.append( + pred_y[0]) + else: + node.features.feature["true_y"].int64_list.value.append(0) + node.features.feature["pred_y"].int64_list.value.append(0) + if run_ig: + node.features.feature["attribution_order"].int64_list.value.append( + results.attributions[j]) + if features.node_features.feature_list["data_flow_root_node"].feature[j].int64_list.value == [1]: + if FLAGS.task in ATTR_ACC_ASC_ORDER_TASKS: + target_node_id = j + elif FLAGS.task in ATTR_ACC_DES_ORDER_TASKS: + source_node_id = j + + graph.features.feature["loss"].float_list.value.append(results.loss) + graph.features.feature["accuracy"].float_list.value.append( + results.accuracy) + graph.features.feature["precision"].float_list.value.append( + results.precision) + graph.features.feature["recall"].float_list.value.append( + results.recall) + graph.features.feature["f1"].float_list.value.append(results.f1) + graph.features.feature["confusion_matrix"].int64_list.value[:] = np.hstack( + results.confusion_matrix + ) + + if run_ig: + if FLAGS.task in ATTR_ACC_ASC_ORDER_TASKS: + source_node_id = nodes_out[i] + elif FLAGS.task in ATTR_ACC_DES_ORDER_TASKS: + target_node_id = nodes_out[i] + try: + attribution_acc_score = CalculateAttributionAccuracyScore( + acyclic_networkx_graph, + results.attributions, + source_node_id, + target_node_id, + ) + except nx.exception.NetworkXNoPath: + print("No feasible from source %d to target %d was found!" % + (source_node_id, target_node_id)) + graph.features.feature["attribution_accuracy"].float_list.value.append( + -1.0) + continue + print("Feasible from source %d to target %d was found." % + (source_node_id, target_node_id)) + graph.features.feature["attribution_accuracy"].float_list.value.append( + attribution_acc_score) + graph.features.feature["faithfulness_score"].float_list.value.append( + results.faithfulness_score) + + if len(results.deletion_res) >= num_deletion_retention_nodes: + deletion_res = results.deletion_res[:num_deletion_retention_nodes] + elif len(results.deletion_res) < num_deletion_retention_nodes: + deletion_res = results.deletion_res[:num_deletion_retention_nodes] + [0.0] * (num_deletion_retention_nodes - len(results.deletion_res)) + for deletion_delta in deletion_res: + graph.features.feature["deletion_res"].float_list.value.append(deletion_delta) + + if len(results.retention_res) >= num_deletion_retention_nodes: + retention_res = results.retention_res[:num_deletion_retention_nodes] + elif len(results.retention_res) < num_deletion_retention_nodes: + retention_res = results.retention_res[:num_deletion_retention_nodes] + [0.0] * (num_deletion_retention_nodes - len(results.retention_res)) + for retention_delta in retention_res: + graph.features.feature["retention_res"].float_list.value.append(retention_delta) + + graphs.append(graph) + + return graphs + + +def AnnotateGraphWithBatchResults( + graph: program_graph_pb2.ProgramGraph, + features: program_graph_features_pb2.ProgramGraphFeatures, + results: BatchResults, + run_ig: bool, +) -> program_graph_pb2.ProgramGraph: + """Annotate graph with features describing the target labels and predicted outcomes.""" + assert len(graph.node) == len( + features.node_features.feature_list["data_flow_value"].feature + ) + assert len(graph.node) == len( + features.node_features.feature_list["data_flow_root_node"].feature + ) + assert len(graph.node) == results.targets.shape[0] + if run_ig: + assert len(graph.node) == results.attributions.shape[0] + + true_y = np.argmax(results.targets, axis=1) + pred_y = np.argmax(results.predictions, axis=1) + + for i, node in enumerate(graph.node): + node.features.feature["data_flow_value"].CopyFrom( + features.node_features.feature_list["data_flow_value"].feature[i] + ) + node.features.feature["data_flow_root_node"].CopyFrom( + features.node_features.feature_list["data_flow_root_node"].feature[i] + ) + node.features.feature["target"].float_list.value[:] = results.targets[i] + node.features.feature["prediction"].float_list.value[:] = results.predictions[i] + node.features.feature["true_y"].int64_list.value.append(true_y[i]) + node.features.feature["pred_y"].int64_list.value.append(pred_y[i]) + node.features.feature["correct"].int64_list.value.append( + true_y[i] == pred_y[i]) + if run_ig: + node.features.feature["attribution_order"].int64_list.value.append( + results.attributions[i]) + + graph.features.feature["loss"].float_list.value.append(results.loss) + graph.features.feature["accuracy"].float_list.value.append( + results.accuracy) + graph.features.feature["precision"].float_list.value.append( + results.precision) + graph.features.feature["recall"].float_list.value.append(results.recall) + graph.features.feature["f1"].float_list.value.append(results.f1) + graph.features.feature["confusion_matrix"].int64_list.value[:] = np.hstack( + results.confusion_matrix + ) + + return graph + + +def TestOneGraph( + ds_path, + model_path, + graph_path, + graph_idx, + run_ig=False, + dep_guided_ig=False, + all_nodes_out=False, + reverse=False, + filter_adjacant_nodes=False, + accumulate_gradients=True, + interpolation_order=None, + acyclic_networkx_graph=None, +): + if all_nodes_out: + graphs = TestOne( + features_list_path=pathlib.Path(graph_path), + features_list_index=int(graph_idx), + checkpoint_path=pathlib.Path(ds_path + model_path), + ds_path=ds_path, + run_ig=run_ig, + dep_guided_ig=dep_guided_ig, + all_nodes_out=all_nodes_out, + reverse=reverse, + filter_adjacant_nodes=filter_adjacant_nodes, + accumulate_gradients=accumulate_gradients, + interpolation_order=interpolation_order, + acyclic_networkx_graph=acyclic_networkx_graph, + ) + torch.cuda.empty_cache() + return graphs + else: + graph = TestOne( + features_list_path=pathlib.Path(graph_path), + features_list_index=int(graph_idx), + checkpoint_path=pathlib.Path(ds_path + model_path), + ds_path=ds_path, + run_ig=run_ig, + dep_guided_ig=dep_guided_ig, + all_nodes_out=all_nodes_out, + reverse=reverse, + filter_adjacant_nodes=filter_adjacant_nodes, + accumulate_gradients=accumulate_gradients, + interpolation_order=interpolation_order, + acyclic_networkx_graph=acyclic_networkx_graph, + ) + torch.cuda.empty_cache() + return graph + + +def DrawAndSaveGraph( + graph, + ds_path, + graph_fname, + save_graph=False, + save_vis=False, + suffix='' +): + if not isinstance(graph, list): + # Meaning we are handling per-node IG + graphs = [graph] + else: + graphs = graph + + attr_scores, faithfulness_scores = [], [] + all_deletion_res, all_retention_res = [], [] + + for i in range(len(graphs)): + graph = graphs[i] + if graph.features.feature["attribution_accuracy"].float_list.value[0] == -1.0: + continue + + save_graph_path = ds_path + '/vis_res/' + graph_fname + \ + ".AttributedProgramGraphFeaturesList.%s.%d.pb" % (suffix, i) + if save_graph: + print("Saving annotated graph to %s..." % save_graph_path) + serialize_ops.save_graphs(save_graph_path, [graph]) + + networkx_graph = ProgramGraphToNetworkX(graph) + + original_labels = nx.get_node_attributes(networkx_graph, "features") + + labels = {} + for node, features in original_labels.items(): + curr_label = "" + curr_label += "Pred: " + str(features["pred_y"]) + " | " + curr_label += "True: " + str(features["true_y"]) + " | \n" + curr_label += "Attr: " + str(features["attribution_order"]) + " | " + if features["data_flow_root_node"] == 0: + curr_label += "Target" + labels[node] = '[' + curr_label + ']' + + color = [] + for node in networkx_graph.nodes(): + if original_labels[node]["data_flow_root_node"] == [1]: + color.append('red') + elif original_labels[node]["pred_y"] == [1]: + color.append('purple') + else: + color.append('grey') + + pos = graphviz_layout(networkx_graph, prog='neato') + nx.draw(networkx_graph, pos=pos, labels=labels, + node_size=500, node_color=color) + + if save_vis: + save_img_path = ds_path + '/vis_res/' + graph_fname + \ + ".AttributedProgramGraph.%s.%d.png" % (suffix, i) + print("Saving visualization of annotated graph to %s..." % + save_img_path) + attr_acc_score = graph.features.feature["attribution_accuracy"].float_list.value[0] + faithfulness_score = graph.features.feature["faithfulness_score"].float_list.value[0] + plt.text(x=20, y=20, s="Attr acc: %f" % attr_acc_score) + plt.text(x=20, y=40, s="Faith score: %f" % faithfulness_score) + plt.show() + plt.savefig(save_img_path, format="PNG") + plt.clf() + deletion_res = graph.features.feature["deletion_res"].float_list.value + retention_res = graph.features.feature["retention_res"].float_list.value + + attr_scores.append(attr_acc_score) + faithfulness_scores.append(faithfulness_score) + all_deletion_res.append(deletion_res) + all_retention_res.append(retention_res) + + return attr_scores, faithfulness_scores, all_deletion_res, all_retention_res + + +def Main(): + """Main entry point.""" + dataflow.PatchWarnings() + + instance_id = FLAGS.instance_id - 1 # due to Shell for loop convention + + # Handle all logging stuff + now = datetime.now() + ts_string = now.strftime("%d_%m_%Y_%H_%M_%S") + fmt = logging.Formatter("%(asctime)s | %(message)s") + + log_filepath = FLAGS.ds_path + \ + "/exp_log/batch_exp_%s_res_%d_%s.log" % ( + FLAGS.task, instance_id, ts_string) + logger = logging.getLogger("logger") + logger.setLevel(logging.DEBUG) + file_handler = logging.FileHandler(log_filepath, mode="w") + file_handler.setFormatter(fmt) + logger.addHandler(file_handler) + logger.debug("Log file being written into at %s" % log_filepath) + + if FLAGS.batch: + graphs_dir = FLAGS.ds_path + '/labels/%s/' % FLAGS.task + # sort the list to ensure the stability of concurrency + graph_fnames = sorted(listdir(graphs_dir)) + if FLAGS.random_test_size != 0: + random.shuffle(graph_fnames) + success_count = 0 + + variant_ranks = { + "STANDARD_IG": [], + "ASCENDING_DEPENDENCY_GUIDED_IG": [], + "UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG": [], + "DESCENDING_DEPENDENCY_GUIDED_IG": [], + } + logger.info('\t'.join([ + "GRAPH_NAME", "STANDARD_IG", "ASCENDING_DEPENDENCY_GUIDED_IG", \ + "UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "DESCENDING_DEPENDENCY_GUIDED_IG", \ + "FAITH_STANDARD_IG", "FAITH_ASCENDING_DEPENDENCY_GUIDED_IG", \ + "FAITH_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "FAITH_DESCENDING_DEPENDENCY_GUIDED_IG", \ + "DELETION_RES_STANDARD_IG", "DELETION_RES_ASCENDING_DEPENDENCY_GUIDED_IG", "DELETION_RES_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", \ + "DELETION_RES_DESCENDING_DEPENDENCY_GUIDED_IG", "RETENTION_RES_STANDARD_IG", "RETENTION_RES_ASCENDING_DEPENDENCY_GUIDED_IG", \ + "RETENTION_RES_UNACCUMULATED_ASCENDING_DEPENDENCY_GUIDED_IG", "RETENTION_RES_DESCENDING_DEPENDENCY_GUIDED_IG" + ])) + for i in range(len(graph_fnames)): + graph_fname = graph_fnames[i] + if i % FLAGS.num_instances != instance_id: + continue + try: + original_graph_fname = graph_fname[: - + len(".ProgramGraphFeaturesList.pb")].split('/')[-1] + print("Processing graph file: %s..." % graph_fname) + graph_path = graphs_dir + graph_fname + try: + interpolation_order, acyclic_networkx_graph = GenerateInterpolationOrderFromGraph( + features_list_path=pathlib.Path(graph_path), + features_list_index=int('-1'), + ds_path=FLAGS.ds_path, + max_removed_edges_ratio=FLAGS.max_removed_edges_ratio, + ) + graph_std_ig = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + graph_path, + '-1', + run_ig=FLAGS.ig, + dep_guided_ig=False, + all_nodes_out=FLAGS.only_pred_y, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + interpolation_order=interpolation_order, + acyclic_networkx_graph=acyclic_networkx_graph, + ) + graph_dep_guided_ig = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + graph_path, + '-1', + run_ig=FLAGS.ig, + dep_guided_ig=True, + all_nodes_out=FLAGS.only_pred_y, + reverse=False, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + accumulate_gradients=True, + interpolation_order=interpolation_order, + acyclic_networkx_graph=acyclic_networkx_graph, + ) + graph_dep_guided_ig_unaccumulated = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + graph_path, + '-1', + run_ig=FLAGS.ig, + dep_guided_ig=True, + all_nodes_out=FLAGS.only_pred_y, + reverse=False, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + accumulate_gradients=False, + interpolation_order=interpolation_order, + acyclic_networkx_graph=acyclic_networkx_graph, + ) + graph_reverse_dep_guided_ig = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + graph_path, + '-1', + run_ig=FLAGS.ig, + dep_guided_ig=True, + all_nodes_out=FLAGS.only_pred_y, + reverse=True, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + accumulate_gradients=True, + interpolation_order=interpolation_order, + acyclic_networkx_graph=acyclic_networkx_graph, + ) + print("Acyclic graph found and loaded.") + except TooComplexGraphError: + print("Skipping graph %s due to exceeding number of nodes..." % + original_graph_fname) + continue + except CycleInGraphError: + print("Skipping graph %s due to presence of graph cycle(s)..." % + original_graph_fname) + continue + except TooManyEdgesRemovedError: + print("Skipping graph %s due to exceeding number of removed edges..." % + original_graph_fname) + continue + except TooManyRootNodesError: + print("Skipping graph %s due to exceeding number of root nodes..." % + original_graph_fname) + continue + except NoQualifiedOutNodeError: + print("Skipping graph %s due to no out node found..." % + original_graph_fname) + continue + + if FLAGS.ig and FLAGS.dep_guided_ig: + attr_acc_std_ig, faith_score_std_ig, deletion_res_std_ig, retention_res_std_ig = DrawAndSaveGraph( + graph_std_ig, FLAGS.ds_path, + original_graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='std_ig' + ) + attr_acc_dep_guided_ig, faith_score_dep_guided_ig, deletion_res_dep_guided_ig, retention_res_dep_guided_ig = DrawAndSaveGraph( + graph_dep_guided_ig, FLAGS.ds_path, + original_graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='dep_guided_ig' + ) + attr_acc_dep_guided_ig_unaccumulated, faith_score_dep_guided_ig_unaccumulated, deletion_res_dep_guided_ig_unaccumulated, retention_res_dep_guided_ig_unaccumulated = DrawAndSaveGraph( + graph_dep_guided_ig_unaccumulated, FLAGS.ds_path, + original_graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='dep_guided_ig_unaccumulated' + ) + attr_acc_reverse_dep_guided_ig, faith_score_reverse_dep_guided_ig, deletion_res_reverse_dep_guided_ig, retention_res_reverse_dep_guided_ig = DrawAndSaveGraph( + graph_dep_guided_ig, FLAGS.ds_path, + original_graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='reverse_dep_guided_ig' + ) + for attr_acc_std_ig, attr_acc_dep_guided_ig, attr_acc_dep_guided_ig_unaccumulated, \ + attr_acc_reverse_dep_guided_ig, faith_score_std_ig, faith_score_dep_guided_ig, \ + faith_score_dep_guided_ig_unaccumulated, faith_score_reverse_dep_guided_ig, \ + deletion_res_std_ig, deletion_res_dep_guided_ig, deletion_res_dep_guided_ig_unaccumulated, \ + deletion_res_reverse_dep_guided_ig, retention_res_std_ig, retention_res_dep_guided_ig, \ + retention_res_dep_guided_ig_unaccumulated, retention_res_reverse_dep_guided_ig in zip( + attr_acc_std_ig, + attr_acc_dep_guided_ig, + attr_acc_dep_guided_ig_unaccumulated, + attr_acc_reverse_dep_guided_ig, + faith_score_std_ig, + faith_score_dep_guided_ig, + faith_score_dep_guided_ig_unaccumulated, + faith_score_reverse_dep_guided_ig, + deletion_res_std_ig, + deletion_res_dep_guided_ig, + deletion_res_dep_guided_ig_unaccumulated, + deletion_res_reverse_dep_guided_ig, + retention_res_std_ig, + retention_res_dep_guided_ig, + retention_res_dep_guided_ig_unaccumulated, + retention_res_reverse_dep_guided_ig, + ): + logger.info('\t'.join([ + graph_fname, + str(attr_acc_std_ig), + str(attr_acc_dep_guided_ig), + str(attr_acc_dep_guided_ig_unaccumulated), + str(attr_acc_reverse_dep_guided_ig), + str(faith_score_std_ig), + str(faith_score_dep_guided_ig), + str(faith_score_dep_guided_ig_unaccumulated), + str(faith_score_reverse_dep_guided_ig), + str(deletion_res_std_ig), + str(deletion_res_dep_guided_ig), + str(deletion_res_dep_guided_ig_unaccumulated), + str(deletion_res_reverse_dep_guided_ig), + str(retention_res_std_ig), + str(retention_res_dep_guided_ig), + str(retention_res_dep_guided_ig_unaccumulated), + str(retention_res_reverse_dep_guided_ig), + ])) + if FLAGS.random_test_size != 0: + success_count += 1 + print("Successfully finished %d graphs." % success_count) + if success_count > FLAGS.random_test_size: + print("Finished all graphs.") + exit() + + except Exception as err: + if FLAGS.debug: + raise err + else: + print("Error testing %s -- %s" % (graph_fname, str(err))) + continue + else: + features_list_path, features_list_index = FLAGS.input.split(":") + graph_fname = features_list_path[: - + len(".ProgramGraphFeaturesList.pb")].split('/')[-1] + try: + graph_std_ig = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + FLAGS.ds_path + features_list_path, + features_list_index, + FLAGS.max_vis_graph_complexity, + run_ig=FLAGS.ig, + dep_guided_ig=False, + all_nodes_out=FLAGS.only_pred_y, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + ) + graph_dep_guided_ig = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + FLAGS.ds_path + features_list_path, + features_list_index, + FLAGS.max_vis_graph_complexity, + run_ig=FLAGS.ig, + dep_guided_ig=True, + all_nodes_out=FLAGS.only_pred_y, + reverse=False, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + accumulate_gradients=True, + ) + graph_dep_guided_ig_unaccumulated = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + FLAGS.ds_path + features_list_path, + features_list_index, + FLAGS.max_vis_graph_complexity, + run_ig=FLAGS.ig, + dep_guided_ig=True, + all_nodes_out=FLAGS.only_pred_y, + reverse=False, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + accumulate_gradients=False, + ) + graph_reverse_dep_guided_ig = TestOneGraph( + FLAGS.ds_path, + FLAGS.model, + FLAGS.ds_path + features_list_path, + features_list_index, + FLAGS.max_vis_graph_complexity, + run_ig=FLAGS.ig, + dep_guided_ig=True, + all_nodes_out=FLAGS.only_pred_y, + reverse=True, + filter_adjacant_nodes=FLAGS.filter_adjacant_nodes, + accumulate_gradients=True, + ) + except TooComplexGraphError: + print("Skipping graph %s due to exceeding number of nodes..." % + graph_fname) + exit() + except CycleInGraphError: + print("Skipping graph %s due to presence of graph cycle(s)..." % + graph_fname) + exit() + except TooManyEdgesRemovedError: + print( + "Skipping graph %s due to exceeding number of removed edges..." % graph_fname) + exit() + except TooManyRootNodesError: + print( + "Skipping graph %s due to exceeding number of root nodes..." % graph_fname) + exit() + except NoQualifiedOutNodeError: + print("Skipping graph %s due to no out node found..." % graph_fname) + exit() + + if FLAGS.ig and FLAGS.dep_guided_ig: + DrawAndSaveGraph( + graph_std_ig, FLAGS.ds_path, + graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='std_ig' + ) + DrawAndSaveGraph( + graph_dep_guided_ig, FLAGS.ds_path, + graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='dep_guided_ig' + ) + DrawAndSaveGraph( + graph_dep_guided_ig, FLAGS.ds_path, + graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='dep_guided_ig_unaccumulated' + ) + DrawAndSaveGraph( + graph_reverse_dep_guided_ig, FLAGS.ds_path, + graph_fname, save_graph=FLAGS.save_graph, + save_vis=FLAGS.save_vis, suffix='reverse_dep_guided_ig' + ) + + +if __name__ == "__main__": + app.Run(Main) diff --git a/programl/task/dataflow/ggnn_test_one.py b/programl/task/dataflow/ggnn_test_one.py deleted file mode 100644 index 8ad050b7e..000000000 --- a/programl/task/dataflow/ggnn_test_one.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2019-2020 the ProGraML authors. -# -# Contact Chris Cummins . -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Run inference of a trained GGNN model on a single graph input. -""" -import pathlib -from typing import Any, Iterable - -import numpy as np -from labm8.py import app, pbutil - -from programl.models.base_graph_loader import BaseGraphLoader -from programl.models.batch_results import BatchResults -from programl.models.ggnn.ggnn import Ggnn -from programl.proto import ( - checkpoint_pb2, - epoch_pb2, - program_graph_features_pb2, - program_graph_pb2, -) -from programl.task.dataflow import dataflow, vocabulary -from programl.task.dataflow.dataset import pathflag -from programl.task.dataflow.ggnn_batch_builder import DataflowGgnnBatchBuilder - -app.DEFINE_boolean( - "cdfg", - False, - "If set, use the CDFG representation for programs. Defaults to ProGraML " - "representations.", -) -app.DEFINE_integer( - "max_vocab_size", - 0, - "If > 0, limit the size of the vocabulary to this number.", -) -app.DEFINE_float("target_vocab_cumfreq", 1.0, "The target cumulative frequency that.") -app.DEFINE_input_path("model", None, "The model checkpoint to restore") -app.DEFINE_string( - "input", - None, - "Path of the input graph features list and index into it, " - "e.g., /path/to/foo:1 to select the second graph from file /path/to/foo.", -) -FLAGS = app.FLAGS - - -class SingleGraphLoader(BaseGraphLoader): - """`A graph loader which reads a single graph.""" - - def __init__( - self, - graph: program_graph_pb2.ProgramGraph, - features: program_graph_features_pb2.ProgramGraphFeatures, - ): - self.graph = graph - self.features = features - - def IterableType(self) -> Any: - return ( - program_graph_pb2.ProgramGraph, - program_graph_features_pb2.ProgramGraphFeatures, - ) - - def __iter__(self) -> Iterable["IterableType"]: - yield self.graph, self.features - - def Stop(self): - pass - - -def TestOne( - features_list_path: pathlib.Path, - features_list_index: int, - checkpoint_path: pathlib.Path, -) -> BatchResults: - path = pathlib.Path(pathflag.path()) - - features_list = pbutil.FromFile( - features_list_path, - program_graph_features_pb2.ProgramGraphFeaturesList(), - ) - features = features_list.graph[features_list_index] - - graph_name = features_list_path.name[: -len(".ProgramGraphFeaturesList.pb")] - graph = pbutil.FromFile( - path / "graphs" / f"{graph_name}.ProgramGraph.pb", - program_graph_pb2.ProgramGraph(), - ) - - # Instantiate and restore the model. - vocab = vocabulary.LoadVocabulary( - path, - model_name="cdfg" if FLAGS.cdfg else "programl", - max_items=FLAGS.max_vocab_size, - target_cumfreq=FLAGS.target_vocab_cumfreq, - ) - - if FLAGS.cdfg: - FLAGS.use_position_embeddings = False - - model = Ggnn( - vocabulary=vocab, - test_only=True, - node_y_dimensionality=2, - graph_y_dimensionality=0, - graph_x_dimensionality=0, - use_selector_embeddings=True, - ) - checkpoint = pbutil.FromFile(checkpoint_path, checkpoint_pb2.Checkpoint()) - model.RestoreCheckpoint(checkpoint) - - batch = list( - DataflowGgnnBatchBuilder( - graph_loader=SingleGraphLoader(graph=graph, features=features), - vocabulary=vocab, - max_node_size=int(1e9), - use_cdfg=FLAGS.cdfg, - max_batch_count=1, - ) - )[0] - - results = model.RunBatch(epoch_pb2.TEST, batch) - - return AnnotateGraphWithBatchResults(graph, features, results) - - -def AnnotateGraphWithBatchResults( - graph: program_graph_pb2.ProgramGraph, - features: program_graph_features_pb2.ProgramGraphFeatures, - results: BatchResults, -): - """Annotate graph with features describing the target labels and predicted outcomes.""" - assert len(graph.node) == len( - features.node_features.feature_list["data_flow_value"].feature - ) - assert len(graph.node) == len( - features.node_features.feature_list["data_flow_root_node"].feature - ) - assert len(graph.node) == results.targets.shape[0] - - true_y = np.argmax(results.targets, axis=1) - pred_y = np.argmax(results.predictions, axis=1) - - for i, node in enumerate(graph.node): - node.features.feature["data_flow_value"].CopyFrom( - features.node_features.feature_list["data_flow_value"].feature[i] - ) - node.features.feature["data_flow_root_node"].CopyFrom( - features.node_features.feature_list["data_flow_root_node"].feature[i] - ) - node.features.feature["target"].float_list.value[:] = results.targets[i] - node.features.feature["prediction"].float_list.value[:] = results.predictions[i] - node.features.feature["true_y"].int64_list.value.append(true_y[i]) - node.features.feature["pred_y"].int64_list.value.append(pred_y[i]) - node.features.feature["correct"].int64_list.value.append(true_y[i] == pred_y[i]) - - graph.features.feature["loss"].float_list.value.append(results.loss) - graph.features.feature["accuracy"].float_list.value.append(results.accuracy) - graph.features.feature["precision"].float_list.value.append(results.precision) - graph.features.feature["recall"].float_list.value.append(results.recall) - graph.features.feature["f1"].float_list.value.append(results.f1) - graph.features.feature["confusion_matrix"].int64_list.value[:] = np.hstack( - results.confusion_matrix - ) - return graph - - -def Main(): - """Main entry point.""" - dataflow.PatchWarnings() - - features_list_path, features_list_index = FLAGS.input.split(":") - graph = TestOne( - features_list_path=pathlib.Path(features_list_path), - features_list_index=int(features_list_index), - checkpoint_path=FLAGS.model, - ) - print(graph) - - -if __name__ == "__main__": - app.Run(Main) diff --git a/programl/task/dataflow/lstm_batch_builder.py b/programl/task/dataflow/lstm_batch_builder.py index 26a807af1..34c495bd1 100644 --- a/programl/task/dataflow/lstm_batch_builder.py +++ b/programl/task/dataflow/lstm_batch_builder.py @@ -17,7 +17,7 @@ from typing import Dict, Optional import numpy as np -import tensorflow as tf +from keras_preprocessing.sequence import pad_sequences from labm8.py import app from programl.graph.format.py import graph_serializer @@ -50,12 +50,12 @@ def __init__( # Mutable state. self.graph_node_sizes = [] self.vocab_ids = [] - self.selector_vectors = [] + self.selector_ids = [] self.targets = [] # Padding values. self._vocab_id_pad = len(self.vocabulary) + 1 - self._selector_vector_pad = np.zeros((0, 2), dtype=np.int32) + self._selector_id_pad = 0 self._node_label_pad = np.zeros((0, self.node_y_dimensionality), dtype=np.int32) # Call super-constructor last since it starts the worker thread. @@ -74,14 +74,16 @@ def _Build(self) -> BatchData: self.vocab_ids += [ np.array([self._vocab_id_pad], dtype=np.int32) ] * pad_count - self.selector_vectors += [self._selector_vector_pad] * pad_count + self.selector_ids += [ + np.array([self._selector_id_pad], dtype=np.int32) + ] * pad_count self.targets += [self._node_label_pad] * pad_count batch = BatchData( graph_count=len(self.graph_node_sizes), model_data=LstmBatchData( graph_node_sizes=np.array(self.graph_node_sizes, dtype=np.int32), - encoded_sequences=tf.compat.v1.keras.preprocessing.sequence.pad_sequences( + encoded_sequences=pad_sequences( self.vocab_ids, maxlen=self.padded_sequence_length, dtype="int32", @@ -89,15 +91,15 @@ def _Build(self) -> BatchData: truncating="post", value=self._vocab_id_pad, ), - selector_vectors=tf.compat.v1.keras.preprocessing.sequence.pad_sequences( - self.selector_vectors, + selector_ids=pad_sequences( + self.selector_ids, maxlen=self.padded_sequence_length, - dtype="float32", + dtype="int32", padding="pre", truncating="post", - value=np.zeros(2, dtype=np.float32), + value=self._selector_id_pad, ), - node_labels=tf.compat.v1.keras.preprocessing.sequence.pad_sequences( + node_labels=pad_sequences( self.targets, maxlen=self.padded_sequence_length, dtype="float32", @@ -113,7 +115,7 @@ def _Build(self) -> BatchData: # Reset mutable state. self.graph_node_sizes = [] self.vocab_ids = [] - self.selector_vectors = [] + self.selector_ids = [] self.targets = [] return batch @@ -139,7 +141,7 @@ def OnItem(self, item) -> Optional[BatchData]: ) for n in node_list ] - selector_values = np.array( + selector_ids = np.array( [ features.node_features.feature_list["data_flow_root_node"] .feature[n] @@ -148,10 +150,7 @@ def OnItem(self, item) -> Optional[BatchData]: ], dtype=np.int32, ) - selector_vectors = np.zeros((selector_values.size, 2), dtype=np.float32) - selector_vectors[ - np.arange(selector_values.size), selector_values - ] = FLAGS.selector_embedding_value + # TODO: FLAGS.selector_embedding_value targets = np.array( [ features.node_features.feature_list["data_flow_value"] @@ -171,7 +170,7 @@ def OnItem(self, item) -> Optional[BatchData]: self.graph_node_sizes.append(len(node_list)) self.vocab_ids.append(vocab_ids) - self.selector_vectors.append(selector_vectors) + self.selector_ids.append(selector_ids) self.targets.append(targets_1hot) if len(self.graph_node_sizes) >= self.batch_size: diff --git a/programl/task/dataflow/train_lstm.py b/programl/task/dataflow/train_lstm.py index 6f3ee3d5b..55c891dea 100644 --- a/programl/task/dataflow/train_lstm.py +++ b/programl/task/dataflow/train_lstm.py @@ -114,6 +114,9 @@ def TrainDataflowLSTM( vocabulary=vocab, test_only=False, node_y_dimensionality=2, + graph_y_dimensionality=0, + graph_x_dimensionality=0, + use_selector_embeddings=True, ) if restore_from: @@ -132,8 +135,6 @@ def TrainDataflowLSTM( model.Initialize() start_epoch_step, start_graph_cumsum = 1, 0 - model.model.summary() - # Create training batches and split into epochs. epochs = EpochBatchIterator( MakeBatchBuilder( diff --git a/requirements.txt b/requirements.txt index 3baff297e..122ac582c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ decorator==4.3.0 gast==0.2.2. # Dependency of tensorflow. GPUtil==1.4.0 Keras==2.3.1 +keras_preprocessing >= 1.1.1, < 1.2 kiwisolver==1.0.1 # Needed by matplotlib. joblib>=0.16.0 # Needed by scikit-learn matplotlib==2.2.0rc1 diff --git a/third_party/py/keras_preprocessing/BUILD b/third_party/py/keras_preprocessing/BUILD new file mode 100644 index 000000000..d9fe238bd --- /dev/null +++ b/third_party/py/keras_preprocessing/BUILD @@ -0,0 +1,16 @@ +# A wrapper around pip package to pull in undeclared dependencies. + +load("@programl_requirements//:requirements.bzl", "requirement") + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT + +py_library( + name = "keras_preprocessing", + srcs = ["//third_party/py:empty.py"], + deps = [ + requirement("keras_preprocessing"), + "//third_party/py/numpy", + ], +)