Add files via upload

This commit is contained in:
Silviu Marian Udrescu 2020-04-07 02:31:51 -04:00 committed by GitHub
parent def13c9018
commit ae3d4e9a8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 9 deletions

View file

@ -61,8 +61,7 @@ def NN_train(pathdir, filename, epochs=1000, lrs=1e-2, N_red_lr=4, pretrained_pa
if n_variables==0 or n_variables==1:
print("Solved!")#, variables[0])
return 0
#elif n_variables==1:
# variables = np.reshape(variables,(len(variables),1))
else:
for j in range(1,n_variables):
v = np.loadtxt(pathdir+"%s" %filename, usecols=(j,))
@ -117,9 +116,7 @@ def NN_train(pathdir, filename, epochs=1000, lrs=1e-2, N_red_lr=4, pretrained_pa
if pretrained_path!="":
model_feynman.load_state_dict(torch.load(pretrained_path))
max_loss = 10000
print("EPOCH", epochs)
check_es_loss = 10000
for i_i in range(N_red_lr):
optimizer_feynman = optim.Adam(model_feynman.parameters(), lr = lrs)
@ -136,12 +133,17 @@ def NN_train(pathdir, filename, epochs=1000, lrs=1e-2, N_red_lr=4, pretrained_pa
prd = data[1].float()
loss = rmse_loss(model_feynman(fct),prd)
if loss < max_loss:
torch.save(model_feynman.state_dict(), "results/NN_trained_models/models/" + filename + ".h5")
max_loss = loss
loss = rmse_loss(model_feynman(fct),prd)
loss.backward()
optimizer_feynman.step()
# Early stopping
if epoch%20==0:
if check_es_loss < loss:
break
else:
torch.save(model_feynman.state_dict(), "results/NN_trained_models/models/" + filename + ".h5")
check_es_loss = loss
print(loss)
lrs = lrs/10