From 138f7687db527442b7038ee3d4619d6e9241cb51 Mon Sep 17 00:00:00 2001 From: Georg Rutishauser Date: Fri, 9 Feb 2024 17:07:51 +0100 Subject: [PATCH 1/5] add capability to harmonize post training --- .../fx_integerization/integerize_pactnets.py | 159 +++++++++++++----- systems/CIFAR10/ResNet/quantize/__init__.py | 1 + systems/CIFAR10/ResNet/quantize/pact.py | 52 ++++-- systems/CIFAR10/ResNet/resnet.py | 3 +- systems/ILSVRC12/MobileNetV2/quantize/pact.py | 15 +- 5 files changed, 163 insertions(+), 67 deletions(-) diff --git a/examples/fx_integerization/integerize_pactnets.py b/examples/fx_integerization/integerize_pactnets.py index 1015d5e..ef3d87f 100644 --- a/examples/fx_integerization/integerize_pactnets.py +++ b/examples/fx_integerization/integerize_pactnets.py @@ -1,23 +1,23 @@ -# +# # integerize_pactnets.py -# +# # Author(s): # Georg Rutishauser -# +# # Copyright (c) 2020-2021 ETH Zurich. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# +# import argparse @@ -77,9 +77,17 @@ # import the DORY backend from quantlib.backends.dory import export_net, export_dvsnet, DORYHarmonizePass # import the PACT/TQT integerization pass -from quantlib.editing.fx.passes.pact import IntegerizePACTNetPass +from quantlib.editing.fx.passes.pact import HarmonizePACTNetPass, IntegerizePACTNetPass, PACT_symbolic_trace from quantlib.editing.fx.util import module_of_node +from quantlib.editing.fx.passes import RetracePass from quantlib.algorithms.pact.pact_ops import * +from quantlib.algorithms.pact.pact_controllers import * + +def read_json(filename : str): + with open(filename, 'r') as fp: + the_dict = json.load(fp) + return the_dict + # organize quantization functions, datasets and transforms by network @dataclass class QuantUtil: @@ -130,7 +138,7 @@ def get_valid_dataset(key : str, cfg : dict, quantize : str, pad_img : Optional[ _QUANT_UTILS = { 'VGG': QuantUtil(problem='CIFAR10', topo='VGG', quantize=quantize_vgg, get_controllers=controllers_vgg, network=VGG, in_shape=(1,3,32,32), eps_in=_CIFAR10_EPS, D=2**19, bs=256, get_in_shape=None, load_dataset_fn=load_cifar10, transform=CIFAR10PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=150000), 'MobileNetV1': QuantUtil(problem='ILSVRC12', topo='MobileNetV1', quantize=quantize_mnv1, get_controllers=controllers_mnv1, network=MobileNetV1, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=96, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=135000), - 'MobileNetV2': QuantUtil(problem='ILSVRC12', topo='MobileNetV2', quantize=quantize_mnv2, get_controllers=controllers_mnv2, network=MobileNetV2, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=53, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=150000), + 'MobileNetV2': QuantUtil(problem='ILSVRC12', topo='MobileNetV2', quantize=quantize_mnv2, get_controllers=controllers_mnv2, network=MobileNetV2, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=43, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=150000), 'MobileNetV3': QuantUtil(problem='ILSVRC12', topo='MobileNetV3', quantize=quantize_mnv3, get_controllers=controllers_mnv3, network=MobileNetV3, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=53, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=150000), 'ResNet': QuantUtil(problem='ILSVRC12', topo='ResNet', quantize=quantize_resnet, get_controllers=controllers_resnet, network=ResNet, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=53, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=160000), 'ResNetCIFAR': QuantUtil(problem='CIFAR10', topo='ResNet', quantize=quantize_resnet_cifar, get_controllers=controllers_resnet_cifar, network=ResNetCIFAR, in_shape=(1,3,32,32), eps_in=_CIFAR10_EPS, D=2**19, bs=128, get_in_shape=None, load_dataset_fn=load_cifar10, transform=CIFAR10PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=110000), @@ -159,7 +167,7 @@ def get_ckpt(key : str, exp_id : int, ckpt_id : Union[int, str]): ckpt_filepath = get_topology_dir(key).joinpath(f'logs/exp{exp_id:04}/fold0/saves/{ckpt_str}.ckpt') return torch.load(ckpt_filepath) -def get_network(key : str, exp_id : int, ckpt_id : Union[int, str], quantized=False): +def get_network(key : str, exp_id : int, ckpt_id : Union[int, str], quantized : bool = False, harmonize_cfg : str = None): cfg = get_config(key, exp_id) qu = _QUANT_UTILS[key] quant_cfg = cfg['network']['quantize']['kwargs'] @@ -185,11 +193,51 @@ def get_network(key : str, exp_id : int, ckpt_id : Union[int, str], quantized=Fa for ctrl, sd in zip(qctrls, ckpt['qnt_ctrls']): ctrl.load_state_dict(sd) - # we don't want to train this network anymore - return quant_net.eval() - -def get_dataloader(key : str, cfg : dict, quantize : str, pad_img : Optional[int] = None, clip : bool = False): + quant_net = quant_net.eval() + # if requested, harmonize the net + if harmonize_cfg is not None: + hc = read_json(harmonize_cfg) + dl = get_dataloader(key, cfg, 'none', shuffle=True) + try: + init_clip_lo = cfg['training']['quantize']['kwargs_activation']['init_clip_lo'] + except KeyError: + init_clip_lo = -1. + try: + init_clip_hi = cfg['training']['quantize']['kwargs_activation']['init_clip_hi'] + except KeyError: + init_clip_hi = -1. + + quant_net = harmonize_network(quant_net, dl, hc, init_clip_lo, init_clip_hi) + return quant_net + +def harmonize_network(net : nn.Module, dl : torch.utils.data.DataLoader, harmonize_cfg : dict, init_clip_lo : float = -1., init_clip_hi : float = 1.): + print("Harmonizing trained network...") + harmonize_pass = HarmonizePACTNetPass(**harmonize_cfg) + net_traced = PACT_symbolic_trace(net) + # get all acts prior to harmonization so we can make a controller that only + # cares about the ones inserted by the harmonization pass + pre_harmonize_acts = PACTActController.get_modules(net_traced) + harmonized_net = harmonize_pass(net_traced) + print("Harmonization done!") + post_harmonize_acts = PACTActController.get_modules(harmonized_net) + harmonize_acts = [a for a in post_harmonize_acts if a not in pre_harmonize_acts] + act_schedule = {"0":["verbose_on", "start"]} + act_ctrl = PACTActController(harmonize_acts, act_schedule, init_clip_lo=init_clip_lo, init_clip_hi=init_clip_hi) + int_modules = PACTIntegerModulesController.get_modules(harmonized_net) + int_ctrl = PACTIntegerModulesController(int_modules) + #import ipdb; ipdb.set_trace() + print("Calibrating activations inserted by harmonization with validation set...") + validate(harmonized_net, dl, n_valid_batches=50) + + act_ctrl.step_pre_training_epoch(0) + int_ctrl.step_pre_validation_epoch() + print("Done!") + return harmonized_net + + + +def get_dataloader(key : str, cfg : dict, quantize : str, pad_img : Optional[int] = None, clip : bool = False, shuffle : bool = False): qu = _QUANT_UTILS[key] if torch.cuda.is_available(): bs = torch.cuda.device_count() * qu.bs @@ -197,7 +245,7 @@ def get_dataloader(key : str, cfg : dict, quantize : str, pad_img : Optional[int # network will be executed on CPU (not recommended!!) bs = 16 ds = get_valid_dataset(key, cfg, quantize, pad_img=pad_img, clip=clip) - return torch.utils.data.DataLoader(ds, bs) + return torch.utils.data.DataLoader(ds, bs, shuffle=shuffle) def validate(net : nn.Module, dl : torch.utils.data.DataLoader, print_interval : int = 10, n_valid_batches : int = None): @@ -234,7 +282,7 @@ def get_input_channels(net : fx.GraphModule): return conv.in_channels # THIS IS WHERE THE BUSINESS HAPPENS! -def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_harmonize : bool, word_align_channels : bool, requant_node : bool = False): +def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_harmonize : bool, word_align_channels : bool, requant_node : bool = False, ternarize : bool = False): qu = _QUANT_UTILS[key] # All we need to do to integerize a fake-quantized network is to run the # IntegerizePACTNetPass on it! Afterwards, the ONNX graph it produces will @@ -245,13 +293,30 @@ def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_har in_shp_cnn = (in_shp[0], in_shp[1]//net.tcn_window, in_shp[2], in_shp[3]) in_shp_tcn = (1, net.tcn.features[0].in_channels, net.tcn_window) tcn_eps_in = net.cnn.features[-1].get_eps() - cnn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_cnn, eps_in=qu.eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, word_align_channels=word_align_channels) + cnn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_cnn, eps_in=qu.eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, word_align_channels=word_align_channels, ternarize=ternarize) cnn_int = cnn_int_pass(net.cnn) - tcn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_tcn, eps_in=tcn_eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels) - net.tcn.classifier = get_new_classifier(net.tcn.classifier) - net.tcn.cls_replaced = True + tcn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_tcn, eps_in=tcn_eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, ternarize=ternarize) + #net.tcn.classifier = get_new_classifier(net.tcn.classifier) + #net.tcn.cls_replaced = True tcn_int = tcn_int_pass(net.tcn) + tcn_nodes = list(tcn_int.graph.nodes) + for n in tcn_nodes[::-1]: + if n.op == 'call_method' and 'squeeze' in n.target: + squeeze_node = n + if n.op == 'call_module': + cls_node = n + break + tcn_int.__setattr__(cls_node.target, get_new_classifier(module_of_node(tcn_int, cls_node))) + squeeze_node.replace_all_uses_with(squeeze_node.all_input_nodes[0]) + tcn_int.graph.erase_node(squeeze_node) + tcn_int.recompile() + + cnn_retracer = RetracePass(PACT_symbolic_trace) + tcn_retracer = RetracePass(PACT_symbolic_trace) + # dissolve the "Module"s that FX makes out of sequentials + cnn_int = cnn_retracer(cnn_int) + tcn_int = tcn_retracer(tcn_int) if fix_channels: in_shp_l_cnn = list(in_shp_cnn) in_shp_l_cnn[1] = get_input_channels(cnn_int) @@ -285,13 +350,13 @@ def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_har return int_net -def export_integerized_network(net : nn.Module, cfg : dict, key : str, export_dir : str, name : str, in_idx : int = 42, pad_img : Optional[int] = None, clip : bool = False, change_n_levels : int = None): +def export_integerized_network(net : nn.Module, cfg : dict, key : str, export_dir : str, name : str, in_idx : int = 42, pad_img : Optional[int] = None, clip : bool = False, change_n_levels : int = None, ternarize : bool = False): qu = _QUANT_UTILS[key] # use a real image from the validation set ds = get_valid_dataset(key, cfg, quantize='int', pad_img=pad_img, clip=clip) test_input = ds[in_idx][0].unsqueeze(0) if key == 'dvs_cnn': - qu.export_fn(*net, name=name, out_dir=export_dir, eps_in=qu.eps_in, integerize=False, D=qu.D, in_data=test_input, change_n_levels=change_n_levels, code_size=qu.code_size) + qu.export_fn(*net, name=name, out_dir=export_dir, eps_in=qu.eps_in, integerize=False, D=qu.D, in_data=test_input, change_n_levels=change_n_levels, code_size=qu.code_size, compressed=ternarize) else: qu.export_fn(net, name=name, out_dir=export_dir, eps_in=qu.eps_in, integerize=False, D=qu.D, in_data=test_input, code_size=qu.code_size) @@ -313,32 +378,29 @@ def export_unquant_net(net : nn.Module, cfg : dict, key : str, export_dir : str, # quite hacky but there is no other way def get_new_classifier(classifier: PACTConv1d): + new_classifier = nn.Sequential(nn.Flatten(), - PACTLinear( - in_features=classifier.in_channels*classifier.kernel_size[0], - out_features=classifier.out_channels+1, - bias=True, - n_levels=classifier.n_levels, - quantize=classifier.quantize, - init_clip=classifier.init_clip, - learn_clip=classifier.learn_clip, - symm_wts=classifier.symm_wts, - nb_std=classifier.nb_std, - tqt=classifier.tqt, - tqt_beta=classifier.tqt_beta, - tqt_clip_grad=classifier.tqt_clip_grad)) + nn.Linear( + in_features=classifier.in_channels*classifier.kernel_size[0], + out_features=classifier.out_channels, + bias=True)) new_weights = classifier.weight.reshape(classifier.out_channels, -1) - new_weights = torch.cat((new_weights, torch.zeros(new_weights.shape[1]).unsqueeze(0))) + #new_weights = torch.cat((new_weights, torch.zeros(new_weights.shape[1]).unsqueeze(0))) new_classifier[1].weight.data.copy_(new_weights) if classifier.bias is not None: - new_classifier[1].bias.data.copy_(torch.cat((classifier.bias, torch.Tensor([0])))) + #new_classifier[1].bias.data.copy_(torch.cat((classifier.bias, torch.Tensor([0])))) + new_classifier[1].bias.data.copy_(classifier.bias) else: new_classifier[1].bias.data.fill_(0) - new_classifier[1].clip_lo = torch.nn.Parameter(torch.cat((classifier.clip_lo.squeeze(2), -torch.ones(1,1)))) - new_classifier[1].clip_hi = torch.nn.Parameter(torch.cat((classifier.clip_hi.squeeze(2), torch.ones(1,1)))) - new_classifier[1].clipping_params = classifier.clipping_params - new_classifier[1].started = classifier.started + #new_classifier[1].clip_lo = torch.nn.Parameter(torch.cat((classifier.clip_lo.squeeze(2), -torch.ones(1,1)))) + #new_classifier[1].clip_hi = torch.nn.Parameter(torch.cat((classifier.clip_hi.squeeze(2), torch.ones(1,1)))) + #new_classifier[1].clip_lo = torch.nn.Parameter(classifier.clip_lo.squeeze(2), -torch.ones(1,1)) + #new_classifier[1].clip_hi = torch.nn.Parameter(classifier.clip_hi.squeeze(2), torch.ones(1,1)) + #new_classifier[1].clipping_params = classifier.clipping_params + #new_classifier[1].started = classifier.started + # we just don't care anymore and hardcode this. + new_classifier[1].n_levels = classifier.n_levels return new_classifier @@ -375,14 +437,18 @@ def get_new_classifier(classifier: PACTConv1d): help='If supplied, don\'t align averagePool nodes\' associated requantization nodes and replace adders with DORYAdders') parser.add_argument('--change_n_levels', type=int, default=None, help='Only used in DVS128 export - override clipping bound of RequantShift modules of exported networks to this value') + parser.add_argument('--ternarize', action='store_true', help='Use threshold layers in exported ternary nets? Do not use together with change_n_levels!') parser.add_argument('--code_size', type=int, default=None, help="Override the default 'code reserved space' setting") parser.add_argument('--requant_node', action='store_true', help='Export RequantShift nodes instead of mul-add-div sequences in ONNX graph') parser.add_argument('--n_valid_batch', type=int, default=None, help='number of validation batches to run') - + parser.add_argument('--harmonize_cfg', type=str, default=None, + help='Run harmonization on the quantized net with this configuration') + + # export_dvsnet->compressed == ternarize args = parser.parse_args() @@ -394,7 +460,8 @@ def get_new_classifier(classifier: PACTConv1d): exp_id = int(args.exp_id) if args.exp_id.isnumeric() else args.exp_id print(f'Loading network {args.net}, experiment {exp_id}, checkpoint {args.ckpt_id}') - qnet = get_network(args.net, exp_id, args.ckpt_id, quantized=True) + qnet = get_network(args.net, exp_id, args.ckpt_id, quantized=True, harmonize_cfg=args.harmonize_cfg) + exp_cfg = get_config(args.net, exp_id) if args.validate_fq: @@ -404,7 +471,7 @@ def get_new_classifier(classifier: PACTConv1d): print(f'Integerizing network {args.net}') - int_net = integerize_network(qnet, args.net, args.fix_channels, not args.no_dory_harmonize, args.word_align_channels, args.requant_node) + int_net = integerize_network(qnet, args.net, args.fix_channels, not args.no_dory_harmonize, args.word_align_channels, args.requant_node, ternarize=args.ternarize) if args.fix_channels: pad_img = get_input_channels(int_net[0] if isinstance(int_net, tuple) else int_net) @@ -413,12 +480,12 @@ def get_new_classifier(classifier: PACTConv1d): if args.validate_tq: dl = get_dataloader(args.net, exp_cfg, quantize='int', pad_img=pad_img) + print(f'Validating integerized network {args.net} on dataset {get_system(args.net)}') validate(int_net, dl, args.accuracy_print_interval, n_valid_batches=args.n_valid_batch) if args.export_dir is not None: print(f'Exporting integerized network {args.net} to directory {args.export_dir} under name {export_name}') - export_integerized_network(int_net, exp_cfg, args.net, args.export_dir, export_name, pad_img=pad_img, clip=args.clip_inputs, change_n_levels=args.change_n_levels) + export_integerized_network(int_net, exp_cfg, args.net, args.export_dir, export_name, pad_img=pad_img, clip=args.clip_inputs, change_n_levels=args.change_n_levels, ternarize=args.ternarize) if args.export_unquant: net_unq = get_network(args.net, exp_id, args.ckpt_id, quantized=False) export_unquant_net(net_unq, exp_cfg, args.net, args.export_dir, export_name) - diff --git a/systems/CIFAR10/ResNet/quantize/__init__.py b/systems/CIFAR10/ResNet/quantize/__init__.py index 594a488..dde28f7 100644 --- a/systems/CIFAR10/ResNet/quantize/__init__.py +++ b/systems/CIFAR10/ResNet/quantize/__init__.py @@ -1 +1,2 @@ from .pact import * +from .bb import * diff --git a/systems/CIFAR10/ResNet/quantize/pact.py b/systems/CIFAR10/ResNet/quantize/pact.py index 0333339..912aed1 100644 --- a/systems/CIFAR10/ResNet/quantize/pact.py +++ b/systems/CIFAR10/ResNet/quantize/pact.py @@ -29,6 +29,7 @@ import quantlib.editing.lightweight.rules as qlr from quantlib.editing.lightweight.rules.filters import VariadicOrFilter, NameFilter, TypeFilter from quantlib.editing.fx.passes.pact import HarmonizePACTNetPass, PACT_symbolic_trace +from quantlib.editing.fx.util import module_of_node from quantlib.algorithms.pact.pact_ops import * from quantlib.algorithms.pact.pact_controllers import * @@ -45,39 +46,49 @@ def pact_recipe(net : nn.Module, # An additional dict is expected to be stored under the key "kwargs", which # is used as the default kwargs. - filter_conv2d = TypeFilter(nn.Conv2d) - filter_linear = TypeFilter(nn.Linear) - act_types = (nn.ReLU, nn.ReLU6) - filter_acts = VariadicOrFilter(*[TypeFilter(t) for t in act_types]) + uact_types = (nn.ReLU, nn.ReLU6) + sact_types = (nn.Hardtanh,) rhos = [] conv_cfg = config["PACTConv2d"] lin_cfg = config["PACTLinear"] - act_cfg = config["PACTUnsignedAct"] - - harmonize_cfg = config["harmonize"] + uact_cfg = config["PACTUnsignedAct"] + try: + sact_cfg = config["PACTAsymmetricAct"] + except KeyError: + sact_cfg = {} + try: + last_add_8b = config['last_add_8b'] + except KeyError: + last_add_8b = False - def make_rules(cfg : dict, - rule : type): + harmonize_cfg = config["harmonize"] + + def make_rules(cfg : dict, t : tuple, + rule : type, **kwargs): rules = [] default_cfg = cfg["kwargs"] if "kwargs" in cfg.keys() else {} layer_keys = [k for k in cfg.keys() if k != "kwargs"] + type_filter = VariadicOrFilter(*[TypeFilter(tt) for tt in t]) + print("type filter: ", type_filter) for k in layer_keys: - filt = NameFilter(k) - kwargs = default_cfg.copy() + filt = NameFilter(k) & type_filter + kwargs.update(default_cfg) kwargs.update(cfg[k]) rho = rule(filt, **kwargs) rules.append(rho) return rules - rhos += make_rules(conv_cfg, + rhos += make_rules(conv_cfg, (nn.Conv2d,), qlr.pact.ReplaceConvLinearPACTRule) - rhos += make_rules(lin_cfg, + rhos += make_rules(lin_cfg, (nn.Linear,), qlr.pact.ReplaceConvLinearPACTRule) - rhos += make_rules(act_cfg, - qlr.pact.ReplaceActPACTRule) + rhos += make_rules(uact_cfg, uact_types, + qlr.pact.ReplaceActPACTRule, signed=False) + rhos += make_rules(sact_cfg, sact_types, + qlr.pact.ReplaceActPACTRule, signed=True) lwg = qlw.LightweightGraph(net) lwe = qlw.LightweightEditor(lwg) @@ -115,6 +126,17 @@ def make_rules(cfg : dict, final_net = harmonize_pass(net_traced) + if last_add_8b: + for n in [nn for nn in final_net.graph.nodes][::-1]: + + if n.op == 'call_module': + module = module_of_node(final_net, n) + if isinstance(module, PACTIntegerAdd): + outact_node = [k for k in n.users.keys()][0] + outact_module = module_of_node(final_net, outact_node) + print(f"Setting node {outact_node}'s output n_levels attribute to 256!") + outact_module.n_levels = 256 + # the prec. spec file might include layers that were added by the # harmonization pass; those need to be treated separately final_nodes = LightweightGraph.build_nodes_list(final_net) diff --git a/systems/CIFAR10/ResNet/resnet.py b/systems/CIFAR10/ResNet/resnet.py index 2748ebc..a6a5cb8 100644 --- a/systems/CIFAR10/ResNet/resnet.py +++ b/systems/CIFAR10/ResNet/resnet.py @@ -251,12 +251,11 @@ def __init__(self, block_class = _CONFIGS[config]['block_class'] block_cfgs = _CONFIGS[config]['block_cfgs'] do_maxpool = _CONFIGS[config]['maxpool'] - out_channels_pilot = 16 in_planes_features = out_channels_pilot out_planes_features = block_cfgs[-1][1] * block_class.expansion_factor out_channels_features = out_planes_features - self.act_type = nn.ReLU if activation.lower() == 'relu' else nn.ReLU6 + self.act_type = nn.ReLU if activation.lower() == 'relu' else nn.ReLU6 if activation.lower() == 'relu6' else nn.Hardtanh self.pilot = self._make_pilot(out_channels_pilot) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if do_maxpool else None diff --git a/systems/ILSVRC12/MobileNetV2/quantize/pact.py b/systems/ILSVRC12/MobileNetV2/quantize/pact.py index 9b84ccf..45142e7 100644 --- a/systems/ILSVRC12/MobileNetV2/quantize/pact.py +++ b/systems/ILSVRC12/MobileNetV2/quantize/pact.py @@ -73,8 +73,12 @@ def pact_recipe(net : nn.Module, conv_cfg = config["PACTConv2d"] lin_cfg = config["PACTLinear"] act_cfg = config["PACTUnsignedAct"] - - harmonize_cfg = config["harmonize"] + # we may get a config that does not include a harmonization configuration; + # in that case, simply don't harmonize + try: + harmonize_cfg = config["harmonize"] + except KeyError: + harmonize_cfg = None prec_override_spec = {} @@ -124,8 +128,11 @@ def make_rules(cfg : dict, lwe.shutdown() # now harmonize the graph according to the configuration - harmonize_pass = HarmonizePACTNetPass(**harmonize_cfg) - final_net = harmonize_pass(net) + if harmonize_cfg is not None: + harmonize_pass = HarmonizePACTNetPass(**harmonize_cfg) + final_net = harmonize_pass(net) + else: + final_net = net # the prec. spec file might include layers that were added by the # harmonization pass; those need to be treated separately From 16bbcf1ff447b008c5e53e42291d515522d7cb7a Mon Sep 17 00:00:00 2001 From: Georg Rutishauser Date: Tue, 13 Feb 2024 18:08:50 +0100 Subject: [PATCH 2/5] fixes for integerizing thresholding networks --- .../fx_integerization/integerize_pactnets.py | 146 +++++++++++++----- quantlib | 2 +- systems/CIFAR10/ResNet/resnet.py | 10 ++ systems/DVS128/dvs_cnn/dvs_cnn.py | 17 +- systems/ILSVRC12/ResNet/quantize/pact.py | 15 +- 5 files changed, 137 insertions(+), 53 deletions(-) diff --git a/examples/fx_integerization/integerize_pactnets.py b/examples/fx_integerization/integerize_pactnets.py index ef3d87f..78df01a 100644 --- a/examples/fx_integerization/integerize_pactnets.py +++ b/examples/fx_integerization/integerize_pactnets.py @@ -28,7 +28,7 @@ import json import torch from torch import nn, fx - +from copy import deepcopy # set the PYTHONPATH to include QuantLab's root directory import sys @@ -142,7 +142,7 @@ def get_valid_dataset(key : str, cfg : dict, quantize : str, pad_img : Optional[ 'MobileNetV3': QuantUtil(problem='ILSVRC12', topo='MobileNetV3', quantize=quantize_mnv3, get_controllers=controllers_mnv3, network=MobileNetV3, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=53, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=150000), 'ResNet': QuantUtil(problem='ILSVRC12', topo='ResNet', quantize=quantize_resnet, get_controllers=controllers_resnet, network=ResNet, in_shape=(1,3,224,224), eps_in=_ILSVRC12_EPS, D=2**19, bs=53, get_in_shape=None, load_dataset_fn=load_ilsvrc12, transform=ILSVRC12PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=160000), 'ResNetCIFAR': QuantUtil(problem='CIFAR10', topo='ResNet', quantize=quantize_resnet_cifar, get_controllers=controllers_resnet_cifar, network=ResNetCIFAR, in_shape=(1,3,32,32), eps_in=_CIFAR10_EPS, D=2**19, bs=128, get_in_shape=None, load_dataset_fn=load_cifar10, transform=CIFAR10PACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=110000), - 'dvs_cnn' : QuantUtil(problem='DVS128', topo='dvs_cnn', quantize=quantize_dvsnet, get_controllers=controllers_dvsnet, network=DVSHybridNet, network_args={'inject_eps':False}, in_shape=None, eps_in=1., D=2**19, bs=128, get_in_shape=get_in_shape_dvsnet, load_dataset_fn=load_dvs128, transform=DVSAugmentTransform, n_levels_in=3, export_fn=export_dvsnet, code_size=340000), + 'dvs_cnn' : QuantUtil(problem='DVS128', topo='dvs_cnn', quantize=quantize_dvsnet, get_controllers=controllers_dvsnet, network=DVSHybridNet, network_args={'inject_eps':True}, in_shape=None, eps_in=1., D=2**19, bs=128, get_in_shape=get_in_shape_dvsnet, load_dataset_fn=load_dvs128, transform=DVSAugmentTransform, n_levels_in=3, export_fn=export_dvsnet, code_size=340000), 'simpleCNN': QuantUtil(problem='MNIST', topo='simpleCNN', quantize=quantize_simpleCNN, get_controllers=controllers_simpleCNN, network=simpleCNN, in_shape=(1,1,32,32), eps_in=_MNIST_EPS, D=2**19, bs=256, get_in_shape=None, load_dataset_fn=load_mnist, transform=MNISTPACTQuantTransform, quant_transform_args={'n_q':256}, n_levels_in=256, export_fn=export_net, code_size=150000) } @@ -226,7 +226,6 @@ def harmonize_network(net : nn.Module, dl : torch.utils.data.DataLoader, harmoni act_ctrl = PACTActController(harmonize_acts, act_schedule, init_clip_lo=init_clip_lo, init_clip_hi=init_clip_hi) int_modules = PACTIntegerModulesController.get_modules(harmonized_net) int_ctrl = PACTIntegerModulesController(int_modules) - #import ipdb; ipdb.set_trace() print("Calibrating activations inserted by harmonization with validation set...") validate(harmonized_net, dl, n_valid_batches=50) @@ -248,7 +247,7 @@ def get_dataloader(key : str, cfg : dict, quantize : str, pad_img : Optional[int return torch.utils.data.DataLoader(ds, bs, shuffle=shuffle) -def validate(net : nn.Module, dl : torch.utils.data.DataLoader, print_interval : int = 10, n_valid_batches : int = None): +def validate(net : nn.Module, dl : torch.utils.data.DataLoader, print_interval : int = 10, n_valid_batches : int = None, eps_w : torch.Tensor = None): net = net.eval() # we assume that the net is on CPU as this is required for some # integerization passes @@ -262,10 +261,17 @@ def validate(net : nn.Module, dl : torch.utils.data.DataLoader, print_interval : n_tot = 0 n_correct = 0 + + n_cls = len(dl.dataset.classes) + if eps_w is None: + eps_w = torch.ones((1, n_cls)) + else: + eps_w = eps_w.squeeze()[None, :] + for i, (xb, yb) in enumerate(tqdm(dl)): - yn = net(xb.to(device)) + yn = net(xb.to(device)).to('cpu') * eps_w n_tot += xb.shape[0] - n_correct += (yn.to('cpu').argmax(dim=1) == yb).sum() + n_correct += (yn.argmax(dim=1) == yb).sum() if ((i+1)%print_interval == 0): print(f'Accuracy after {i+1} batches: {n_correct/n_tot}') if (i+1) == n_valid_batches: @@ -289,13 +295,16 @@ def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_har # contain only integer operations. Any divisions in the integerized graph # will be by powers of 2 and can be implemented as bit shifts. in_shp = qu.in_shape + net_cp = deepcopy(net) if key == 'dvs_cnn': in_shp_cnn = (in_shp[0], in_shp[1]//net.tcn_window, in_shp[2], in_shp[3]) in_shp_tcn = (1, net.tcn.features[0].in_channels, net.tcn_window) tcn_eps_in = net.cnn.features[-1].get_eps() + tcn_sgnd_in = net.cnn.features[-1].signed + tcn_last_act_eps = net.tcn.features[-1].get_eps() cnn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_cnn, eps_in=qu.eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, word_align_channels=word_align_channels, ternarize=ternarize) cnn_int = cnn_int_pass(net.cnn) - tcn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_tcn, eps_in=tcn_eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, ternarize=ternarize) + tcn_int_pass = IntegerizePACTNetPass(shape_in=in_shp_tcn, eps_in=tcn_eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, ternarize=ternarize, signed_in=tcn_sgnd_in) #net.tcn.classifier = get_new_classifier(net.tcn.classifier) #net.tcn.cls_replaced = True @@ -307,7 +316,7 @@ def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_har if n.op == 'call_module': cls_node = n break - tcn_int.__setattr__(cls_node.target, get_new_classifier(module_of_node(tcn_int, cls_node))) + tcn_int.__setattr__(cls_node.target, get_new_classifier(module_of_node(tcn_int, cls_node), not net.tcn.features[-1].signed, tcn_last_act_eps)) squeeze_node.replace_all_uses_with(squeeze_node.all_input_nodes[0]) tcn_int.graph.erase_node(squeeze_node) tcn_int.recompile() @@ -329,7 +338,12 @@ def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_har cnn_int = dory_harmonize_pass_cnn(cnn_int) dory_harmonize_pass_tcn = DORYHarmonizePass(in_shape=in_shp_tcn) tcn_int = dory_harmonize_pass_tcn(tcn_int) - return cnn_int, tcn_int + + net.cnn = cnn_int + net.tcn = tcn_int + return net, net_cp, net_cp.tcn.classifier.get_eps_w() + #return cnn_int, tcn_int + else: int_pass = IntegerizePACTNetPass(shape_in=in_shp, eps_in=qu.eps_in, D=qu.D, n_levels_in=qu.n_levels_in, fix_channel_numbers=fix_channels, requant_node=requant_node) int_net = int_pass(net) @@ -350,13 +364,88 @@ def integerize_network(net : nn.Module, key : str, fix_channels : bool, dory_har return int_net +# quite hacky but there is no other way + +def get_new_classifier(classifier: PACTConv1d, unsigned_in, eps_in : float = 1.): + + new_classifier = nn.Sequential(nn.Flatten(), + nn.Linear( + in_features=classifier.in_channels*classifier.kernel_size[0], + out_features=classifier.out_channels, + bias=True)) + + new_weights = classifier.weight.reshape(classifier.out_channels, -1) + #new_weights = torch.cat((new_weights, torch.zeros(new_weights.shape[1]).unsqueeze(0))) + new_classifier[1].weight.data.copy_(new_weights) + new_bias = torch.zeros_like(new_classifier[1].bias) + if unsigned_in: + new_bias += torch.round(new_classifier[1].weight.sum(dim=1))# * eps_in) + if classifier.bias is not None: + new_bias += classifier.bias + new_classifier[1].bias.data.copy_(new_bias) + new_classifier[1].n_levels = classifier.n_levels + return new_classifier + +def compare_nets(int_net, fq_net, dl): + #indata = dl.dataset[42][0][None, ...] + cnn_win = fq_net.cnn.adapter[0].in_channels + for j, (indata, _) in tqdm(enumerate(dl)): + in_windows = torch.split(indata, cnn_win, dim=1) + cnn_outs_fq, cnn_outs_tq = [], [] + eps_w =fq_net.tcn.classifier.get_eps_w().squeeze()[None, :] + def rebuild_subnet(fx_net, subnet_name): + subnet_nodes = [n for n in fx_net.graph.nodes if subnet_name in n.name and n.op == 'call_module'] + subnet_modules = [module_of_node(fx_net, n) for n in subnet_nodes] + return nn.Sequential(*subnet_modules) + for i,win in enumerate(in_windows): + ad_fq = fq_net.cnn.adapter + #ad_int = int_net.cnn.adapter + ad_int = rebuild_subnet(int_net.cnn, 'adapter') + ad_out_fq = ad_fq(win) + ad_out_int = ad_int(win) + ad_out_int_fq = (ad_out_int + 1) * ad_fq[-1].get_eps() + if not torch.all(ad_out_int_fq == ad_out_fq): + print(f"failure in adapter, iteration {i}") + f_int = rebuild_subnet(int_net.cnn, 'features') + f_fq = fq_net.cnn.features + f_out_int = f_int(ad_out_int) + f_out_fq = f_fq(ad_out_fq) + f_out_intfq = (f_out_int+1)*f_fq[-1].get_eps() + if not torch.all(f_out_intfq == f_out_fq): + print(f"failure in features, iteration {i}") + cnn_outs_fq.append(f_out_fq.flatten(start_dim=1)) + cnn_outs_tq.append(f_out_int.flatten(start_dim=1)) + + tcn_in_fq = torch.stack(cnn_outs_fq, dim=2) + tcn_in_int = torch.stack(cnn_outs_tq, dim=2) + tcn_f_fq = fq_net.tcn.features + tcn_f_int = rebuild_subnet(int_net.tcn, 'features') + tcn_fout_fq = tcn_f_fq(tcn_in_fq) + tcn_fout_int = tcn_f_int(tcn_in_int) + tcn_fout_intfq = (tcn_fout_int+1)*tcn_f_fq[-1].get_eps() + if not torch.all(tcn_fout_intfq == tcn_fout_fq): + print(f"failure in tcn features") + tcn_fout_int_fl = tcn_fout_int.flatten(start_dim=1) + tcn_out_fq = fq_net.tcn.classifier(tcn_fout_fq) + tcn_cls = int_net.tcn._QL_REPLACED__INTEGERIZE_PACT_CONV1D_PASS_0.get_submodule('1') + tcn_out_int = tcn_cls(tcn_fout_int_fl) * eps_w + if not torch.all(tcn_out_int.argmax(dim=1) == tcn_out_fq.squeeze().argmax(dim=1)): + print(f"{torch.sum(tcn_out_int.argmax(dim=1) != tcn_out_fq.squeeze().argmax(dim=1))} failures in result argmax!") + net_out_int = int_net(indata) + net_out_fq = fq_net(indata) + + def export_integerized_network(net : nn.Module, cfg : dict, key : str, export_dir : str, name : str, in_idx : int = 42, pad_img : Optional[int] = None, clip : bool = False, change_n_levels : int = None, ternarize : bool = False): qu = _QUANT_UTILS[key] # use a real image from the validation set ds = get_valid_dataset(key, cfg, quantize='int', pad_img=pad_img, clip=clip) test_input = ds[in_idx][0].unsqueeze(0) if key == 'dvs_cnn': - qu.export_fn(*net, name=name, out_dir=export_dir, eps_in=qu.eps_in, integerize=False, D=qu.D, in_data=test_input, change_n_levels=change_n_levels, code_size=qu.code_size, compressed=ternarize) + #qu.export_fn(*net, name=name, out_dir=export_dir, eps_in=qu.eps_in, + #integerize=False, D=qu.D, in_data=test_input, + #change_n_levels=change_n_levels, code_size=qu.code_size, + #compressed=ternarize) + qu.export_fn(net.cnn, net.tcn, name=name, out_dir=export_dir, eps_in=qu.eps_in, integerize=False, D=qu.D, in_data=test_input, change_n_levels=change_n_levels, code_size=qu.code_size, compressed=ternarize) else: qu.export_fn(net, name=name, out_dir=export_dir, eps_in=qu.eps_in, integerize=False, D=qu.D, in_data=test_input, code_size=qu.code_size) @@ -375,33 +464,6 @@ def export_unquant_net(net : nn.Module, cfg : dict, key : str, export_dir : str, do_constant_folding=True) -# quite hacky but there is no other way - -def get_new_classifier(classifier: PACTConv1d): - - new_classifier = nn.Sequential(nn.Flatten(), - nn.Linear( - in_features=classifier.in_channels*classifier.kernel_size[0], - out_features=classifier.out_channels, - bias=True)) - - new_weights = classifier.weight.reshape(classifier.out_channels, -1) - #new_weights = torch.cat((new_weights, torch.zeros(new_weights.shape[1]).unsqueeze(0))) - new_classifier[1].weight.data.copy_(new_weights) - if classifier.bias is not None: - #new_classifier[1].bias.data.copy_(torch.cat((classifier.bias, torch.Tensor([0])))) - new_classifier[1].bias.data.copy_(classifier.bias) - else: - new_classifier[1].bias.data.fill_(0) - #new_classifier[1].clip_lo = torch.nn.Parameter(torch.cat((classifier.clip_lo.squeeze(2), -torch.ones(1,1)))) - #new_classifier[1].clip_hi = torch.nn.Parameter(torch.cat((classifier.clip_hi.squeeze(2), torch.ones(1,1)))) - #new_classifier[1].clip_lo = torch.nn.Parameter(classifier.clip_lo.squeeze(2), -torch.ones(1,1)) - #new_classifier[1].clip_hi = torch.nn.Parameter(classifier.clip_hi.squeeze(2), torch.ones(1,1)) - #new_classifier[1].clipping_params = classifier.clipping_params - #new_classifier[1].started = classifier.started - # we just don't care anymore and hardcode this. - new_classifier[1].n_levels = classifier.n_levels - return new_classifier @@ -471,7 +533,11 @@ def get_new_classifier(classifier: PACTConv1d): print(f'Integerizing network {args.net}') - int_net = integerize_network(qnet, args.net, args.fix_channels, not args.no_dory_harmonize, args.word_align_channels, args.requant_node, ternarize=args.ternarize) + int_net= integerize_network(qnet, args.net, args.fix_channels, not args.no_dory_harmonize, args.word_align_channels, args.requant_node, ternarize=args.ternarize) + #integerizing dvs_cnn also returns the fq network and the last layer's + #weight epsilons + if args.net == 'dvs_cnn': + int_net, fq_net, eps_w = int_net if args.fix_channels: pad_img = get_input_channels(int_net[0] if isinstance(int_net, tuple) else int_net) @@ -481,7 +547,9 @@ def get_new_classifier(classifier: PACTConv1d): if args.validate_tq: dl = get_dataloader(args.net, exp_cfg, quantize='int', pad_img=pad_img) print(f'Validating integerized network {args.net} on dataset {get_system(args.net)}') - validate(int_net, dl, args.accuracy_print_interval, n_valid_batches=args.n_valid_batch) + # uncomment to debug DVS CNN + #compare_nets(int_net, fq_net, dl) + validate(int_net, dl, args.accuracy_print_interval, n_valid_batches=args.n_valid_batch, eps_w=eps_w) if args.export_dir is not None: print(f'Exporting integerized network {args.net} to directory {args.export_dir} under name {export_name}') diff --git a/quantlib b/quantlib index ba13b49..cfcd780 160000 --- a/quantlib +++ b/quantlib @@ -1 +1 @@ -Subproject commit ba13b4957bd23c54d94b5aae78457b78341e76bf +Subproject commit cfcd7809a32e4a4137c90fffb6a9fc6f3c755e7a diff --git a/systems/CIFAR10/ResNet/resnet.py b/systems/CIFAR10/ResNet/resnet.py index a6a5cb8..a129cd4 100644 --- a/systems/CIFAR10/ResNet/resnet.py +++ b/systems/CIFAR10/ResNet/resnet.py @@ -203,6 +203,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: 'block_cfgs': [( 1, 16, 1), ( 1, 32, 2), ( 1, 64, 2)], + 'maxpool': False}, + 'ResNet8_tb': {'block_class': BasicBlock, + 'block_cfgs': [( 1, 20, 1), + ( 1, 40, 2), + ( 1, 80, 2)], + 'maxpool': False}, + 'ResNet8_ts': {'block_class': BasicBlock, + 'block_cfgs': [( 1, 15, 1), + ( 1, 30, 2), + ( 1, 60, 2)], 'maxpool': False}, 'ResNet0': {'block_class': NonResidualBlock, 'block_cfgs': [( 1, 16, 2), diff --git a/systems/DVS128/dvs_cnn/dvs_cnn.py b/systems/DVS128/dvs_cnn/dvs_cnn.py index 235d73a..bb4b459 100644 --- a/systems/DVS128/dvs_cnn/dvs_cnn.py +++ b/systems/DVS128/dvs_cnn/dvs_cnn.py @@ -12,6 +12,7 @@ __CNN_CFGS__ = { 'first_try' : [128, 128, 128, 128], 'ninetysix_ch' : [96, 96, 96, 96], + 'eighty_ch' : [80, 80, 80, 80], 'reduced_channels' : [64, 64, 64, 64], '128_channels' : [128, 128, 128, 128], '96_channels' : [96, 96, 96, 96], @@ -27,12 +28,13 @@ '64_channels' : [(2, 1, 64), (2, 2, 64), (2, 4, 64)], '64_channels_k3' : [(3, 1, 64), (3, 2, 64), (3, 4, 64)], '64_channels_k4' : [(4, 1, 64), (4, 2, 64), (4, 4, 64)], - 'ninetysix_ch' : [(2, 1, 96), (2, 2, 96), (2, 4, 96)], + 'eighty_ch' : [(2, 1, 80), (2, 2, 80), (2, 4, 80)], '128_ch' : [(2, 1, 128), (2, 2, 128), (2, 4, 128)], '128_channels' : [(2, 1, 128), (2, 2, 128), (2, 4, 128)], 'k3' : [(3, 1, 64), (3, 2, 64), (3, 4, 64)], '96_channels' : [(2, 1, 96), (2, 2, 96), (2, 4, 96)], '96_channels_k3' : [(3, 1, 96), (3, 2, 96), (3, 4, 96)], + '80_channels_k3' : [(3, 1, 80), (3, 2, 80), (3, 4, 80)], '96_channels_k4' : [(4, 1, 96), (4, 2, 96), (4, 4, 96)], '32_channels' : [(2, 1, 32), (2, 2, 32), (2, 4, 32)], '32_channels_k3' : [(3, 1, 32), (3, 2, 32), (3, 4, 32)], @@ -51,7 +53,7 @@ def get_input_shape(cfg : dict): class DVSNet2D(nn.Module): def __init__(self, cnn_cfg_key : str, pool_type : str = "stride", cnn_window : int = 16, activation : str = 'relu', - out_size : int = 11, use_classifier : bool = True, fix_cnn_pool=False, k : int = 3, layer_order : str = 'pool_bn', last_conv_nopad : bool = False, **kwargs): + out_size : int = 11, use_classifier : bool = True, fix_cnn_pool=False, k : int = 3, layer_order : str = 'pool_bn', last_conv_nopad : bool = False, adapter_out_ch : int = 32, **kwargs): super(DVSNet2D, self).__init__() cfg = __CNN_CFGS__[cnn_cfg_key] self.k = k @@ -69,22 +71,22 @@ def __init__(self, cnn_cfg_key : str, pool_type : str = "stride", cnn_window : i adapter_list = [] - adapter_list.append(nn.Conv2d(cnn_window, 32, kernel_size=k, padding=k//2, bias=False)) + adapter_list.append(nn.Conv2d(cnn_window, adapter_out_ch, kernel_size=k, padding=k//2, bias=False)) if pool_type != 'max_pool': adapter_pool = nn.AvgPool2d(kernel_size=2) else: adapter_pool = nn.MaxPool2d(kernel_size=2) if layer_order == 'pool_bn': adapter_list.append(adapter_pool) - adapter_list.append(nn.BatchNorm2d(32)) + adapter_list.append(nn.BatchNorm2d(adapter_out_ch)) else: - adapter_list.append(nn.BatchNorm2d(32)) + adapter_list.append(nn.BatchNorm2d(adapter_out_ch)) adapter_list.append(adapter_pool) adapter_list.append(self._act(inplace=True)) adapter = nn.Sequential(*adapter_list) self.adapter = adapter - features = self._get_features(32, cfg, k, pool_type, self._act, layer_order, last_conv_nopad) + features = self._get_features(adapter_out_ch, cfg, k, pool_type, self._act, layer_order, last_conv_nopad) self.features = features # after features block, we should have a 4x4 feature map if use_classifier: @@ -282,8 +284,7 @@ def forward(self, x): # 1. split it up into cnn_window-sized stacks if self.inject_eps: - #x = QTensor(x, eps=1.) - pass + x = QTensor(x, eps=1.) #print(f"type of x - hybridnet: {type(x)}") cnn_wins = torch.split(x, self.cnn_window, dim=1) #print(f"eps of x - hybridnet after split: {tuple(w.eps for w in cnn_wins)}") diff --git a/systems/ILSVRC12/ResNet/quantize/pact.py b/systems/ILSVRC12/ResNet/quantize/pact.py index 9be1886..611f186 100644 --- a/systems/ILSVRC12/ResNet/quantize/pact.py +++ b/systems/ILSVRC12/ResNet/quantize/pact.py @@ -51,7 +51,10 @@ def pact_recipe(net : nn.Module, lin_cfg = config["PACTLinear"] act_cfg = config["PACTUnsignedAct"] - harmonize_cfg = config["harmonize"] + try: + harmonize_cfg = config["harmonize"] + except KeyError: + harmonize_cfg = None def make_rules(cfg : dict, rule : type): @@ -82,10 +85,12 @@ def make_rules(cfg : dict, lwe.apply() lwe.shutdown() # now harmonize the graph - harmonize_pass = HarmonizePACTNetPass(**harmonize_cfg) - #harmonize_pass = HarmonizePACTNetPass(n_levels=harmonize_cfg["n_levels"]) - net_traced = PACT_symbolic_trace(lwg.net) - final_net = harmonize_pass(net_traced) + if harmonize_cfg is not None: + harmonize_pass = HarmonizePACTNetPass(**harmonize_cfg) + net_traced = PACT_symbolic_trace(lwg.net) + final_net = harmonize_pass(net_traced) + else: + final_net = lwg.net return final_net From 5e61c27313024329b273e62386ad30224b818727 Mon Sep 17 00:00:00 2001 From: Georg Rutishauser Date: Wed, 14 Feb 2024 12:10:10 +0100 Subject: [PATCH 3/5] fix bug in integerize_pactnets.py that crashed export --- examples/fx_integerization/integerize_pactnets.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/fx_integerization/integerize_pactnets.py b/examples/fx_integerization/integerize_pactnets.py index 78df01a..70f003f 100644 --- a/examples/fx_integerization/integerize_pactnets.py +++ b/examples/fx_integerization/integerize_pactnets.py @@ -538,6 +538,8 @@ def export_unquant_net(net : nn.Module, cfg : dict, key : str, export_dir : str, #weight epsilons if args.net == 'dvs_cnn': int_net, fq_net, eps_w = int_net + else: + eps_w = None if args.fix_channels: pad_img = get_input_channels(int_net[0] if isinstance(int_net, tuple) else int_net) From 49e4d41a0a7ea8e8a6977e36d70dd66683dc0742 Mon Sep 17 00:00:00 2001 From: Georg Rutishauser Date: Tue, 5 Mar 2024 16:31:14 +0100 Subject: [PATCH 4/5] bump quantlib to final PR version --- quantlib | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantlib b/quantlib index cfcd780..0f536ba 160000 --- a/quantlib +++ b/quantlib @@ -1 +1 @@ -Subproject commit cfcd7809a32e4a4137c90fffb6a9fc6f3c755e7a +Subproject commit 0f536ba243936a1334eb1bb2c64fab2fbf30d661 From 5cb7392645fdb6f14b5c78c758506ccd42e8eb93 Mon Sep 17 00:00:00 2001 From: Georg Rutishauser Date: Tue, 5 Mar 2024 17:01:20 +0100 Subject: [PATCH 5/5] bump quantlib again --- quantlib | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quantlib b/quantlib index 0f536ba..4ba0a49 160000 --- a/quantlib +++ b/quantlib @@ -1 +1 @@ -Subproject commit 0f536ba243936a1334eb1bb2c64fab2fbf30d661 +Subproject commit 4ba0a49bdef8a330bac0e6669bf90181a713f330