diff --git a/Code/RPN_to_pytorch.py b/Code/RPN_to_pytorch.py index b03ed0e..4387bba 100644 --- a/Code/RPN_to_pytorch.py +++ b/Code/RPN_to_pytorch.py @@ -122,9 +122,12 @@ def RPN_to_pytorch(data_file, math_expr, lr = 1e-2, N_epochs = 500): else: eq = eq.subs(parm, trainable_parameters[ii]) complexity = complexity + get_number_DL(trainable_parameters[ii].detach().numpy()) + n_variables = len(eq.free_symbols) + n_operations = len(count_ops(eq,visual=True).free_symbols) + if n_operations!=0 or n_variables!=0: + complexity = complexity + (n_variables+n_operations)*np.log2((n_variables+n_operations)) ii = ii+1 - error = torch.mean((f(*input)-y)**2).data.numpy()*1 return error, complexity, eq