From 22af9a3f7128ed1c3cac3cae09b0a9485e48d280 Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sat, 20 Jul 2024 22:02:55 -0400 Subject: [PATCH 1/7] attempt to simplify within sympy --- rate_collection.py | 63 ++++++++++++++++++++++++++++++++++----------- write_actual_rhs.py | 2 +- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/rate_collection.py b/rate_collection.py index c1614e2..a915801 100644 --- a/rate_collection.py +++ b/rate_collection.py @@ -1,20 +1,11 @@ # Common Imports -import warnings import functools -import math import os -from operator import mul, add -from collections import OrderedDict, Counter - -from ipywidgets import interact +from operator import mul +from collections import Counter import numpy as np -import matplotlib as mpl -import matplotlib.pyplot as plt -from mpl_toolkits.axes_grid1 import make_axes_locatable -from matplotlib.ticker import MaxNLocator -import networkx as nx # Import Rate from rate import ChemSpecie, ChemRate, ChemComposition, SympyChemRate @@ -1927,6 +1918,44 @@ def write_cxxnetwork_withoutCSE(self, T, composition, redshift, density=0): return None + def simplify_expression(self, old_expr): + """Simplify a Sympy expression before turning it into C++ code.""" + import sympy as sp + + def pow_as_exp_of_log(expr): + """rewrites Pow(x, a) as Exp(a * Log(x))""" + x, a = expr.args + if not isinstance(a, sp.core.numbers.Integer) \ + and not ((a == sp.Rational(1, 2)) or (a == sp.Rational(-1, 2))) \ + and x.is_positive: + my_new_expr = sp.Pow(sp.E, sp.UnevaluatedExpr(a * sp.log(x))) + #print("transforming:", expr, "->", my_new_expr) + return my_new_expr + else: + return expr + + new_expr = old_expr.replace(lambda expr: expr.is_Pow, pow_as_exp_of_log) + + simplified_expr = sp.powsimp(new_expr) + ## powsimp is buggy, doesn't perform either replacement below: + ## 1.4e-18*X(0)*X(2)*exp(-6.17283950617284e-5*T)*exp(0.928*log(T)) -> 1.4e-18*X(0)*X(2)*exp(-6.17283950617284e-5*T + 0.928*log(T)) + ## 35.5*X(0)*X(8)*exp(-46707.0/T)*exp(-2.28*log(T)) -> 35.5*X(0)*X(8)*exp(-46707.0/T + -2.28*log(T)) + + # fails with "RecursionError: maximum recursion depth exceeded" + #print("simplifying all...", end="", flush=True) + #full_simplified_expr = sp.simplify(simplified_expr) + #print("finished.") + full_simplified_expr = simplified_expr + + print("") + print("Original expression:", old_expr) + print("") + print("Simplified expression:", full_simplified_expr) + print("") + + return full_simplified_expr + + def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): import sympy as sp from sympy.printing import cxxcode @@ -1946,13 +1975,16 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): ydots_subs = [] fupdated = '' for i, specie in enumerate(composition): - ydots_subs.append(sp.expand_log(ydots[specie].subs({ChemSpecie('elec').sym_name: sy[0], ChemSpecie('hp').sym_name: sy[1], \ + this_ydot = ydots[specie].subs({ChemSpecie('elec').sym_name: sy[0], ChemSpecie('hp').sym_name: sy[1], \ ChemSpecie('h').sym_name: sy[2], ChemSpecie('hm').sym_name: sy[3], \ ChemSpecie('dp').sym_name: sy[4], ChemSpecie('d').sym_name: sy[5], \ ChemSpecie('h2p').sym_name: sy[6], ChemSpecie('dm').sym_name: sy[7], \ ChemSpecie('h2').sym_name: sy[8], ChemSpecie('hdp').sym_name: sy[9], \ ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ - ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}))) + ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) + #this_ydot = sp.expand_log(this_ydot) + this_simplified_ydot = self.simplify_expression(this_ydot) + ydots_subs.append(this_simplified_ydot) print('Substituted ydot ', i+1) a, b = sp.cse(ydots_subs) @@ -1978,7 +2010,8 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) print('EDOT substituted') - tdot_subs = sp.expand_log(tdot_subs) + #tdot_subs = sp.expand_log(tdot_subs) + simplified_tdot = self.simplify_expression(tdot_subs) #for Tdot and jacobians, we will use sympy.cse (common subexpression eliminatino) #this greatly simplifies the code and redduces the file size by a factor of 10 @@ -1986,7 +2019,7 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): fupdated = '' - a, b = sp.cse(tdot_subs) + a, b = sp.cse(simplified_tdot) for j in range(0, len(a)): a1 = cxxcode(a[j]) fupdated += a1.replace('{',' Real ',1).replace(',',' =',1).replace('}',';\n\n',len(a1)) diff --git a/write_actual_rhs.py b/write_actual_rhs.py index 1158781..678ebe6 100644 --- a/write_actual_rhs.py +++ b/write_actual_rhs.py @@ -14,7 +14,7 @@ T = sp.symbols('T', positive=True) -redshift = sp.symbols('z', real=True) +redshift = sp.symbols('z', positive=True) bb = ratecol.SympyChemRateCollection(rates=list(sym_rates.values()),tdot_switch=1,ydots_lambdified=False,withD=1,massfracs=0) From b559398df8ee815f95a0a165eed7d9ceac9c1d9e Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sat, 20 Jul 2024 22:19:09 -0400 Subject: [PATCH 2/7] update docstring --- rate_collection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rate_collection.py b/rate_collection.py index a915801..369bbda 100644 --- a/rate_collection.py +++ b/rate_collection.py @@ -1923,7 +1923,7 @@ def simplify_expression(self, old_expr): import sympy as sp def pow_as_exp_of_log(expr): - """rewrites Pow(x, a) as Exp(a * Log(x))""" + """rewrites Pow(x, a) as Exp(a * Log(x)) whenever a is non-integer, a is not in {1/2, -1/2}, and x > 0.""" x, a = expr.args if not isinstance(a, sp.core.numbers.Integer) \ and not ((a == sp.Rational(1, 2)) or (a == sp.Rational(-1, 2))) \ @@ -1970,7 +1970,7 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): f = open('empty_rhs.H', 'r') fcontent = f.read() - print('Expanding log') + #print('Expanding log') ydots_subs = [] fupdated = '' From f6ecb0bf64b4ff89783f1fa6b6c1e74c20f062ef Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sat, 20 Jul 2024 22:21:53 -0400 Subject: [PATCH 3/7] remove commented-out code --- rate_collection.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/rate_collection.py b/rate_collection.py index 369bbda..3ecf5ae 100644 --- a/rate_collection.py +++ b/rate_collection.py @@ -1970,7 +1970,6 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): f = open('empty_rhs.H', 'r') fcontent = f.read() - #print('Expanding log') ydots_subs = [] fupdated = '' @@ -1982,7 +1981,6 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): ChemSpecie('h2').sym_name: sy[8], ChemSpecie('hdp').sym_name: sy[9], \ ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) - #this_ydot = sp.expand_log(this_ydot) this_simplified_ydot = self.simplify_expression(this_ydot) ydots_subs.append(this_simplified_ydot) print('Substituted ydot ', i+1) @@ -2010,7 +2008,6 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) print('EDOT substituted') - #tdot_subs = sp.expand_log(tdot_subs) simplified_tdot = self.simplify_expression(tdot_subs) #for Tdot and jacobians, we will use sympy.cse (common subexpression eliminatino) From 1c8f42e9ac0fd30759b7a8ef9c0ed29761be3c49 Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sat, 20 Jul 2024 22:35:42 -0400 Subject: [PATCH 4/7] rename vars --- rate_collection.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/rate_collection.py b/rate_collection.py index 3ecf5ae..ba25c4b 100644 --- a/rate_collection.py +++ b/rate_collection.py @@ -1943,17 +1943,16 @@ def pow_as_exp_of_log(expr): # fails with "RecursionError: maximum recursion depth exceeded" #print("simplifying all...", end="", flush=True) - #full_simplified_expr = sp.simplify(simplified_expr) + #simplified_expr = sp.simplify(simplified_expr) #print("finished.") - full_simplified_expr = simplified_expr print("") print("Original expression:", old_expr) print("") - print("Simplified expression:", full_simplified_expr) + print("Simplified expression:", simplified_expr) print("") - return full_simplified_expr + return simplified_expr def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): @@ -2008,7 +2007,7 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) print('EDOT substituted') - simplified_tdot = self.simplify_expression(tdot_subs) + #tdot_subs = self.simplify_expression(tdot_subs) # this fails due to a SymPy bug :/ #for Tdot and jacobians, we will use sympy.cse (common subexpression eliminatino) #this greatly simplifies the code and redduces the file size by a factor of 10 @@ -2016,7 +2015,7 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): fupdated = '' - a, b = sp.cse(simplified_tdot) + a, b = sp.cse(tdot_subs) for j in range(0, len(a)): a1 = cxxcode(a[j]) fupdated += a1.replace('{',' Real ',1).replace(',',' =',1).replace('}',';\n\n',len(a1)) From 213f45723e46eab5ba4f5c074c8f5f221fc14010 Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sun, 21 Jul 2024 17:30:32 -0400 Subject: [PATCH 5/7] add sympy bug reproducer --- reproducer.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 reproducer.py diff --git a/reproducer.py b/reproducer.py new file mode 100644 index 0000000..71c9024 --- /dev/null +++ b/reproducer.py @@ -0,0 +1,25 @@ +import sympy as sp + +## expected output: +# old expr: exp(2*T) +# simplified expr: exp(2*T) +# +# old expr: exp(T)*exp(T) +# simplified expr: exp(T)*exp(T) + +# works +expr1 = sp.sympify("exp(T)") +expr2 = sp.sympify("T") +old_expr = expr1 * sp.Pow(sp.E, expr2) +simplified_expr = sp.powsimp(old_expr) +print("old expr:", old_expr) +print("simplified expr:", simplified_expr) +print("") + +# fails +expr1 = sp.sympify("exp(T)") +expr2 = sp.UnevaluatedExpr(sp.sympify("T")) +old_expr = expr1 * sp.Pow(sp.E, expr2) +simplified_expr = sp.powsimp(old_expr) +print("old expr:", old_expr) +print("simplified expr:", simplified_expr) From ee94232e11a2ef9778e886a1ab5faad13619febb Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sun, 21 Jul 2024 17:32:47 -0400 Subject: [PATCH 6/7] update reproducer --- reproducer.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/reproducer.py b/reproducer.py index 71c9024..2fb6cb6 100644 --- a/reproducer.py +++ b/reproducer.py @@ -1,25 +1,32 @@ -import sympy as sp +import sympy -## expected output: +## output when run: # old expr: exp(2*T) # simplified expr: exp(2*T) # # old expr: exp(T)*exp(T) # simplified expr: exp(T)*exp(T) +## desired output: +# old expr: exp(2*T) +# simplified expr: exp(2*T) +# +# old expr: exp(T)*exp(T) +# simplified expr: exp(T+T) + # works -expr1 = sp.sympify("exp(T)") -expr2 = sp.sympify("T") -old_expr = expr1 * sp.Pow(sp.E, expr2) -simplified_expr = sp.powsimp(old_expr) +expr1 = sympy.sympify("exp(T)") +expr2 = sympy.sympify("T") +old_expr = expr1 * sympy.Pow(sympy.E, expr2) +simplified_expr = sympy.powsimp(old_expr) print("old expr:", old_expr) print("simplified expr:", simplified_expr) print("") # fails -expr1 = sp.sympify("exp(T)") -expr2 = sp.UnevaluatedExpr(sp.sympify("T")) -old_expr = expr1 * sp.Pow(sp.E, expr2) -simplified_expr = sp.powsimp(old_expr) +expr1 = sympy.sympify("exp(T)") +expr2 = sympy.UnevaluatedExpr(sympy.sympify("T")) +old_expr = expr1 * sympy.Pow(sympy.E, expr2) +simplified_expr = sympy.powsimp(old_expr) print("old expr:", old_expr) print("simplified expr:", simplified_expr) From c7102c0c51df9f9e01a6fb63647d5a13e7a175e3 Mon Sep 17 00:00:00 2001 From: Ben Wibking Date: Sun, 21 Jul 2024 17:41:50 -0400 Subject: [PATCH 7/7] simplify reproducer --- reproducer.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/reproducer.py b/reproducer.py index 2fb6cb6..14ca0f1 100644 --- a/reproducer.py +++ b/reproducer.py @@ -1,32 +1,17 @@ -import sympy - -## output when run: -# old expr: exp(2*T) -# simplified expr: exp(2*T) -# -# old expr: exp(T)*exp(T) -# simplified expr: exp(T)*exp(T) +## actual output: +# old expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T')))) +# simplified expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T')))) -## desired output: -# old expr: exp(2*T) -# simplified expr: exp(2*T) -# -# old expr: exp(T)*exp(T) -# simplified expr: exp(T+T) +## expected output: +# old expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T')))) +# simplified expr: Mul(exp(Add(Symbol('T')), UnevaluatedExpr(Symbol('T'))))) -# works -expr1 = sympy.sympify("exp(T)") -expr2 = sympy.sympify("T") -old_expr = expr1 * sympy.Pow(sympy.E, expr2) -simplified_expr = sympy.powsimp(old_expr) -print("old expr:", old_expr) -print("simplified expr:", simplified_expr) -print("") +import sympy -# fails expr1 = sympy.sympify("exp(T)") expr2 = sympy.UnevaluatedExpr(sympy.sympify("T")) old_expr = expr1 * sympy.Pow(sympy.E, expr2) simplified_expr = sympy.powsimp(old_expr) -print("old expr:", old_expr) -print("simplified expr:", simplified_expr) + +print("old expr:", sympy.srepr(old_expr)) +print("simplified expr:", sympy.srepr(simplified_expr))