diff --git a/Code/S_NN_train.py b/Code/S_NN_train.py index da8983a..48d646b 100644 --- a/Code/S_NN_train.py +++ b/Code/S_NN_train.py @@ -61,8 +61,7 @@ def NN_train(pathdir, filename, epochs=1000, lrs=1e-2, N_red_lr=4, pretrained_pa if n_variables==0 or n_variables==1: print("Solved!")#, variables[0]) return 0 - #elif n_variables==1: - # variables = np.reshape(variables,(len(variables),1)) + else: for j in range(1,n_variables): v = np.loadtxt(pathdir+"%s" %filename, usecols=(j,)) @@ -117,9 +116,7 @@ def NN_train(pathdir, filename, epochs=1000, lrs=1e-2, N_red_lr=4, pretrained_pa if pretrained_path!="": model_feynman.load_state_dict(torch.load(pretrained_path)) - max_loss = 10000 - - print("EPOCH", epochs) + check_es_loss = 10000 for i_i in range(N_red_lr): optimizer_feynman = optim.Adam(model_feynman.parameters(), lr = lrs) @@ -136,12 +133,17 @@ def NN_train(pathdir, filename, epochs=1000, lrs=1e-2, N_red_lr=4, pretrained_pa prd = data[1].float() loss = rmse_loss(model_feynman(fct),prd) - if loss < max_loss: - torch.save(model_feynman.state_dict(), "results/NN_trained_models/models/" + filename + ".h5") - max_loss = loss - loss = rmse_loss(model_feynman(fct),prd) loss.backward() optimizer_feynman.step() + + # Early stopping + if epoch%20==0: + if check_es_loss < loss: + break + else: + torch.save(model_feynman.state_dict(), "results/NN_trained_models/models/" + filename + ".h5") + check_es_loss = loss + print(loss) lrs = lrs/10 diff --git a/Code/S_separability.py b/Code/S_separability.py index 7abde26..df072b6 100644 --- a/Code/S_separability.py +++ b/Code/S_separability.py @@ -230,7 +230,13 @@ def check_separability_multiply(pathdir, filename): f_dependent = np.loadtxt(pathdir+filename, usecols=(n_variables,)) + + # Pick only data which is close enough to the maximum value (5 times less or higher) + max_output = np.max(abs(f_dependent)) + use_idx = np.where(abs(f_dependent)>=max_output/5) + f_dependent = f_dependent[use_idx] f_dependent = np.reshape(f_dependent,(len(f_dependent),1)) + variables = variables[use_idx] factors = torch.from_numpy(variables) if is_cuda: