Skip to content

Commit 3023d9b

Browse files
Change FindValueOfExpression signature to return symbolic.Expression (#27)
* Initial plan * Change FindValueOfExpression signature to return symbolic.Expression Co-authored-by: kwesiRutledge <9002730+kwesiRutledge@users.noreply.github.com> * Add helper function to reduce code duplication in tests Co-authored-by: kwesiRutledge <9002730+kwesiRutledge@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: kwesiRutledge <9002730+kwesiRutledge@users.noreply.github.com>
1 parent 213c11b commit 3023d9b

2 files changed

Lines changed: 36 additions & 19 deletions

File tree

solution/solution.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ func ExtractValueOfVariable(s Solution, v symbolic.Variable) (float64, error) {
4444

4545
// FindValueOfExpression evaluates a symbolic expression using the values from a solution.
4646
// It substitutes all variables in the expression with their values from the solution
47-
// and returns the resulting scalar value.
48-
func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error) {
47+
// and returns the resulting symbolic expression (typically a constant).
48+
func FindValueOfExpression(s Solution, expr symbolic.Expression) (symbolic.Expression, error) {
4949
// Get all variables in the expression
5050
vars := expr.Variables()
5151

@@ -54,7 +54,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
5454
for _, v := range vars {
5555
val, err := ExtractValueOfVariable(s, v)
5656
if err != nil {
57-
return 0.0, fmt.Errorf(
57+
return nil, fmt.Errorf(
5858
"failed to extract value for variable %v: %w",
5959
v.ID,
6060
err,
@@ -66,16 +66,7 @@ func FindValueOfExpression(s Solution, expr symbolic.Expression) (float64, error
6666
// Substitute all variables with their values
6767
resultExpr := expr.SubstituteAccordingTo(subMap)
6868

69-
// Type assert to K (constant) to extract the float64 value
70-
resultK, ok := resultExpr.(symbolic.K)
71-
if !ok {
72-
return 0.0, fmt.Errorf(
73-
"expected substituted expression to be a constant, got type %T",
74-
resultExpr,
75-
)
76-
}
77-
78-
return float64(resultK), nil
69+
return resultExpr, nil
7970
}
8071

8172
// GetOptimalObjectiveValue evaluates the objective function of an optimization problem
@@ -95,10 +86,19 @@ func GetOptimalObjectiveValue(sol Solution) (float64, error) {
9586
}
9687

9788
// Use FindValueOfExpression to evaluate the objective at the solution point
98-
value, err := FindValueOfExpression(sol, objectiveExpr)
89+
resultExpr, err := FindValueOfExpression(sol, objectiveExpr)
9990
if err != nil {
10091
return 0.0, fmt.Errorf("failed to evaluate objective expression: %w", err)
10192
}
10293

103-
return value, nil
94+
// Type assert to K (constant) to extract the float64 value
95+
resultK, ok := resultExpr.(symbolic.K)
96+
if !ok {
97+
return 0.0, fmt.Errorf(
98+
"expected substituted expression to be a constant, got type %T",
99+
resultExpr,
100+
)
101+
}
102+
103+
return float64(resultK), nil
104104
}

testing/solution/solution_test.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ Description:
1717
(This seems like it is highly representative of the Gurobi solver; is there a reason to make it this way?)
1818
*/
1919

20+
// Helper function to convert a symbolic.Expression to float64
21+
func exprToFloat64(t *testing.T, expr symbolic.Expression) float64 {
22+
resultK, ok := expr.(symbolic.K)
23+
if !ok {
24+
t.Fatalf("Expected result to be a constant, got type %T", expr)
25+
}
26+
return float64(resultK)
27+
}
28+
2029
func TestSolution_ToMessage1(t *testing.T) {
2130
// Constants
2231
tempSol := solution.DummySolution{
@@ -161,11 +170,13 @@ func TestSolution_FindValueOfExpression1(t *testing.T) {
161170
expr := v1.Multiply(symbolic.K(2.0)).Plus(v2.Multiply(symbolic.K(3.0)))
162171

163172
// Algorithm
164-
result, err := solution.FindValueOfExpression(&tempSol, expr)
173+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
165174
if err != nil {
166175
t.Errorf("FindValueOfExpression returned an error: %v", err)
167176
}
168177

178+
result := exprToFloat64(t, resultExpr)
179+
169180
expected := 13.0
170181
if result != expected {
171182
t.Errorf(
@@ -194,11 +205,13 @@ func TestSolution_FindValueOfExpression2(t *testing.T) {
194205
expr := symbolic.K(42.0)
195206

196207
// Algorithm
197-
result, err := solution.FindValueOfExpression(&tempSol, expr)
208+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
198209
if err != nil {
199210
t.Errorf("FindValueOfExpression returned an error: %v", err)
200211
}
201212

213+
result := exprToFloat64(t, resultExpr)
214+
202215
expected := 42.0
203216
if result != expected {
204217
t.Errorf(
@@ -231,11 +244,13 @@ func TestSolution_FindValueOfExpression3(t *testing.T) {
231244
expr := v1.Plus(symbolic.K(10.0))
232245

233246
// Algorithm
234-
result, err := solution.FindValueOfExpression(&tempSol, expr)
247+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
235248
if err != nil {
236249
t.Errorf("FindValueOfExpression returned an error: %v", err)
237250
}
238251

252+
result := exprToFloat64(t, resultExpr)
253+
239254
expected := 15.5
240255
if result != expected {
241256
t.Errorf(
@@ -304,11 +319,13 @@ func TestSolution_FindValueOfExpression5(t *testing.T) {
304319
expr := v1.Plus(v2).Multiply(v3).Plus(symbolic.K(5.0))
305320

306321
// Algorithm
307-
result, err := solution.FindValueOfExpression(&tempSol, expr)
322+
resultExpr, err := solution.FindValueOfExpression(&tempSol, expr)
308323
if err != nil {
309324
t.Errorf("FindValueOfExpression returned an error: %v", err)
310325
}
311326

327+
result := exprToFloat64(t, resultExpr)
328+
312329
expected := 14.0
313330
if result != expected {
314331
t.Errorf(

0 commit comments

Comments
 (0)