Add files via upload
This commit is contained in:
parent
233adeb7ae
commit
13148a5c6a
9 changed files with 6 additions and 21 deletions
|
|
@ -9,10 +9,6 @@ import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data as utils
|
import torch.utils.data as utils
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from sklearn.metrics import roc_curve, auc
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
from sklearn.manifold import TSNE
|
|
||||||
import seaborn as sns
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import sympy
|
import sympy
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,6 @@ import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data as utils
|
import torch.utils.data as utils
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from sklearn.metrics import roc_curve, auc
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
from sklearn.manifold import TSNE
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,6 @@ import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data as utils
|
import torch.utils.data as utils
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from sklearn.metrics import roc_curve, auc
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
from sklearn.manifold import TSNE
|
|
||||||
import seaborn as sns
|
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
|
||||||
|
|
@ -9,10 +9,6 @@ import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import torch.utils.data as utils
|
import torch.utils.data as utils
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from sklearn.metrics import roc_curve, auc
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
from sklearn.manifold import TSNE
|
|
||||||
import seaborn as sns
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import sympy
|
import sympy
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -16,6 +15,8 @@ import torch.utils.data as utils
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
is_cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
|
def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
|
||||||
removed_weights = net.linear1.weight[:,idx_neuron]
|
removed_weights = net.linear1.weight[:,idx_neuron]
|
||||||
# Remove the weights associated with the removed input neuron
|
# Remove the weights associated with the removed input neuron
|
||||||
|
|
@ -24,6 +25,9 @@ def remove_input_neuron(net,n_inp,idx_neuron,ct_median,save_filename):
|
||||||
t = nn.Parameter(t[preserved_ids, :])
|
t = nn.Parameter(t[preserved_ids, :])
|
||||||
net.linear1.weight = nn.Parameter(torch.transpose(t,0,1))
|
net.linear1.weight = nn.Parameter(torch.transpose(t,0,1))
|
||||||
# Adjust the biases
|
# Adjust the biases
|
||||||
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float().cuda())
|
if is_cuda:
|
||||||
|
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float().cuda())
|
||||||
|
else:
|
||||||
|
net.linear1.bias = nn.Parameter(net.linear1.bias+torch.tensor(ct_median*removed_weights).float())
|
||||||
torch.save(net.state_dict(), save_filename)
|
torch.save(net.state_dict(), save_filename)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import os
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import os
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torchvision import datasets, transforms
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue