diff --git a/rate_collection.py b/rate_collection.py index c1614e2..ba25c4b 100644 --- a/rate_collection.py +++ b/rate_collection.py @@ -1,20 +1,11 @@ # Common Imports -import warnings import functools -import math import os -from operator import mul, add -from collections import OrderedDict, Counter - -from ipywidgets import interact +from operator import mul +from collections import Counter import numpy as np -import matplotlib as mpl -import matplotlib.pyplot as plt -from mpl_toolkits.axes_grid1 import make_axes_locatable -from matplotlib.ticker import MaxNLocator -import networkx as nx # Import Rate from rate import ChemSpecie, ChemRate, ChemComposition, SympyChemRate @@ -1927,6 +1918,43 @@ def write_cxxnetwork_withoutCSE(self, T, composition, redshift, density=0): return None + def simplify_expression(self, old_expr): + """Simplify a Sympy expression before turning it into C++ code.""" + import sympy as sp + + def pow_as_exp_of_log(expr): + """rewrites Pow(x, a) as Exp(a * Log(x)) whenever a is non-integer, a is not in {1/2, -1/2}, and x > 0.""" + x, a = expr.args + if not isinstance(a, sp.core.numbers.Integer) \ + and not ((a == sp.Rational(1, 2)) or (a == sp.Rational(-1, 2))) \ + and x.is_positive: + my_new_expr = sp.Pow(sp.E, sp.UnevaluatedExpr(a * sp.log(x))) + #print("transforming:", expr, "->", my_new_expr) + return my_new_expr + else: + return expr + + new_expr = old_expr.replace(lambda expr: expr.is_Pow, pow_as_exp_of_log) + + simplified_expr = sp.powsimp(new_expr) + ## powsimp is buggy, doesn't perform either replacement below: + ## 1.4e-18*X(0)*X(2)*exp(-6.17283950617284e-5*T)*exp(0.928*log(T)) -> 1.4e-18*X(0)*X(2)*exp(-6.17283950617284e-5*T + 0.928*log(T)) + ## 35.5*X(0)*X(8)*exp(-46707.0/T)*exp(-2.28*log(T)) -> 35.5*X(0)*X(8)*exp(-46707.0/T + -2.28*log(T)) + + # fails with "RecursionError: maximum recursion depth exceeded" + #print("simplifying all...", end="", flush=True) + #simplified_expr = sp.simplify(simplified_expr) + #print("finished.") + + print("") + print("Original expression:", old_expr) + print("") + print("Simplified expression:", simplified_expr) + print("") + + return simplified_expr + + def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): import sympy as sp from sympy.printing import cxxcode @@ -1941,18 +1969,19 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): f = open('empty_rhs.H', 'r') fcontent = f.read() - print('Expanding log') ydots_subs = [] fupdated = '' for i, specie in enumerate(composition): - ydots_subs.append(sp.expand_log(ydots[specie].subs({ChemSpecie('elec').sym_name: sy[0], ChemSpecie('hp').sym_name: sy[1], \ + this_ydot = ydots[specie].subs({ChemSpecie('elec').sym_name: sy[0], ChemSpecie('hp').sym_name: sy[1], \ ChemSpecie('h').sym_name: sy[2], ChemSpecie('hm').sym_name: sy[3], \ ChemSpecie('dp').sym_name: sy[4], ChemSpecie('d').sym_name: sy[5], \ ChemSpecie('h2p').sym_name: sy[6], ChemSpecie('dm').sym_name: sy[7], \ ChemSpecie('h2').sym_name: sy[8], ChemSpecie('hdp').sym_name: sy[9], \ ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ - ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}))) + ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) + this_simplified_ydot = self.simplify_expression(this_ydot) + ydots_subs.append(this_simplified_ydot) print('Substituted ydot ', i+1) a, b = sp.cse(ydots_subs) @@ -1978,7 +2007,7 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0): ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \ ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]}) print('EDOT substituted') - tdot_subs = sp.expand_log(tdot_subs) + #tdot_subs = self.simplify_expression(tdot_subs) # this fails due to a SymPy bug :/ #for Tdot and jacobians, we will use sympy.cse (common subexpression eliminatino) #this greatly simplifies the code and redduces the file size by a factor of 10 diff --git a/reproducer.py b/reproducer.py new file mode 100644 index 0000000..14ca0f1 --- /dev/null +++ b/reproducer.py @@ -0,0 +1,17 @@ +## actual output: +# old expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T')))) +# simplified expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T')))) + +## expected output: +# old expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T')))) +# simplified expr: Mul(exp(Add(Symbol('T')), UnevaluatedExpr(Symbol('T'))))) + +import sympy + +expr1 = sympy.sympify("exp(T)") +expr2 = sympy.UnevaluatedExpr(sympy.sympify("T")) +old_expr = expr1 * sympy.Pow(sympy.E, expr2) +simplified_expr = sympy.powsimp(old_expr) + +print("old expr:", sympy.srepr(old_expr)) +print("simplified expr:", sympy.srepr(simplified_expr)) diff --git a/write_actual_rhs.py b/write_actual_rhs.py index 1158781..678ebe6 100644 --- a/write_actual_rhs.py +++ b/write_actual_rhs.py @@ -14,7 +14,7 @@ T = sp.symbols('T', positive=True) -redshift = sp.symbols('z', real=True) +redshift = sp.symbols('z', positive=True) bb = ratecol.SympyChemRateCollection(rates=list(sym_rates.values()),tdot_switch=1,ydots_lambdified=False,withD=1,massfracs=0)