symbolic-regression/prior-art/Code/S_remove_input_neuron.py
2025-06-19 14:07:47 +03:00

33 lines
1.2 KiB
Python

# Remove on input neuron from a NN
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import torch
from torch.utils import data
import pickle
from matplotlib import pyplot as plt
import torch.utils.data as utils
import time
import os
is_cuda = torch.cuda.is_available()
def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
removed_weights = net.linear1.weight[:,idx_neuron]
# Remove the weights associated with the removed input neuron
t = torch.transpose(net.linear1.weight,0,1)
preserved_ids = torch.LongTensor(np.array(list(set(range(n_inp)) - set([idx_neuron]))))
t = nn.Parameter(t[preserved_ids, :])
net.linear1.weight = nn.Parameter(torch.transpose(t,0,1))
# Adjust the biases
if is_cuda:
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float().cuda())
else:
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float())
torch.save(net.state_dict(), save_filename)