diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tensorflow/contrib/linalg/bvls.py b/tensorflow/contrib/linalg/bvls.py new file mode 100644 index 00000000000000..a040d8b7885deb --- /dev/null +++ b/tensorflow/contrib/linalg/bvls.py @@ -0,0 +1,536 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bounded-Variables Least-Squares operations.""" + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import map_fn +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import numerics +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.linalg import linalg_impl + + +def tf_bvls_input_validation(matrix, rhs, lower_bounds, upper_bounds): + """ + Perform basic sanity check on the inputs to the BVLS algorithm. + :param matrix: Matrix of the least square regression. + :param rhs: Right hand side of the least square regression. + :param lower_bounds: Lower bounds of the regression variables. + :param upper_bounds: Upper bounds on the regression variables. + """ + + # TODO later with tf.assert_shapes + # assert lower_bounds.shape == upper_bounds.shape + # assert matrix.shape[-2] == rhs.shape[-1] + # assert matrix.shape[-1] == lower_bounds.shape[-1] + + # Check that the lower bounds are smaller than the upper bounds + bounds_check = check_ops.assert_less_equal( + lower_bounds, + upper_bounds, + message="Bvls input lower bounds must best less or equal to upper bounds.", + ) + + # Check shapes + checks = [bounds_check] + + # Add the check to the graph as a control dependency + with ops.control_dependencies(checks): + # Check that all values are finite + matrix = numerics.verify_tensor_all_finite(matrix, "Bvls input matrix.") + rhs = numerics.verify_tensor_all_finite(rhs, "Bvls input rhs.") + lower_bounds = numerics.verify_tensor_all_finite(lower_bounds, "Bvls input lower bounds.") + upper_bounds = numerics.verify_tensor_all_finite(upper_bounds, "Bvls input upper bounds.") + + return matrix, rhs, lower_bounds, upper_bounds + + +def kuhn_tucker_convergence_test_lower(n_grad, variables, lower_bounds): + """ + Convergence test for the variables at the lower bound. + The gradient for the variables at the lower bound must be negative. + + :param n_grad: float Tensor variables negative gradient + :param variables: float Tensor variables + :param lower_bounds: float Tensor lower bounds + :return: bool Tensor indicating if variables converged + """ + + return math_ops.logical_and( + math_ops.less_equal(variables, lower_bounds), + math_ops.less_equal(n_grad, array_ops.zeros_like(n_grad)), + ) + + +def kuhn_tucker_convergence_test_center(variables, lower_bounds, upper_bounds): + """ + Convergence test for the variables strictly between the lower and upper bound. + + :param variables: float Tensor variables + :param lower_bounds: float Tensor lower bounds + :param upper_bounds: float Tensor upper bounds + :return: bool Tensor indicating if variables converged + """ + + return math_ops.logical_and( + math_ops.greater(variables, lower_bounds), + math_ops.less(variables, upper_bounds), + ) + + +def kuhn_tucker_convergence_test_upper(n_grad, variables, upper_bounds): + """ + Convergence test for the variables at the upper bound. + The gradient for the variables at the upper bound must be positive. + + :param n_grad: float Tensor variables negative gradient + :param variables: float Tensor variables + :param upper_bounds: float Tensor upper bounds + :return: bool Tensor indicating if variables converged + """ + + return math_ops.logical_and( + math_ops.greater_equal(variables, upper_bounds), + math_ops.greater_equal(n_grad, array_ops.zeros(array_ops.shape(n_grad), dtype=n_grad.dtype)), + ) + + +def kuhn_tucker_convergence_test(n_grad, variables, lower_bounds, upper_bounds): + """ + Convergence test for the variables. + + :param n_grad: float Tensor variables negative gradient + :param variables: float Tensor variables + :param lower_bounds: float Tensor lower bounds + :param upper_bounds: float Tensor upper bounds + :return: bool Tensor indicating if variables converged + """ + + lower_converged = kuhn_tucker_convergence_test_lower(n_grad, variables, lower_bounds) + center_converged = kuhn_tucker_convergence_test_center(variables, lower_bounds, upper_bounds) + upper_converged = kuhn_tucker_convergence_test_upper(n_grad, variables, upper_bounds) + + converged = array_ops.stack([ + lower_converged, + center_converged, + upper_converged, + ], axis=-1) + + return math_ops.reduce_any(converged, axis=-1) + + +def free_variable_with_largest_gradient(n_grad, lower_mask, upper_mask): + """ + Free variable at bound with largest gradient away from the bound. + + :param n_grad: float Tensor variables negative gradient + :param lower_mask: bool Tensor mask of variables at the lower bound + :param upper_mask: bool Tensor mask of variables at the upper bound + :return: (bool Tensor , bool Tensor) lower and upper mask tuple + """ + + lower_values = array_ops.where(lower_mask, x=n_grad, y=array_ops.zeros_like(n_grad)) + upper_values = array_ops.where(upper_mask, x=-n_grad, y=array_ops.zeros_like(n_grad)) + values = lower_values + upper_values + v_max = math_ops.reduce_max(values) + v_max = math_ops.maximum(v_max, 1E-9) + + lower_mask = math_ops.logical_and( + lower_mask, + math_ops.less(lower_values, v_max), + ) + + upper_mask = math_ops.logical_and( + upper_mask, + math_ops.less(upper_values, v_max), + ) + + return lower_mask, upper_mask + + +def lstsq_negative_gradient( + matrix, + rhs, + variables, + axis=1, + noise_precision=None, + prior_precision=None, + target_weights=None): + if target_weights is None: + target_weights = array_ops.ones_like(rhs) + + einsum1 = "ij,j->i" if axis <= 1 else "ijk,ik->ij" + einsum2 = "ji,j->i" if axis <= 1 else "ikj,ik->ij" + + # Least square + # matrix = tf.Print(matrix, [array_ops.shape(matrix)], message="matrix", summarize=10) + # variables = tf.Print(variables, [array_ops.shape(variables)], message="variables", summarize=10) + b = rhs - special_math_ops.einsum(einsum1, matrix, variables) + b = math_ops.square(target_weights) * b + w = special_math_ops.einsum(einsum2, matrix, b) + + # Prior gradient + if prior_precision is not None: + if noise_precision is None: + noise_precision = constant_op.constant(1., dtype=matrix.dtype) + + np_sqrt = math_ops.sqrt(noise_precision) + + return np_sqrt * w + prior_precision * w + + return w + + +def free_lstsq( + matrix, + rhs, + center_mask, + lower_mask, + lower_bounds, + upper_mask, + upper_bounds, + noise_precision=None, + prior_precision=None, + target_weights=None, + l2_regularizer=0., + fast=True): + """ + Least square regression with variables fixed at lower or upper bound values. + + :param matrix: float Tensor design matrix + :param rhs: float Tensor right hand side + :param center_mask: bool Tensor mask of free variables + :param lower_mask: bool Tensor mask of variables at the lower bound + :param lower_bounds: float Tensor lower bounds + :param upper_mask: bool Tensor mask of variables at the upper bound + :param upper_bounds: float Tensor upper bounds + :param noise_precision: float Scalar noise precision of targets + :param prior_precision: float Tensor prior precision of variables + :param target_weights: float Tensor weights of targets + :param l2_regularizer: float least square regularization + :param fast: bool fast least square (differentiable but less stable) + :return: float Tensor least square result for free variables + """ + + if target_weights is None: + target_weights = array_ops.ones_like(rhs) + + if prior_precision is not None: + if noise_precision is None: + noise_precision = constant_op.constant(1., dtype=matrix.dtype) + + np_sqrt = math_ops.sqrt(noise_precision) + + matrix = array_ops.concat([ + np_sqrt * matrix, + array_ops.diag(math_ops.sqrt(prior_precision)), + ], axis=0) + + rhs = array_ops.concat([ + np_sqrt * rhs, + array_ops.zeros(array_ops.shape(prior_precision), dtype=rhs.dtype) + ], axis=0) + + target_weights = array_ops.concat([ + target_weights, + array_ops.ones(array_ops.shape(prior_precision), dtype=target_weights.dtype), + ], axis=0) + + lm = math_ops.cast(lower_mask, dtype=lower_bounds.dtype) + um = math_ops.cast(upper_mask, dtype=upper_bounds.dtype) + cm = math_ops.cast(center_mask, dtype=upper_bounds.dtype) + + m = special_math_ops.einsum("ij,j->ij", matrix, cm) + m = special_math_ops.einsum("i,ij->ij", target_weights, m) + + b = rhs + b -= math_ops.tensordot(matrix, lm * lower_bounds + um * upper_bounds, axes=[[1], [0]]) + b = target_weights * b + b = array_ops.expand_dims(b, -1) + + # TODO: performance optimize with QR decomposition + result = linalg_impl.lstsq(m, b, l2_regularizer=l2_regularizer, fast=fast) + # result = tf.Print(result, [], message="------------------------", summarize=1000) + # result = tf.Print(result, [lm], message="BVLS lstsq lm", summarize=1000) + # result = tf.Print(result, [cm], message="BVLS lstsq cm", summarize=1000) + # result = tf.Print(result, [um], message="BVLS lstsq um", summarize=1000) + # result = tf.Print(result, [m], message="BVLS lstsq m", summarize=1000) + # result = tf.Print(result, [b], message="BVLS lstsq b", summarize=1000) + # result = tf.Print(result, [result], message="BVLS lstsq", summarize=1000) + + return result[:, 0] + + +def free_bounded_step( + center_mask, + variables, + lower_mask, + lower_bounds, + upper_mask, + upper_bounds): + """ + Update variables based on least square result of variables but respecting the bounds. + + :param center_mask: bool Tensor mask of free variables + :param variables: float Tensor variables + :param lower_mask: bool Tensor mask of free variables + :param lower_bounds: float Tensor lower bounds + :param upper_mask: bool Tensor mask of variables at the upper bound + :param upper_bounds: float Tensor upper bounds + :return: (float Tensor, float Tensor): variables and step size + """ + + zero = array_ops.zeros((), dtype=lower_bounds.dtype) + one = array_ops.ones((), dtype=lower_bounds.dtype) + + lm = math_ops.cast(lower_mask, dtype=lower_bounds.dtype) + um = math_ops.cast(upper_mask, dtype=upper_bounds.dtype) + cm = math_ops.cast(center_mask, dtype=upper_bounds.dtype) + + lower_alphas = math_ops.cast(math_ops.less_equal(lower_bounds, variables), dtype=lower_bounds.dtype) + upper_alphas = math_ops.cast(math_ops.greater_equal(upper_bounds, variables), dtype=lower_bounds.dtype) + + lower_alphas += (1 - lower_alphas) * math_ops.truediv(lower_bounds, variables) + upper_alphas += (1 - upper_alphas) * math_ops.truediv(upper_bounds, variables) + + lower_alphas = (one - cm) + cm * lower_alphas + upper_alphas = (one - cm) + cm * upper_alphas + + lower_alphas = array_ops.where( + math_ops.less(lower_alphas, zero), + x=array_ops.ones_like(lower_alphas), + y=lower_alphas, + ) + upper_alphas = array_ops.where( + math_ops.less(upper_alphas, zero), + x=array_ops.ones_like(upper_alphas), + y=upper_alphas, + ) + + min_alpha = math_ops.reduce_min([lower_alphas, upper_alphas]) + alpha = math_ops.minimum(one, min_alpha) + # alpha = tf.Print(alpha, [variables], message="BVLS variables", summarize=1000) + # alpha = tf.Print(alpha, [lower_alphas], message="BVLS l alpha", summarize=1000) + # alpha = tf.Print(alpha, [upper_alphas], message="BVLS u alpha", summarize=1000) + # alpha = tf.Print(alpha, [alpha], message="BVLS alpha", summarize=1000) + + variables = lm * lower_bounds + alpha * cm * variables + um * upper_bounds + + return variables, alpha + + +def compute_variables_sets(variables, lower_bounds, upper_bounds): + """ + Compute the lower bound mask, center mask, and upper mask. + Some numerical error is allowed. + + :param variables: float Tensor variables + :param lower_bounds: float Tensor lower bounds + :param upper_bounds: float Tensor upper bounds + :return: bool Tensor tuple: lower bound variables, free variables, upper bound variables. + """ + + lm = math_ops.less_equal(variables, lower_bounds + 1E-9) + um = math_ops.greater_equal(variables, upper_bounds - 1E-9) + cm = math_ops.logical_not(math_ops.logical_or(lm, um)) + + return lm, cm, um + + +def lstsq_squared_residuals_sum(matrix, variables, rhs, tws): + residuals = tws * (special_math_ops.einsum("ij,j->i", matrix, variables) - rhs) + return math_ops.reduce_sum(math_ops.square(residuals)) + + +# TODO: warm start +# TODO: convergence result +def tf_bvls( + matrix, + rhs, + lower_bounds, + upper_bounds, + noise_precision=None, + prior_precision=None, + target_weights=None, + l2_regularizer=0., + fast=True, + maximum_iterations=20, + return_iterations=False, + name="bvls", +): + """ + Least square regression with variable bound + + :param matrix: float Tensor design matrix + :param rhs: float Tensor right hand side + :param lower_bounds: float Tensor lower bounds + :param upper_bounds: float Tensor upper bounds + :param noise_precision: float Scalar noise precision of targets + :param prior_precision: float Tensor prior precision of variables + :param target_weights: float Tensor weights of targets + :param l2_regularizer: float Scalar regression regularization + :param fast: bool Use fast regression, less stable + :param maximum_iterations: int Maximum number of iterations + :param return_iterations: bool returns number iterations if True + :param name: str Name of the node in the graph + :return: float Tensor bounded least square result + """ + + # Validate the inputs + matrix, rhs, lower_bounds, upper_bounds = tf_bvls_input_validation( + matrix, rhs, lower_bounds, upper_bounds + ) + + def tf_bvls_condition(_, vs, __, ___, n_grad, free): + """ + Termination condition + + :param _: unused argument + :param vs: float Tensor variables + :param __: unused argument + :param ___: unused argument + :param n_grad: float Tensor variables negative gradient + :param free: bool Tensor free variable at boundary with largest gradient + :return: + """ + + converged = kuhn_tucker_convergence_test(n_grad, vs, lower_bounds, upper_bounds) + + # BVLS terminates when Kuhn-Tucker conditions are met and a variable can be freed + return math_ops.logical_not(math_ops.logical_and( + free, + math_ops.reduce_all(converged), + )) + + def tf_bvls_body(i, _, lm, um, n_grad, free): + """ + BVLS least square loop + + :param i: int Scalar iteration counter + :param _: unused argument + :param lm: bool Tensor variables at lower bound + :param um: bool Tensor variables at upper bound + :param n_grad: float Tensor variables negative gradient + :param free: bool Tensor flag to free variable at boundary with largest gradient + :return: Tensor tuple with loop variables for next iteration + """ + + lm, um = control_flow_ops.cond( + free, + true_fn=lambda: free_variable_with_largest_gradient(n_grad, lm, um), + false_fn=lambda: (lm, um), + ) + cm = math_ops.logical_not(math_ops.logical_or(lm, um)) + + # Compute the least square regression over the free variables + result = free_lstsq( + matrix, rhs, cm, lm, lower_bounds, um, upper_bounds, + noise_precision=noise_precision, + prior_precision=prior_precision, + target_weights=target_weights, + l2_regularizer=l2_regularizer, + fast=fast, + ) + + # Perform a bound respecting update step for the variables + vs, alpha = free_bounded_step(cm, result, lm, lower_bounds, um, upper_bounds) + + # Compute the sets of variables at the lower and upper bounds + lm, _, um = compute_variables_sets(vs, lower_bounds, upper_bounds) + + # Compute the negative gradient for each variable + n_grad = lstsq_negative_gradient( + matrix, + rhs, + vs, + noise_precision=noise_precision, + prior_precision=prior_precision, + target_weights=target_weights, + ) + + # When no free variable hit a free bound, the next step is to free + # the variable at the bound with the largest gradient + free = math_ops.greater_equal(alpha, 1.0) + + return i + 1, vs, lm, um, n_grad, free + + # Cold start + i0 = constant_op.constant(0, dtype=dtypes.int8) + vs0 = (lower_bounds + upper_bounds) / 2. + lower_mask0 = math_ops.less_equal(vs0, lower_bounds) + upper_mask0 = math_ops.greater_equal(vs0, upper_bounds) + n_grad0 = lstsq_negative_gradient(matrix, rhs, vs0) + free0 = constant_op.constant(False) + + iterations, variables, lower_mask, upper_mask, _, _ = control_flow_ops.while_loop( + tf_bvls_condition, + tf_bvls_body, + loop_vars=(i0, vs0, lower_mask0, upper_mask0, n_grad0, free0), + back_prop=False, + maximum_iterations=maximum_iterations, + parallel_iterations=1, + name="bvls_loop", + ) + + if return_iterations: + return ( + array_ops.identity(variables, name=name), + array_ops.identity(iterations, name="%s_iterations" % name), + ) + else: + return array_ops.identity(variables, name=name) + + +def tf_bvls_batch( + matrix, + rhs, + lower_bounds, + upper_bounds, + noise_precision=None, + prior_precision=None, + l2_regularizer=0., + fast=True, + maximum_iterations=20, + parallel_iterations=None, + name="bvls_batch"): + def map_multi_args(fn, arrays, dtype=dtypes.float32): + indices = math_ops.range(0, limit=array_ops.shape(arrays[0])[0], dtype=dtypes.int32) + out = map_fn.map_fn( + lambda ii: fn( + *[array[ii] for array in arrays], + noise_precision=noise_precision, + prior_precision=prior_precision, + l2_regularizer=l2_regularizer, + fast=fast, + maximum_iterations=maximum_iterations, + ), + indices, + dtype=dtype, + parallel_iterations=parallel_iterations, + ) + return out + + ws = map_multi_args( + fn=tf_bvls, + arrays=[matrix, rhs, lower_bounds, upper_bounds], + dtype=matrix.dtype, + ) + + return array_ops.identity(ws, name=name) diff --git a/tensorflow/contrib/linalg/bvls_test.py b/tensorflow/contrib/linalg/bvls_test.py new file mode 100644 index 00000000000000..e3226a32e0ecd3 --- /dev/null +++ b/tensorflow/contrib/linalg/bvls_test.py @@ -0,0 +1,335 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for tf.contrib.linalg.bvls.""" + +import time +import unittest + +import numpy as np + +from tensorflow.python import convert_to_tensor + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.training import adam + +from tensorflow.contrib.linalg.bvls import ( + lstsq_negative_gradient, + tf_bvls, + free_variable_with_largest_gradient, + free_lstsq, + free_bounded_step, + tf_bvls_batch, + lstsq_squared_residuals_sum, +) + + +class TestTfBvls(unittest.TestCase): + """ + Test cases for the bounded variable least square solver + """ + + @staticmethod + def getTestCase(nd=5, nw=4): + m = np.random.normal(0.0, 1.0, (nd, nw)) + rhs = np.random.normal(0.0, 1.0, (nd,)) + lower_bounds = np.random.uniform(-1.0, 0.0, (nw,)) + upper_bounds = np.random.uniform(0.0, 1.0, (nw,)) + noise_precision = np.random.uniform(1E-2, 1E-1, ()) + prior_precision = np.random.uniform(0., 1E-2, (nw,)) + target_weights = np.random.uniform(0., 1., (nd,)) + + m = convert_to_tensor(m) + rhs = convert_to_tensor(rhs) + lower_bounds = convert_to_tensor(lower_bounds) + upper_bounds = convert_to_tensor(upper_bounds) + + return m, rhs, lower_bounds, upper_bounds, noise_precision, prior_precision, target_weights + + @staticmethod + def solve( + m, + rhs, + lower_bounds, + upper_bounds, + noise_precision=None, + prior_precision=None, + target_weights=None): + + nw = m.shape[1] + + if noise_precision is None: + noise_precision = constant_op.constant(1., dtype=m.dtype) + + if prior_precision is None: + prior_precision = array_ops.zeros(nw, dtype=rhs.dtype) + + if target_weights is None: + target_weights = array_ops.ones_like(rhs, dtype=rhs.dtype) + + w = variables.Variable( + array_ops.zeros((nw,), dtype=m.dtype), + constraint=lambda x: clip_ops.clip_by_value(x, lower_bounds, upper_bounds), + name="w", + ) + + # Least square residuals + lstsq_residuals = target_weights * (special_math_ops.einsum("ij,j->i", m, w) - rhs) + + # Least square loss + tf_loss = noise_precision * math_ops.reduce_sum(math_ops.square(lstsq_residuals)) + + # Prior loss + tf_loss += math_ops.reduce_sum(prior_precision * w * w) + + train = adam.AdamOptimizer(0.01).minimize( + tf_loss, + var_list=[w], + ) + init = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init) + + loss_prev = 1E30 + loss = 1E29 + i = 0 + start = time.time() + while abs(loss - loss_prev) > 1E-16 and i < 5000: + i += 1 + loss_prev = loss + loss, _ = sess.run((tf_loss, train)) + + end = time.time() + w_result = sess.run(w) + loss = sess.run(tf_loss) + print("Loss ", i, ":", loss, "Time (ms): ", round(1000 * (end - start))) + + return w_result, loss + + @staticmethod + def timed_execution(func): + result = None + start = time.time() + + for _ in range(100): + result = func() + + end = time.time() + execution_time = round(1000 * (end - start) / 100, 3) + + return result, execution_time + + def setUp(self): + """ + Initialize a random bounded least square regression problem. + """ + + # Bounded regression example + self.m1 = np.array([ + [0.890197, 0.98748, 0.597844], + [0.686742, 0.0558757, 0.201711], + [0.383872, 0.96083, 0.319599], + ]) + self.rhs1 = np.array([0.360696, 0.945096, 0.106577]) + + # Lower bounds + self.l1 = np.array([-1, -0.5, -1]) + + # Upper bounds + self.u1 = np.array([1, 0.5, 1]) + + # Bounded solution + self.v1 = np.array([1, -0.5, 0.202432]) + + def test_bvls_free_variable_with_largest_gradient(self): + """ + Check that the variable at the boundary with the largest gradient is being freed. + """ + + m1 = ops.convert_to_tensor(self.m1) + rhs1 = ops.convert_to_tensor(self.rhs1) + l1 = ops.convert_to_tensor(self.l1) + + # Negative gradient + n_grad = lstsq_negative_gradient(m1, rhs1, l1) + + lower_mask = [False, True, False] + upper_mask = [True, False, False] + tf_result = free_variable_with_largest_gradient(n_grad, lower_mask, upper_mask) + + with session.Session() as sess: + result = sess.run(tf_result) + np.testing.assert_equal(result[0], [False, False, False]) + np.testing.assert_equal(result[1], [True, False, False]) + + def test_bvls_free_variable_with_largest_gradient_batch(self): + """ + Check that the variable at the boundary with the largest gradient is being freed. + """ + + m1 = ops.convert_to_tensor(self.m1.reshape((1, 3, 3))) + rhs1 = ops.convert_to_tensor(self.rhs1.reshape((1, 3))) + l1 = ops.convert_to_tensor(self.l1.reshape((1, 3))) + + # Negative gradient + n_grad = lstsq_negative_gradient(m1, rhs1, l1, axis=2) + + lower_mask = [[False, True, False]] + upper_mask = [[True, False, False]] + tf_result = free_variable_with_largest_gradient(n_grad, lower_mask, upper_mask) + + with session.Session() as sess: + result = sess.run(tf_result) + np.testing.assert_equal(result[0], [[False, False, False]]) + np.testing.assert_equal(result[1], [[True, False, False]]) + + def test_bvls_free_lstsq(self): + """ + Check that least square regression over the free variables works. + """ + + lower_mask = [False, True, False] + upper_mask = [True, False, False] + center_mask = [False, False, True] + + m = ops.convert_to_tensor(self.m1) + rhs = ops.convert_to_tensor(self.rhs1) + + tf_result = free_lstsq( + m, rhs, center_mask, lower_mask, self.l1, upper_mask, self.u1, fast=False) + + with session.Session() as sess: + result = sess.run(tf_result) + np.testing.assert_almost_equal(result, [0, 0, self.v1[-1]], decimal=4) + + def test_bvls_free_bounded_step_passing(self): + """ + Check that least square regression step respects the bounds. + """ + + lower_mask = [False, True, False] + upper_mask = [True, False, False] + center_mask = [False, False, True] + + cvs = np.array([0.0, 0.0, self.v1[-1]], dtype=np.float64) + + tf_result = free_bounded_step( + center_mask, cvs, lower_mask, self.l1, upper_mask, self.u1) + + with session.Session() as sess: + result, _ = sess.run(tf_result) + np.testing.assert_almost_equal(result, [1, -0.5, self.v1[-1]], decimal=4) + + def test_bvls_free_bounded_step_clipped(self): + """ + Check that least square regression step respects the bounds. + """ + + lower_mask = [False, True, False] + upper_mask = [True, False, False] + center_mask = [False, False, True] + + cvs = np.array([0.0, 0.0, 2.0], dtype=np.float64) + + tf_result = free_bounded_step( + center_mask, cvs, lower_mask, self.l1, upper_mask, self.u1) + + with session.Session() as sess: + result, _ = sess.run(tf_result) + np.testing.assert_almost_equal(result, [1, -0.5, 1.0], decimal=4) + + # TODO: write test cases for all boundary scenarios + + def test_bvls(self): + """ + Check that the bounded least square regression works. + """ + + tf_result = tf_bvls(self.m1, self.rhs1, self.l1, self.u1, fast=False) + + with session.Session() as sess: + result = sess.run(tf_result) + print(result) + np.testing.assert_almost_equal(result, [1, -0.5, self.v1[-1]], decimal=4) + + def test_bvls_batch(self): + """ + Check that the bounded least square regression works for a batch input. + """ + + m_batch = ops.convert_to_tensor(np.tile(self.m1, (2, 1, 1))) + rhs_batch = ops.convert_to_tensor(np.tile(self.rhs1, (2, 1))) + lb_batch = ops.convert_to_tensor(np.tile(self.l1, (2, 1))) + ub_batch = ops.convert_to_tensor(np.tile(self.u1, (2, 1))) + + tf_result = tf_bvls_batch(m_batch, rhs_batch, lb_batch, ub_batch, fast=False) + + with session.Session() as sess: + result = sess.run(tf_result) + print(result) + # np.testing.assert_almost_equal(result, [1, -0.5, self.v1[-1]], decimal=4) + + # @unittest.skip + def test_bvls_random_test_cases(self): + """ + Check that the bounded least square regression works for random test cases. + """ + + for _ in range(10): + print("-" * 100) + + m, rhs, lb, ub, noise_precision, prior_precision, tws = self.getTestCase(nd=5, nw=3) + expected_w0, gloss0 = self.solve(m, rhs, lb, ub) + expected_w, gloss = self.solve(m, rhs, lb, ub, noise_precision, prior_precision, tws) + tf_bvls_result = tf_bvls( + m, rhs, lb, ub, + noise_precision=noise_precision, + prior_precision=prior_precision, + target_weights=tws, + fast=False, + return_iterations=True, + ) + tf_lstsq_result = linalg_impl.lstsq(m, array_ops.expand_dims(rhs, -1)) + tf_loss = noise_precision * lstsq_squared_residuals_sum(m, tf_bvls_result[0], rhs, tws) + + with session.Session() as sess: + (w_result, i), bvls_time = self.timed_execution(lambda: sess.run(tf_bvls_result)) + _, lstsq_time = self.timed_execution(lambda: sess.run(tf_lstsq_result)) + loss = sess.run(tf_loss) + + print("W sample") + print("Lower bound: ", w_result <= sess.run(lb)) + print("Upper bound: ", w_result >= sess.run(ub)) + print("Result: ", w_result) + print("Expected0: ", expected_w0) + print("Expected: ", expected_w) + print("Assert: ", abs(w_result - expected_w) < 1E-2) + print("Time (ms): ", round(bvls_time / i, 3), lstsq_time, i) + print("Loss: ", loss, 100 * (gloss - loss) / loss, "%") + + np.testing.assert_almost_equal(w_result, expected_w, decimal=3) + self.assertGreaterEqual(gloss, loss) + + +if __name__ == '__main__': + unittest.main()