From 5e059634a65ca2bdcd4349f5ac66bae03868da27 Mon Sep 17 00:00:00 2001 From: Silviu Marian Udrescu Date: Fri, 26 Jun 2020 08:54:38 -0400 Subject: [PATCH] Add files via upload --- Code/S_final_gd.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/Code/S_final_gd.py b/Code/S_final_gd.py index b4c95a4..6c3c6a3 100644 --- a/Code/S_final_gd.py +++ b/Code/S_final_gd.py @@ -30,7 +30,7 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000): """Recursively transform each numerical value into a learnable parameter.""" import sympy from sympy import Symbol - if isinstance(expr, sympy.numbers.Float): + if isinstance(expr, sympy.numbers.Float) or isinstance(expr, sympy.numbers.Integer) or isinstance(expr, sympy.numbers.Rational) or isinstance(expr, sympy.numbers.Pi): used_param_names = list(param_dict.keys()) + list(unsnapped_param_dict) unsnapped_param_name = get_next_available_key(used_param_names, "p", is_underscore=False) unsnapped_param_dict[unsnapped_param_name] = float(expr) @@ -45,7 +45,6 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000): unsnapped_sub_expr_list.append(unsnapped_sub_expr) return expr.func(*unsnapped_sub_expr_list) - def get_next_available_key(iterable, key, midfix="", suffix="", is_underscore=True): """Get the next available key that does not collide with the keys in the dictionary.""" if key + suffix not in iterable: @@ -61,10 +60,9 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000): # Turn BF expression to pytorch expression eq = parse_expr(math_expr) eq = unsnap_recur(eq,param_dict,unsnapped_param_dict) - + N_vars = len(data[0])-1 N_params = len(unsnapped_param_dict) - possible_vars = ["x%s" %i for i in np.arange(0,30,1)] variables = [] params = [] @@ -76,7 +74,6 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000): symbols = params + variables f = lambdify(symbols, N(eq), torch) - # Set the trainable parameters in the expression trainable_parameters = [] @@ -143,5 +140,3 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000): error = get_symbolic_expr_error(data,str(eq)) return error, complexity, eq - -