Add files via upload

This commit is contained in:
Silviu Marian Udrescu 2020-06-26 09:07:10 -04:00 committed by GitHub
parent 5e059634a6
commit 84aa8e27fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View file

@ -103,6 +103,8 @@ def RPN_to_pytorch(data, math_expr, lr = 1e-2, N_epochs = 500):
for j in range(N_params-1): for j in range(N_params-1):
trainable_parameters[j] -= lr * trainable_parameters[j].grad trainable_parameters[j] -= lr * trainable_parameters[j].grad
trainable_parameters[j].grad.zero_() trainable_parameters[j].grad.zero_()
if torch.isnan(loss):
break
for nan_i in range(len(trainable_parameters)): for nan_i in range(len(trainable_parameters)):
if torch.isnan(trainable_parameters[nan_i])==True or abs(trainable_parameters[nan_i])>1e7: if torch.isnan(trainable_parameters[nan_i])==True or abs(trainable_parameters[nan_i])>1e7:

View file

@ -101,6 +101,8 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000):
for j in range(N_params-1): for j in range(N_params-1):
trainable_parameters[j] -= lr * trainable_parameters[j].grad trainable_parameters[j] -= lr * trainable_parameters[j].grad
trainable_parameters[j].grad.zero_() trainable_parameters[j].grad.zero_()
if torch.isnan(loss):
break
for i in range(N_epochs): for i in range(N_epochs):
# this order is fixed i.e. first parameters # this order is fixed i.e. first parameters
@ -111,6 +113,8 @@ def final_gd(data, math_expr, lr = 1e-2, N_epochs = 5000):
for j in range(N_params-1): for j in range(N_params-1):
trainable_parameters[j] -= lr/10 * trainable_parameters[j].grad trainable_parameters[j] -= lr/10 * trainable_parameters[j].grad
trainable_parameters[j].grad.zero_() trainable_parameters[j].grad.zero_()
if torch.isnan(loss):
break
for nan_i in range(len(trainable_parameters)): for nan_i in range(len(trainable_parameters)):
if torch.isnan(trainable_parameters[nan_i])==True or abs(trainable_parameters[nan_i])>1e7: if torch.isnan(trainable_parameters[nan_i])==True or abs(trainable_parameters[nan_i])>1e7: