symbolic-regression/Code/S_remove_input_neuron.py
Silviu Marian Udrescu f0cc7dfcaa
Add files via upload
2020-03-23 03:35:19 -04:00

29 lines
1.1 KiB
Python

# Remove on input neuron from a NN
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import torch
from torch.utils import data
import pickle
from matplotlib import pyplot as plt
import torch.utils.data as utils
import time
import os
def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
removed_weights = net.linear1.weight[:,idx_neuron]
# Remove the weights associated with the removed input neuron
t = torch.transpose(net.linear1.weight,0,1)
preserved_ids = torch.LongTensor(np.array(list(set(range(n_inp)) - set([idx_neuron]))))
t = nn.Parameter(t[preserved_ids, :])
net.linear1.weight = nn.Parameter(torch.transpose(t,0,1))
# Adjust the biases
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float().cuda())
torch.save(net.state_dict(), save_filename)