Add files via upload

This commit is contained in:
Silviu Marian Udrescu 2020-04-29 13:41:52 -04:00 committed by GitHub
parent 233adeb7ae
commit 13148a5c6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 6 additions and 21 deletions

View file

@ -5,7 +5,6 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
import torch
@ -16,6 +15,8 @@ import torch.utils.data as utils
import time
import os
is_cuda = torch.cuda.is_available()
def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
removed_weights = net.linear1.weight[:,idx_neuron]
# Remove the weights associated with the removed input neuron
@ -24,6 +25,9 @@ def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
t = nn.Parameter(t[preserved_ids, :])
net.linear1.weight = nn.Parameter(torch.transpose(t,0,1))
# Adjust the biases
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float().cuda())
if is_cuda:
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float().cuda())
else:
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float())
torch.save(net.state_dict(), save_filename)