Skip to content

[fix] Combine solver solution with time offset#809

Open
maxbriel wants to merge 8 commits intomainfrom
mb_solve_fix
Open

[fix] Combine solver solution with time offset#809
maxbriel wants to merge 8 commits intomainfrom
mb_solve_fix

Conversation

@maxbriel
Copy link
Collaborator

@maxbriel maxbriel commented Feb 17, 2026

For DCO we're doing multiple solves. We tried to implement a fix by changing the solver range to correctly account for the evolution time. However, this causes issues due to the scale differences in time that the equations evolve over.
We were doing multiple iterations of solve_ivp to account for this time scale difference. The fix did not account for this and ends up causing systems within the double_CO step to fail.

Here I re-implement the old method of solving the ODE's in multiple steps.
To propagate the changes to after the evolution in step_detached, I create a "fake" solution object combines the multiple solves into a single result object with similar attributes/functions as the original solve_ivp object.

@maxbriel
Copy link
Collaborator Author

I'm getting similar results with this and main at hash d39be4d
I'm verifying why there's some small differences still.

image

@maxbriel
Copy link
Collaborator Author

maxbriel commented Feb 17, 2026

Output from this branch:
evolve_binaries_mb_solve_fix.txt

Output from main at d39be4d
evolve_binaries_main.txt

There are some differences in the exact solutions between these two, but they're closer than the current main:
evolve_binaries_main.txt

I hope we can solve this quickly

Comment on lines +115 to +124
if len(sol) == 1:
output_solution = CombinedSolution()
output_solution.t = sol[0].t + t0
output_solution.y = sol[0].y
output_solution.status = sol[0].status
output_solution.message = sol[0].message
output_solution.t_events = sol[0].t_events
output_solution.y_events = sol[0].y_events
output_solution.success = sol[0].success
output_solution.sol = lambda t: sol[0].sol(t-t0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if len(sol) == 1:
output_solution = CombinedSolution()
output_solution.t = sol[0].t + t0
output_solution.y = sol[0].y
output_solution.status = sol[0].status
output_solution.message = sol[0].message
output_solution.t_events = sol[0].t_events
output_solution.y_events = sol[0].y_events
output_solution.success = sol[0].success
output_solution.sol = lambda t: sol[0].sol(t-t0)

Comment on lines +127 to +142
else:
output_solution = CombinedSolution()
output_solution.t = np.concatenate([t0+t.t for t, t0 in zip(sol, time_sol)])
output_solution.y = np.hstack([s.y for s in sol])
output_solution.status = sol[-1].status
output_solution.message = sol[-1].message
output_solution.t_events = sol[-1].t_events
output_solution.y_events = sol[-1].y_events
output_solution.success = sol[-1].success

# dynamically create a combined sol method that can interpolate across the combined solution
def combined_sol(t):
for s, t0 in zip(sol, time_sol):
if t0 <= t <= t0 + s.t[-1]:
return s.sol(t - t0)
raise ValueError(f"Time {t} is out of bounds for the combined solution.")
Copy link
Contributor

@sgossage sgossage Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
output_solution = CombinedSolution()
output_solution.t = np.concatenate([t0+t.t for t, t0 in zip(sol, time_sol)])
output_solution.y = np.hstack([s.y for s in sol])
output_solution.status = sol[-1].status
output_solution.message = sol[-1].message
output_solution.t_events = sol[-1].t_events
output_solution.y_events = sol[-1].y_events
output_solution.success = sol[-1].success
# dynamically create a combined sol method that can interpolate across the combined solution
def combined_sol(t):
for s, t0 in zip(sol, time_sol):
if t0 <= t <= t0 + s.t[-1]:
return s.sol(t - t0)
raise ValueError(f"Time {t} is out of bounds for the combined solution.")
output_solution = CombinedSolution()
output_solution.t = np.concatenate([t0+s.t for s, t0 in zip(sol, time_sol)])
output_solution.y = np.hstack([s.y for s in sol])
output_solution.status = sol[-1].status
output_solution.message = sol[-1].message
output_solution.t_events = sol[-1].t_events
output_solution.y_events = sol[-1].y_events
output_solution.success = sol[-1].success
# dynamically create a combined sol method that can interpolate across the combined solution
def combined_sol(t):
for s, t0 in zip(sol, time_sol):
if t0 <= t <= t0 + s.t[-1]:
return s.sol(t - t0)
raise ValueError(f"Time {t} is out of bounds for the combined solution.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe move this into a function

return s.sol(t - t0)
raise ValueError(f"Time {t} is out of bounds for the combined solution.")

output_solution.sol = combined_sol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_solution.sol = combined_sol
output_solution.sol = combined_sol

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments