Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions rate_collection.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,11 @@
# Common Imports
import warnings
import functools
import math
import os

from operator import mul, add
from collections import OrderedDict, Counter

from ipywidgets import interact
from operator import mul
from collections import Counter

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import MaxNLocator
import networkx as nx
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all of the unused imports.


# Import Rate
from rate import ChemSpecie, ChemRate, ChemComposition, SympyChemRate
Expand Down Expand Up @@ -1927,6 +1918,43 @@ def write_cxxnetwork_withoutCSE(self, T, composition, redshift, density=0):
return None


def simplify_expression(self, old_expr):
"""Simplify a Sympy expression before turning it into C++ code."""
import sympy as sp

def pow_as_exp_of_log(expr):
"""rewrites Pow(x, a) as Exp(a * Log(x)) whenever a is non-integer, a is not in {1/2, -1/2}, and x > 0."""
x, a = expr.args
if not isinstance(a, sp.core.numbers.Integer) \
and not ((a == sp.Rational(1, 2)) or (a == sp.Rational(-1, 2))) \
and x.is_positive:
my_new_expr = sp.Pow(sp.E, sp.UnevaluatedExpr(a * sp.log(x)))
#print("transforming:", expr, "->", my_new_expr)
return my_new_expr
else:
return expr

new_expr = old_expr.replace(lambda expr: expr.is_Pow, pow_as_exp_of_log)

simplified_expr = sp.powsimp(new_expr)
## powsimp is buggy, doesn't perform either replacement below:
## 1.4e-18*X(0)*X(2)*exp(-6.17283950617284e-5*T)*exp(0.928*log(T)) -> 1.4e-18*X(0)*X(2)*exp(-6.17283950617284e-5*T + 0.928*log(T))
## 35.5*X(0)*X(8)*exp(-46707.0/T)*exp(-2.28*log(T)) -> 35.5*X(0)*X(8)*exp(-46707.0/T + -2.28*log(T))

# fails with "RecursionError: maximum recursion depth exceeded"
#print("simplifying all...", end="", flush=True)
#simplified_expr = sp.simplify(simplified_expr)
#print("finished.")

print("")
print("Original expression:", old_expr)
print("")
print("Simplified expression:", simplified_expr)
print("")

return simplified_expr


def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0):
import sympy as sp
from sympy.printing import cxxcode
Expand All @@ -1941,18 +1969,19 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0):
f = open('empty_rhs.H', 'r')
fcontent = f.read()

print('Expanding log')

ydots_subs = []
fupdated = ''
for i, specie in enumerate(composition):
ydots_subs.append(sp.expand_log(ydots[specie].subs({ChemSpecie('elec').sym_name: sy[0], ChemSpecie('hp').sym_name: sy[1], \
this_ydot = ydots[specie].subs({ChemSpecie('elec').sym_name: sy[0], ChemSpecie('hp').sym_name: sy[1], \
ChemSpecie('h').sym_name: sy[2], ChemSpecie('hm').sym_name: sy[3], \
ChemSpecie('dp').sym_name: sy[4], ChemSpecie('d').sym_name: sy[5], \
ChemSpecie('h2p').sym_name: sy[6], ChemSpecie('dm').sym_name: sy[7], \
ChemSpecie('h2').sym_name: sy[8], ChemSpecie('hdp').sym_name: sy[9], \
ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \
ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]})))
ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]})
this_simplified_ydot = self.simplify_expression(this_ydot)
ydots_subs.append(this_simplified_ydot)
print('Substituted ydot ', i+1)

a, b = sp.cse(ydots_subs)
Expand All @@ -1978,7 +2007,7 @@ def write_cxxnetwork_withCSE(self, T, composition, redshift, density=0):
ChemSpecie('hd').sym_name: sy[10], ChemSpecie('hepp').sym_name: sy[11], \
ChemSpecie('hep').sym_name: sy[12], ChemSpecie('he').sym_name: sy[13]})
print('EDOT substituted')
tdot_subs = sp.expand_log(tdot_subs)
#tdot_subs = self.simplify_expression(tdot_subs) # this fails due to a SymPy bug :/

#for Tdot and jacobians, we will use sympy.cse (common subexpression eliminatino)
#this greatly simplifies the code and redduces the file size by a factor of 10
Expand Down
17 changes: 17 additions & 0 deletions reproducer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## actual output:
# old expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T'))))
# simplified expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T'))))

## expected output:
# old expr: Mul(exp(Symbol('T')), exp(UnevaluatedExpr(Symbol('T'))))
# simplified expr: Mul(exp(Add(Symbol('T')), UnevaluatedExpr(Symbol('T')))))

import sympy

expr1 = sympy.sympify("exp(T)")
expr2 = sympy.UnevaluatedExpr(sympy.sympify("T"))
old_expr = expr1 * sympy.Pow(sympy.E, expr2)
simplified_expr = sympy.powsimp(old_expr)

print("old expr:", sympy.srepr(old_expr))
print("simplified expr:", sympy.srepr(simplified_expr))
2 changes: 1 addition & 1 deletion write_actual_rhs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

T = sp.symbols('T', positive=True)

redshift = sp.symbols('z', real=True)
redshift = sp.symbols('z', positive=True)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can assume redshift is always positive 😄


bb = ratecol.SympyChemRateCollection(rates=list(sym_rates.values()),tdot_switch=1,ydots_lambdified=False,withD=1,massfracs=0)

Expand Down