Unverified Commit f0416ebf authored by NingMa's avatar NingMa Committed by GitHub

Add files via upload

parent fec106d0
This diff is collapsed.
import numpy as np
import os
import os.path
from PIL import Image
def pil_loader(path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def make_dataset_fromlist(image_list):
# with open(image_list) as f:
image_index = [x.split(' ')[0] for x in image_list]
# with open(image_list) as f:
label_list = []
selected_list = []
for ind, x in enumerate(image_list):
label = x.split(' ')[1].strip()
label_list.append(int(label))
selected_list.append(ind)
image_index = np.array(image_index)
label_list = np.array(label_list)
image_index = image_index[selected_list]
return image_index, label_list
def return_classlist(image_list):
with open(image_list) as f:
label_list = []
for ind, x in enumerate(f.readlines()):
label = x.split(' ')[0].split('/')[-2]
if label not in label_list:
label_list.append(str(label))
return label_list
class Imagelists_VISDA(object):
def __init__(self, image_list, root="./data/multi/",
transform=None, target_transform=None, test=False):
imgs, labels = make_dataset_fromlist(image_list)
self.imgs = imgs
self.labels = labels
self.transform = transform
self.target_transform = target_transform
self.loader = pil_loader
self.root = root
self.test = test
self.return_index=False
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is
class_index of the target class.
"""
path = os.path.join(self.root, self.imgs[index])
target = self.labels[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.test:
if self.return_index:
return img, target,index
else:
return img, target
else:
return img, target, self.imgs[index]
def __len__(self):
return len(self.imgs)
class STL(object):
def __init__(self, root="./data/multi/", ttype="labeled",
transform=None, target_transform=None, test=False):
# imgs, labels = make_dataset_fromlist(image_list)
if ttype=="labeled":
imgs=np.load(os.path.join(root,"labeled_data"))
lables=np.load(os.path.join(root,"labeled_label"))
elif ttype=="unlabeled":
imgs=np.load(os.path.join(root,"unlabeled_data"))
lables=np.load(os.path.join(root,"unlabeled_label"))
else:
imgs=np.load(os.path.join(root,"test_data"))
lables=np.load(os.path.join(root,"test_label"))
self.imgs = imgs
self.labels = labels
self.transform = transform
self.target_transform = target_transform
self.loader = pil_loader
self.root = root
self.test = test
self.return_index=False
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is
class_index of the target class.
"""
# path = os.path.join(self.root, self.imgs[index])
target = self.labels[index]
img = Image.fromarray(self.imgs[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.test:
if self.return_index:
return img, target,index
else:
return img, target
else:
return img, target, self.imgs[index]
def __len__(self):
return len(self.imgs)
\ No newline at end of file
class BaseDataLoader():
def __init__(self):
pass
def initialize(self,batch_size):
self.batch_size = batch_size
self.serial_batches = 0
self.nThreads = 2
self.max_dataset_size=float("inf")
pass
def load_data():
return None
import torch
import numpy as np
import random
from PIL import Image
from torch.utils.data import Dataset,Sampler
from collections import defaultdict
import os
import os.path
import cv2
import torchvision
def make_dataset(image_list, labels):
if labels:
len_ = len(image_list)
images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
else:
if len(image_list[0].split()) > 2:
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
else:
images = [(val.split()[0], int(val.split()[1])) for val in image_list]
return images
def rgb_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def l_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('L')
class ImageList(Dataset):
def __init__(self, image_list, root=None,labels=None, transform=None, target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
# if len(imgs) == 0:
# raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.imgs = imgs
self.root = root
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
self.classwise_indices = defaultdict(list)
self.idx2class = []
for [_, target] in self.imgs:
# print(target)
self.idx2class.append(target)
# self.idx2class=[target for i,[_,target] in enumerate(self.imgs)]
for idx, classes in enumerate(self.idx2class):
self.classwise_indices[classes].append(idx)
def get_class(self, idx):
return self.idx2class[idx]
def __getitem__(self, index):
# print("index",index)
path, target = self.imgs[index]
path=os.path.join(self.root,path)
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
class ImageList_idx(Dataset):
def __init__(self, image_list, root=None,labels=None, transform=None, target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.imgs = imgs
self.root = root
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
self.classwise_indices = defaultdict(list)
self.idx2class = []
for [_, target] in self.imgs:
# print(target)
self.idx2class.append(target)
# self.idx2class=[target for i,[_,target] in enumerate(self.imgs)]
for idx, classes in enumerate(self.idx2class):
self.classwise_indices[classes].append(idx)
# def get_psd_class(self,idx):
# return self.idx2class[idx]
def set_psd_class(self,pred):
self.idx2class=list(pred.reshape(-1)) # numpy to list the pred is not shuffleed
self.classwise_indices = defaultdict(list)
for idx, classes in enumerate(self.idx2class):
self.classwise_indices[classes].append(idx)
def get_class(self, idx):
return self.idx2class[idx]
def __getitem__(self, index):
path, target = self.imgs[index]
path=os.path.join(self.root,path)
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
def __len__(self):
return len(self.imgs)
class PairBatchSampler(Sampler):
def __init__(self, dataset, batch_size, num_iterations=None):
self.dataset = dataset
self.batch_size = batch_size
self.num_iterations = num_iterations
def __iter__(self):
indices = list(range(len(self.dataset)))
random.shuffle(indices)
for k in range(len(self)):
if self.num_iterations is None:
offset = k*self.batch_size
batch_indices = indices[offset:offset+self.batch_size]
else:
batch_indices = random.sample(range(len(self.dataset)),
self.batch_size)
pair_indices = []
for idx in batch_indices:
y = self.dataset.get_class(idx)
pair_indices.append(random.choice(self.dataset.classwise_indices[y]))
# print(len(batch_indices + pair_indices))
# print("1", batch_indices, "s", pair_indices)
yield batch_indices + pair_indices
def __len__(self):
if self.num_iterations is None:
return (len(self.dataset)) // self.batch_size
else:
return self.num_iterations
\ No newline at end of file
import sys
# sys.path.append('../data')
from data_pro.unaligned_data_loader import UnalignedDataLoader
from data_pro.svhn import load_svhn
from data_pro.mnist import load_mnist
from data_pro.mnist_m import load_mnistm
from data_pro.usps_ import load_usps
from data_pro.gtsrb import load_gtsrb
from data_pro.synth_number import load_syn
from data_pro.synth_traffic import load_syntraffic
def return_dataset(data, scale=False, usps=False, all_use='no'):
if data == 'svhn':
train_image, train_label, \
test_image, test_label = load_svhn()
if data == 'mnist':
train_image, train_label, \
test_image, test_label = load_mnist()
#print(train_image.shape)
if data == 'mnistm':
train_image, train_label, \
test_image, test_label = load_mnistm()
#print(train_image.shape)
if data == 'usps':
train_image, train_label, \
test_image, test_label = load_usps()
if data == 'synth':
train_image, train_label, \
test_image, test_label = load_syntraffic()
if data == 'gtsrb':
train_image, train_label, \
test_image, test_label = load_gtsrb()
if data == 'syn':
train_image, train_label, \
test_image, test_label = load_syn()
return train_image, train_label, test_image, test_label
def dataset_read(target, batch_size):
S1 = {}
S1_test = {}
S2 = {}
S2_test = {}
S3 = {}
S3_test = {}
S4 = {}
S4_test = {}
S = [S1, S2, S3, S4]
S_test = [S1_test, S2_test, S3_test, S4_test]
T = {}
T_test = {}
domain_all = ['mnistm', 'mnist', 'usps', 'svhn', 'syn']
domain_all.remove(target)
target_train, target_train_label , target_test, target_test_label = return_dataset(target)
print(domain_all)
for i in range(len(domain_all)):
source_train, source_train_label, source_test , source_test_label = return_dataset(domain_all[i])
S[i]['imgs'] = source_train
S[i]['labels'] = source_train_label
#input target sample when test, source performance is not important
S_test[i]['imgs'] = target_test
S_test[i]['labels'] = target_test_label
#S['imgs'] = train_source
#S['labels'] = s_label_train
T['imgs'] = target_train
T['labels'] = target_train_label
# input target samples for both
#S_test['imgs'] = test_target
#S_test['labels'] = t_label_test
T_test['imgs'] = target_test
T_test['labels'] = target_test_label
scale = 32
train_loader = UnalignedDataLoader()
train_loader.initialize(S, T, batch_size, batch_size, scale=scale)
dataset = train_loader.load_data()
test_loader = UnalignedDataLoader()
test_loader.initialize(S_test, T_test, batch_size, batch_size, scale=scale)
dataset_test = test_loader.load_data()
return dataset, dataset_test
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import numpy as np
class Dataset(data.Dataset):
"""Args:
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, data, label,
transform=None,target_transform=None):
self.transform = transform
self.target_transform = target_transform
self.data = data
self.labels = label
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
# print(img.shape)
if img.shape[0] != 1:
#print(img)
img = Image.fromarray(np.uint8(np.asarray(img.transpose((1, 2, 0)))))
elif img.shape[0] == 1:
im = np.uint8(np.asarray(img))
#print(np.vstack([im,im,im]).shape)
im = np.vstack([im, im, im]).transpose((1, 2, 0))
img = Image.fromarray(im)
if self.target_transform is not None:
target = self.target_transform(target)
if self.transform is not None:
img = self.transform(img)
# return img, target
return img, target
def __len__(self):
return len(self.data)
import pickle as p
import numpy as np
import os
from PIL import Image
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
""" 载入cifar全部数据 """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b,))
X, Y = load_CIFAR_batch(f)
xs.append(X) #将所有batch整合起来
ys.append(Y)
Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte
def split_ssl_data( data, target, num_labels, num_classes, index=None, include_lb_to_ulb=True):
"""
data & target is splitted into labeled and unlabeld data.
Args
index: If np.array of index is given, select the data[index], target[index] as labeled samples.
include_lb_to_ulb: If True, labeled data is also included in unlabeld data
"""
data, target = np.array(data), np.array(target)
lb_data, lbs, lb_idx, = sample_labeled_data(data, target, num_labels, num_classes, index)
ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) # unlabeled_data index of data
if include_lb_to_ulb:
return lb_data, lbs, data, target
else:
return lb_data, lbs, data[ulb_idx], target[ulb_idx]
def sample_labeled_data(data, target,
num_labels, num_classes,
index=None, name=None):
'''
samples for labeled data
(sampling with balanced ratio over classes)
'''
assert num_labels % num_classes == 0
if not index is None:
index = np.array(index, dtype=np.int32)
return data[index], target[index], index
# dump_path = os.path.join(args.save_dir, args.save_name, 'sampled_label_idx.npy')
# if os.path.exists(dump_path):
# lb_idx = np.load(dump_path)
# lb_data = data[lb_idx]
# lbs = target[lb_idx]
# return lb_data, lbs, lb_idx
samples_per_class = int(num_labels / num_classes)
lb_data = []
lbs = []
lb_idx = []
np.random.seed(2022)
for c in range(num_classes):
idx = np.where(target == c)[0]
idx = np.random.choice(idx, samples_per_class, False)
lb_idx.extend(idx)
lb_data.extend(data[idx])
lbs.extend(target[idx])
# np.save(dump_path, np.array(lb_idx))
# np.save(dump_path, np.array(lb_idx))
return np.array(lb_data), np.array(lbs), np.array(lb_idx)
def mysave(dataset,lb_data,lbs, txt_path, ROOT,cnt):
# cnt=0
lines=[]
with open(txt_path,"w") as f:
isfirst=True
for (img,label) in zip(lb_data,lbs):
if not os.path.exists(os.path.join(ROOT,str(label))):
os.makedirs(os.path.join(ROOT,str(label)))
# print(type(img),img.shape)
image=Image.fromarray(img.astype(np.uint8))
image=image.convert('RGB')
image.save(os.path.join(ROOT,str(label),"{}.jpg".format(cnt)))
if isfirst:
isfirst=False
f.writelines(["{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
else :
f.writelines(["\n{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
cnt+=1
# f.writelines(lines)
return cnt
ROOT="/data/maning/datasets/cifar-10-batches-py/"
Xtr, Ytr, Xte, Yte=load_CIFAR10(ROOT)
print(Xtr.shape,Ytr.shape,Xte.shape,Yte.shape)
num_classes=10
dataset="CIFAR10_4(1)"
num_labels=num_classes*4
lb_data, lbs, data, target=split_ssl_data( Xtr, Ytr, num_labels, num_classes, index=None, include_lb_to_ulb=True)
p=os.path.join(ROOT,"SSL",dataset)
import shutil
if os.path.exists(p):
shutil.rmtree(p)
cnt=0
cnt1=mysave(dataset,lb_data,lbs,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","labeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt)
cnt2=mysave(dataset,data,target,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","unlabeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt1)
cnt3=mysave(dataset,Xte,Yte, os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","validation_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt2)
print(cnt1,cnt2,cnt3)
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_data.npy", np.array(lb_data))
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_label.npy", np.array(lbs))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_data.npy", np.array(data))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_label.npy", np.array(target))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_data.npy", np.array(Xte))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_label.npy", np.array(Yte))
import pickle as p
import numpy as np
import os
from PIL import Image
def load_CIFAR_batch(filename):
""" 载入cifar数据集的一个batch """
with open(filename, 'rb') as f:
datadict = p.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y
def unpickle(file):
fo = open(file, 'rb')
dict = p.load(fo,encoding='latin1')
fo.close()
return dict
def load_CIFAR100(ROOT):
""" 载入cifar全部数据 """
path = os.path.join(ROOT, "train")
batch = unpickle(path)
# print(batch['fine_labels'] )
Xtr = np.array(batch['data'])
Ytr = np.array(batch['fine_labels'] )
path = os.path.join(ROOT, "test")
batch = unpickle(path)
Xte =np.array( batch['data'])
Yte = np.array(batch['fine_labels'] )
# print(Xte.size(),Yte.size())
return Xtr, Ytr, Xte, Yte
def split_ssl_data( data, target, num_labels, num_classes, index=None, include_lb_to_ulb=True):
"""
data & target is splitted into labeled and unlabeld data.
Args
index: If np.array of index is given, select the data[index], target[index] as labeled samples.
include_lb_to_ulb: If True, labeled data is also included in unlabeld data
"""
data, target = np.array(data), np.array(target)
lb_data, lbs, lb_idx, = sample_labeled_data(data, target, num_labels, num_classes, index)
ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) # unlabeled_data index of data
if include_lb_to_ulb:
return lb_data, lbs, data, target
else:
return lb_data, lbs, data[ulb_idx], target[ulb_idx]
def sample_labeled_data(data, target,
num_labels, num_classes,
index=None, name=None):
'''
samples for labeled data
(sampling with balanced ratio over classes)
'''
assert num_labels % num_classes == 0
if not index is None:
index = np.array(index, dtype=np.int32)
return data[index], target[index], index
# dump_path = os.path.join(args.save_dir, args.save_name, 'sampled_label_idx.npy')
# if os.path.exists(dump_path):
# lb_idx = np.load(dump_path)
# lb_data = data[lb_idx]
# lbs = target[lb_idx]
# return lb_data, lbs, lb_idx
samples_per_class = int(num_labels / num_classes)
lb_data = []
lbs = []
lb_idx = []
np.random.seed(2022)
for c in range(num_classes):
idx = np.where(target == c)[0]
idx = np.random.choice(idx, samples_per_class, False)
lb_idx.extend(idx)
lb_data.extend(data[idx])
lbs.extend(target[idx])
# np.save(dump_path, np.array(lb_idx))
# np.save(dump_path, np.array(lb_idx))
return np.array(lb_data), np.array(lbs), np.array(lb_idx)
def mysave(dataset,lb_data,lbs, txt_path, ROOT,cnt):
# cnt=0
lines=[]
with open(txt_path,"w") as f:
isfirst=True
for (img,label) in zip(lb_data,lbs):
if not os.path.exists(os.path.join(ROOT,str(label))):
os.makedirs(os.path.join(ROOT,str(label)))
image=Image.fromarray(img.reshape(3,32,32).transpose(1, 2, 0))
image=image.convert('RGB')
image.save(os.path.join(ROOT,str(label),"{}.jpg".format(cnt)))
if isfirst:
isfirst=False
f.writelines(["{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
else :
f.writelines(["\n{}/{}/{}.jpg {}".format(dataset,label,cnt,label)])
cnt+=1
# f.writelines(lines)
return cnt
ROOT="/data/maning/datasets/cifar-100-python/"
target_path="/data/maning/git/shot/data"
Xtr, Ytr, Xte, Yte=load_CIFAR100(ROOT)
print(Xtr.shape,Ytr.shape,Xte.shape,Yte.shape)
num_classes=100
dataset="CIFAR100_4"
num_labels=num_classes*4
lb_data, lbs, data, target=split_ssl_data( Xtr, Ytr, num_labels, num_classes, index=None, include_lb_to_ulb=True)
p=os.path.join(target_path,"SSL",dataset)
import shutil
if os.path.exists(p):
shutil.rmtree(p)
cnt=0
cnt1=mysave(dataset,lb_data,lbs,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","labeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt)
cnt2=mysave(dataset,data,target,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","unlabeled_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt1)
cnt3=mysave(dataset,Xte,Yte, os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","validation_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))),p,cnt2)
# os.path.join(
shutil.copy(os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","validation_target_images_{}_{}.txt".format(dataset,int(num_labels/num_classes)))
,os.path.join("/data/maning/git/shot/data/SSDA_split/SSL","labeled_source_images_{}_{}.txt".format(dataset,int(num_labels/num_classes))
))
print(cnt1,cnt2,cnt3)
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_data.npy", np.array(lb_data))
# np.save("/data/maning/datasets/cifar-10-batches-py/labeled_label.npy", np.array(lbs))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_data.npy", np.array(data))
# np.save("/data/maning/datasets/cifar-10-batches-py/unlabeled_label.npy", np.array(target))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_data.npy", np.array(Xte))
# np.save("/data/maning/datasets/cifar-10-batches-py/test_label.npy", np.array(Yte))
import numpy as np
import pickle as pkl
def load_gtsrb():
data_target = pkl.load(open('../data/data_gtsrb'))
target_train = np.random.permutation(len(data_target['image']))
data_t_im = data_target['image'][target_train[:31367], :, :, :]
data_t_im_test = data_target['image'][target_train[31367:], :, :, :]
data_t_label = data_target['label'][target_train[:31367]] + 1
data_t_label_test = data_target['label'][target_train[31367:]] + 1
data_t_im = data_t_im.transpose(0, 3, 1, 2).astype(np.float32)
data_t_im_test = data_t_im_test.transpose(0, 3, 1, 2).astype(np.float32)
return data_t_im, data_t_label, data_t_im_test, data_t_label_test
This diff is collapsed.
import numpy as np
from scipy.io import loadmat
base_dir = './data'
def load_mnist(scale=True, usps=False, all_use=False):
mnist_data = loadmat(base_dir + '/mnist_data.mat')
if scale:
mnist_train = np.reshape(mnist_data['train_32'], (55000, 32, 32, 1))
mnist_test = np.reshape(mnist_data['test_32'], (10000, 32, 32, 1))
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
mnist_train = mnist_train.transpose(0, 3, 1, 2).astype(np.float32)
mnist_test = mnist_test.transpose(0, 3, 1, 2).astype(np.float32)
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
else:
mnist_train = mnist_data['train_28']
mnist_test = mnist_data['test_28']
mnist_labels_train = mnist_data['label_train']
mnist_labels_test = mnist_data['label_test']
mnist_train = mnist_train.astype(np.float32)
mnist_test = mnist_test.astype(np.float32)
mnist_train = mnist_train.transpose((0, 3, 1, 2))
mnist_test = mnist_test.transpose((0, 3, 1, 2))
train_label = np.argmax(mnist_labels_train, axis=1)
inds = np.random.permutation(mnist_train.shape[0])
mnist_train = mnist_train[inds]
train_label = train_label[inds]
test_label = np.argmax(mnist_labels_test, axis=1)
mnist_train = mnist_train[:25000]
train_label = train_label[:25000]
mnist_test = mnist_test[:25000]
test_label = test_label[:25000]
# print('mnist train X shape->', mnist_train.shape)
# print('mnist train y shape->', train_label.shape)
# print('mnist test X shape->', mnist_test.shape)
# print('mnist test y shape->', test_label.shape)
return mnist_train, train_label, mnist_test, test_label
import numpy as np
from scipy.io import loadmat
base_dir = './data'
def load_mnistm(scale=True, usps=False, all_use=False):
mnistm_data = loadmat(base_dir + '/mnistm_with_label.mat')
mnistm_train = mnistm_data['train']
mnistm_test = mnistm_data['test']
mnistm_train = mnistm_train.transpose(0, 3, 1, 2).astype(np.float32)
mnistm_test = mnistm_test.transpose(0, 3, 1, 2).astype(np.float32)
mnistm_labels_train = mnistm_data['label_train']
mnistm_labels_test = mnistm_data['label_test']
train_label = np.argmax(mnistm_labels_train, axis=1)
inds = np.random.permutation(mnistm_train.shape[0])
mnistm_train = mnistm_train[inds]
train_label = train_label[inds]
test_label = np.argmax(mnistm_labels_test, axis=1)
mnistm_train = mnistm_train[:25000]
train_label = train_label[:25000]
mnistm_test = mnistm_test[:9000]
test_label = test_label[:9000]
# print('mnist_m train X shape->', mnistm_train.shape)
# print('mnist_m train y shape->', train_label.shape)
# print('mnist_m test X shape->', mnistm_test.shape)
# print('mnist_m test y shape->', test_label.shape)
return mnistm_train, train_label, mnistm_test, test_label
This diff is collapsed.
"""Data loading facilities for Omniglot experiment."""
import random
import os
from os.path import join
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
from torch.utils.data import Dataset
from torchvision.datasets.utils import list_dir, list_files
from torchvision import transforms
from PIL import Image
DEFAULT_TRANSFORM =transforms.Compose([transforms.Resize((100,100)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)])
CAT2LABEL={'bike': 0, 'monitor': 1, 'laptop_computer': 2, 'mug': 3, 'calculator': 4, 'projector': 5, 'keyboard': 6, 'headphones': 7, 'back_pack': 8, 'mouse': 9}
###############################################################################
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
def make_dataset(image_list, labels):
if labels:
len_ = len(image_list)
images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
else:
if len(image_list[0].split()) > 2:
images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
else:
images = [(val.split()[0], int(val.split()[1])) for val in image_list]
return images
def rgb_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('RGB')
def l_loader(path):
with open(path, 'rb') as f:
with Image.open(f) as img:
return img.convert('L')
class ImageList(Dataset):
def __init__(self, image_list, labels=None, transform=None,target_transform=None, mode='RGB'):
imgs = make_dataset(image_list, labels)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
if mode == 'RGB':
self.loader = rgb_loader
elif mode == 'L':
self.loader = l_loader
self.return_index=0
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
if self.return_index==1:
return img, target,index
else:
return img, target
def __len__(self):
return len(self.imgs)
import os
import torch
from torchvision import transforms
from data_pro.SSDA_data_list import Imagelists_VISDA, return_classlist
from data_pro.data_list import ImageList_idx,ImageList,PairBatchSampler
class ResizeImage():
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
def return_dataset(args):
base_path = '/data/maning/git/shot/data/SSDA_split/%s' % args.dataset
if args.dataset in "office-home":
# args.dataset='OfficeHomeDataset'
root = '/data/maning/git/shot/data/OfficeHomeDataset/'
else :
root = '/data/maning/git/shot/data/%s/' % args.dataset
image_set_file_s = \
os.path.join(base_path,
'labeled_source_images_' +
args.s + '.txt')
image_set_file_t = \
os.path.join(base_path,
'labeled_target_images_' +
args.t + '_%d.txt' % (args.num))
image_set_file_t_val = \
os.path.join(base_path,
'validation_target_images_' +
args.t + '_3.txt')
image_set_file_unl = \
os.path.join(base_path,
'unlabeled_target_images_' +
args.t + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'train': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
ResizeImage(256),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
def split(train_r, source_path):
with open(source_path, 'r') as f:
data = f.readlines()
train_len = int(len(data) * train_r)
train, val = torch.utils.data.random_split(data, [train_len, len(data) - train_len])
return train, val
if args.dataset in "multi":
source_train, source_val = split(train_r=0.95, source_path=image_set_file_s)
else:
source_train, source_val = split(train_r=0.90, source_path=image_set_file_s)
print("source_train and val num", len(source_train), len(source_val))
source_dataset = ImageList(source_train, root=root,
transform=data_transforms['train'])
source_val_dataset = ImageList(source_val, root=root,
transform=data_transforms['val'])
target_dataset = ImageList(open(image_set_file_t).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_val = ImageList(open(image_set_file_t_val).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_unl = ImageList_idx(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['val'])
target_dataset_test = ImageList(open(image_set_file_unl).readlines(), root=root,
transform=data_transforms['test'])
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
# if args.net == 'alexnet':
# bs = 20
# else:
# bs = 16
bs=args.batch_size*2 # for KD term
if args.skd_src==1:
source_loader = torch.utils.data.DataLoader(source_dataset,batch_sampler=PairBatchSampler(source_dataset, args.batch_size),num_workers=args.worker)
else:
source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs,
num_workers=args.worker, shuffle=True,
drop_last=False)
source_val_loader = torch.utils.data.DataLoader(source_val_dataset, batch_size=bs,
num_workers=args.worker, shuffle=False,
drop_last=False)
target_loader = \
torch.utils.data.DataLoader(target_dataset,
batch_size=min(bs, len(target_dataset)),
num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_val = \
torch.utils.data.DataLoader(target_dataset_val,
batch_size=min(bs,
len(target_dataset_val)),
num_workers=args.worker,
shuffle=True, drop_last=False)
if args.skd_src==1:
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_sampler=PairBatchSampler(target_dataset_unl, args.batch_size),num_workers=args.worker)
else:
target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 1, num_workers=args.worker,
shuffle=True, drop_last=True)
target_loader_test = \
torch.utils.data.DataLoader(target_dataset_test,
batch_size=bs * 1, num_workers=args.worker,
shuffle=False, drop_last=False)
return source_loader, source_val_loader, target_loader, target_loader_unl, \
target_loader_val, target_loader_test, target_dataset_unl
def return_dataset_test(args):
base_path = './data/txt/%s' % args.dataset
root = './data/%s/' % args.dataset
image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
image_set_file_test = os.path.join(base_path,
'unlabeled_target_images_' +
args.target + '_%d.txt' % (args.num))
if args.net == 'alexnet':
crop_size = 227
else:
crop_size = 224
data_transforms = {
'test': transforms.Compose([
ResizeImage(256),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
transform=data_transforms['test'],
test=True)
class_list = return_classlist(image_set_file_s)
print("%d classes in this dataset" % len(class_list))
if args.net == 'alexnet':
bs = 32
else:
bs = 24
target_loader_unl = \
torch.utils.data.DataLoader(target_dataset_unl,
batch_size=bs * 2, num_workers=3,
shuffle=False, drop_last=False)
return target_loader_unl, class_list
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from scipy.io import loadmat
import numpy as np
import sys
sys.path.append('../utils/')
from data_pro.utils import dense_to_one_hot
base_dir = './data'
def load_svhn():
svhn_train = loadmat(base_dir + '/svhn_train_32x32.mat')
svhn_test = loadmat(base_dir + '/svhn_test_32x32.mat')
svhn_train_im = svhn_train['X']
svhn_train_im = svhn_train_im.transpose(3, 2, 0, 1).astype(np.float32)
# print('svhn train y shape before dense_to_one_hot->', svhn_train['y'].shape)
svhn_label = dense_to_one_hot(svhn_train['y'])
# print('svhn train y shape after dense_to_one_hot->',svhn_label.shape)
svhn_test_im = svhn_test['X']
svhn_test_im = svhn_test_im.transpose(3, 2, 0, 1).astype(np.float32)
svhn_label_test = dense_to_one_hot(svhn_test['y'])
svhn_train_im = svhn_train_im[:25000]
svhn_label = svhn_label[:25000]
svhn_test_im = svhn_test_im[:9000]
svhn_label_test = svhn_label_test[:9000]
# print('svhn train X shape->', svhn_train_im.shape)
# print('svhn train y shape->', svhn_label.shape)
# print('svhn test X shape->', svhn_test_im.shape)
# print('svhn test y shape->', svhn_label_test.shape)
return svhn_train_im, svhn_label, svhn_test_im, svhn_label_test
import numpy as np
from scipy.io import loadmat
import sys
sys.path.append('../utils/')
from data_pro.utils import dense_to_one_hot
base_dir = './data'
def load_syn(scale=True, usps=False, all_use=False):
syn_data = loadmat(base_dir + '/syn_number.mat')
syn_train = syn_data['train_data']
syn_test = syn_data['test_data']
syn_train = syn_train.transpose(0, 3, 1, 2).astype(np.float32)
syn_test = syn_test.transpose(0, 3, 1, 2).astype(np.float32)
syn_labels_train = syn_data['train_label']
syn_labels_test = syn_data['test_label']
train_label = syn_labels_train
inds = np.random.permutation(syn_train.shape[0])
syn_train = syn_train[inds]
train_label = train_label[inds]
test_label = syn_labels_test
# syn_train = syn_train[:25000]
# train_label = train_label[:25000]
# syn_test = syn_test[:9000]
# test_label = test_label[:9000]
train_label = dense_to_one_hot(train_label)
test_label = dense_to_one_hot(test_label)
# print('syn number train X shape->', syn_train.shape)
# print('syn number train y shape->', train_label.shape)
# print('syn number test X shape->', syn_test.shape)
# print('syn number test y shape->', test_label.shape)
return syn_train, train_label, syn_test, test_label
import numpy as np
import pickle as pkl
def load_syntraffic():
data_source = pkl.load(open('../data/data_synthetic'))
source_train = np.random.permutation(len(data_source['image']))
data_s_im = data_source['image'][source_train[:len(data_source['image'])], :, :, :]
data_s_im_test = data_source['image'][source_train[len(data_source['image']) - 2000:], :, :, :]
data_s_label = data_source['label'][source_train[:len(data_source['image'])]]
data_s_label_test = data_source['label'][source_train[len(data_source['image']) - 2000:]]
data_s_im = data_s_im.transpose(0, 3, 1, 2).astype(np.float32)
data_s_im_test = data_s_im_test.transpose(0, 3, 1, 2).astype(np.float32)
return data_s_im, data_s_label, data_s_im_test, data_s_label_test
\ No newline at end of file
import numpy as np
import os
import os.path
from PIL import Image
# def pil_loader(path):
# with open(path, 'rb') as f:
# img = Image.open(f)
# return img.convert('RGB')
# def make_dataset_fromlist(image_list):
# # with open(image_list) as f:
# image_index = [x.split(' ')[0] for x in image_list]
# # with open(image_list) as f:
# label_list = []
# selected_list = []
# for ind, x in enumerate(image_list):
# label = x.split(' ')[1].strip()
# label_list.append(int(label))
# selected_list.append(ind)
# image_index = np.array(image_index)
# label_list = np.array(label_list)
# image_index = image_index[selected_list]
# return image_index, label_list
# def return_classlist(image_list):
# with open(image_list) as f:
# label_list = []
# for ind, x in enumerate(f.readlines()):
# label = x.split(' ')[0].split('/')[-2]
# if label not in label_list:
# label_list.append(str(label))
# return label_list
class Text(object):
def __init__(self, text, root="./data/multi/",
transform=None, target_transform=None, test=False):
texts, labels = text[0],text[1]
self.texts = texts
self.labels = labels
self.transform = transform
self.target_transform = target_transform
self.root = root
self.test = test
self.return_index=False
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is
class_index of the target class.
"""
# path = os.path.join(self.root, self.imgs[index])
# target = self.labels[index]
# img = self.loader(path)
single_text=self.texts[index]
target=self.labels[index]
# if self.transform is not None:
# single_text = self.transform(single_text)
if self.target_transform is not None:
target = self.target_transform(target)
if not self.test:
if self.return_index:
return single_text, target,index
else:
return single_text, target
else:
return single_text, target, self.texts[index]
def __len__(self):
return len(self.texts)
import torch.utils.data
import torchnet as tnt
from builtins import object
import torchvision.transforms as transforms
from data_pro.datasets_ import Dataset
class PairedData(object):
def __init__(self, data_loader_A, data_loader_B, data_loader_C, data_loader_D, data_loader_t, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.data_loader_C = data_loader_C
self.data_loader_D = data_loader_D
self.data_loader_t = data_loader_t
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
self.max_dataset_size = max_dataset_size
def __iter__(self):
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
self.data_loader_A_iter = iter(self.data_loader_A)
self.data_loader_B_iter = iter(self.data_loader_B)
self.data_loader_C_iter = iter(self.data_loader_C)
self.data_loader_D_iter = iter(self.data_loader_D)
self.data_loader_t_iter = iter(self.data_loader_t)
self.iter = 0
return self
def __next__(self):
A, A_paths = None, None
B, B_paths = None, None
C, C_paths = None, None
D, D_paths = None, None
t, t_paths = None, None
try:
A, A_paths = next(self.data_loader_A_iter)
except StopIteration:
if A is None or A_paths is None:
self.stop_A = True
self.data_loader_A_iter = iter(self.data_loader_A)
A, A_paths = next(self.data_loader_A_iter)
try:
B, B_paths = next(self.data_loader_B_iter)
except StopIteration:
if B is None or B_paths is None:
self.stop_B = True
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
try:
C, C_paths = next(self.data_loader_C_iter)
except StopIteration:
if C is None or C_paths is None:
self.stop_C = True
self.data_loader_C_iter = iter(self.data_loader_C)
C, C_paths = next(self.data_loader_C_iter)
try:
D, D_paths = next(self.data_loader_D_iter)
except StopIteration:
if D is None or D_paths is None:
self.stop_D = True
self.data_loader_D_iter = iter(self.data_loader_D)
D, D_paths = next(self.data_loader_D_iter)
try:
t, t_paths = next(self.data_loader_t_iter)
except StopIteration:
if t is None or t_paths is None:
self.stop_t = True
self.data_loader_t_iter = iter(self.data_loader_t)
t, t_paths = next(self.data_loader_t_iter)
if (self.stop_A and self.stop_B and self.stop_C and self.stop_D and self.stop_t) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
self.stop_C = False
self.stop_D = False
self.stop_t = False
raise StopIteration()
else:
self.iter += 1
return {'S1': A, 'S1_label': A_paths,
'S2': B, 'S2_label': B_paths,
'S3': C, 'S3_label': C_paths,
'S4': D, 'S4_label': D_paths,
'T': t, 'T_label': t_paths}
class UnalignedDataLoader():
def initialize(self, source, target, batch_size1, batch_size2, scale=32):
transform = transforms.Compose([
transforms.Scale(scale),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#dataset_source1 = Dataset(source[1]['imgs'], source['labels'], transform=transform)
dataset_source1 = Dataset(source[0]['imgs'], source[0]['labels'], transform=transform)
data_loader_s1 = torch.utils.data.DataLoader(dataset_source1, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s1 = dataset_source1
dataset_source2 = Dataset(source[1]['imgs'], source[1]['labels'], transform=transform)
data_loader_s2 = torch.utils.data.DataLoader(dataset_source2, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s2 = dataset_source2
dataset_source3 = Dataset(source[2]['imgs'], source[2]['labels'], transform=transform)
data_loader_s3 = torch.utils.data.DataLoader(dataset_source3, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s3 = dataset_source3
dataset_source4 = Dataset(source[3]['imgs'], source[3]['labels'], transform=transform)
data_loader_s4 = torch.utils.data.DataLoader(dataset_source4, batch_size=batch_size1, shuffle=True, num_workers=4)
self.dataset_s4 = dataset_source4
#for i in range(len(source)):
# dataset_source[i] = Dataset(source[i]['imgs'], source[i]['labels'], transform=transform)
dataset_target = Dataset(target['imgs'], target['labels'], transform=transform)
data_loader_t = torch.utils.data.DataLoader(dataset_target, batch_size=batch_size2, shuffle=True, num_workers=4)
self.dataset_t = dataset_target
self.paired_data = PairedData(data_loader_s1, data_loader_s2, data_loader_s3,data_loader_s4, data_loader_t,
float("inf"))
def name(self):
return 'UnalignedDataLoader'
def load_data(self):
return self.paired_data
def __len__(self):
return min(max(len(self.dataset_s1),len(self.dataset_s2),len(self.dataset_s3), len(self.dataset_s4),len(self.dataset_t)), float("inf"))
"""Dataset setting and data loader for USPS.
Modified from
https://github.com/mingyuliutw/CoGAN/blob/master/cogan_pytorch/src/dataset_usps.py
"""
import gzip
import os
import pickle
import urllib
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import datasets, transforms
class USPS(data.Dataset):
"""USPS Dataset.
Args:
root (string): Root directory of dataset where dataset file exist.
train (bool, optional): If True, resample from dataset randomly.
download (bool, optional): If true, downloads the dataset
from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in
an PIL image and returns a transformed version.
E.g, ``transforms.RandomCrop``
"""
url = "https://raw.githubusercontent.com/mingyuliutw/CoGAN/master/cogan_pytorch/data/uspssample/usps_28x28.pkl"
def __init__(self, root, train=True, transform=None, download=False):
"""Init USPS dataset."""
# init params
self.root = os.path.expanduser(root)
self.filename = "usps_28x28.pkl"
self.train = train
# Num of Train = 7438, Num ot Test 1860
self.transform = transform
self.dataset_size = None
# download dataset.
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found." +
" You can use download=True to download it")
self.train_data, self.train_labels = self.load_samples()
if self.train:
total_num_samples = self.train_labels.shape[0]
indices = np.arange(total_num_samples)
self.train_data = self.train_data[indices[0:self.dataset_size], ::]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
self.train_data *= 255.0
self.train_data = np.squeeze(self.train_data).astype(np.uint8)
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, label = self.train_data[index], self.train_labels[index]
img = Image.fromarray(img, mode='L')
img = img.copy()
if self.transform is not None:
img = self.transform(img)
return img, label.astype("int64")
def __len__(self):
"""Return size of dataset."""
return len(self.train_data)
def _check_exists(self):
"""Check if dataset is download and in right place."""
return os.path.exists(os.path.join(self.root, self.filename))
def download(self):
"""Download dataset."""
filename = os.path.join(self.root, self.filename)
dirname = os.path.dirname(filename)
if not os.path.isdir(dirname):
os.makedirs(dirname)
if os.path.isfile(filename):
return
print("Download %s to %s" % (self.url, os.path.abspath(filename)))
urllib.request.urlretrieve(self.url, filename)
print("[DONE]")
return
def load_samples(self):
"""Load sample images from dataset."""
filename = os.path.join(self.root, self.filename)
f = gzip.open(filename, "rb")
data_set = pickle.load(f, encoding="bytes")
f.close()
if self.train:
images = data_set[0][0]
labels = data_set[0][1]
self.dataset_size = labels.shape[0]
else:
images = data_set[1][0]
labels = data_set[1][1]
self.dataset_size = labels.shape[0]
return images, labels
\ No newline at end of file
import numpy as np
from scipy.io import loadmat
import gzip
import pickle
import sys
sys.path.append('../utils/')
from data_pro.utils import dense_to_one_hot
base_dir = './data'
def load_usps(all_use=False):
#f = gzip.open('data_pro/usps_28x28.pkl', 'rb')
#data_set = pickle.load(f)
#f.close()
dataset = loadmat(base_dir + '/usps_28x28.mat')
data_set = dataset['dataset']
img_train = data_set[0][0]
label_train = data_set[0][1]
img_test = data_set[1][0]
label_test = data_set[1][1]
inds = np.random.permutation(img_train.shape[0])
img_train = img_train[inds]
label_train = label_train[inds]
img_train = img_train * 255
img_test = img_test * 255
img_train = img_train.reshape((img_train.shape[0], 1, 28, 28))
img_test = img_test.reshape((img_test.shape[0], 1, 28, 28))
#img_test = dense_to_one_hot(img_test)
label_train = dense_to_one_hot(label_train)
label_test = dense_to_one_hot(label_test)
img_train = np.concatenate([img_train, img_train, img_train, img_train], 0)
label_train = np.concatenate([label_train, label_train, label_train, label_train], 0)
# print('usps train X shape->', img_train.shape)
# print('usps train y shape->', label_train.shape)
# print('usps test X shape->', img_test.shape)
# print('usps test y shape->', label_test.shape)
return img_train, label_train, img_test, label_test
import numpy as np
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.01)
m.bias.data.normal_(0.0, 0.01)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.01)
m.bias.data.fill_(0)
def dense_to_one_hot(labels_dense):
"""Convert class labels from scalars to one-hot vectors."""
labels_one_hot = np.zeros((len(labels_dense),))
labels_dense = list(labels_dense)
for i, t in enumerate(labels_dense):
if t == 10:
t = 0
labels_one_hot[i] = t
else:
labels_one_hot[i] = t
return labels_one_hot
SNPC,SP,CS,RS
5,64.6,60.6,60.7
10,66.6,63.2,62.8
15,67.1,64.3,62.2
20,66.6,64.2,63.5
25,66.8,64.8,62.6
30,67.1,64.7,63.2
"""
draw singular values of DomainNet with resnet34
"""
import csv,os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# plt.style.use('ggplot')
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus']=False
# sns.set_theme(style="darkgrid")
sns.set_theme(style="whitegrid", palette="pastel")
def draw(file_paths,dir,name):
scoremarkers=["v","s","*","o","x","+"]
# accmarkers=["v","s","*","o","x"]
for i, path in enumerate(file_paths):
fmri=pd.read_csv(path,sep=',',) #header=None,names=["score"],index_col=False
num=sum(1 for line in open(path))
# fmri["score"]=fmri["score"]*100
# sns.barplot(x="alpha", y="RS", data=fmri)
# sns.barplot(x="alpha", y="CS", data=tips)
# sns.barplot(x="alpha", y="SP", data=tips)
a=0
ax=sns.lineplot(x="SNPC",y="RS",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
a=a+1
ax=sns.lineplot(x="SNPC",y="CS",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
a=a+1
ax=sns.lineplot(x="SNPC",y="SP",err_style = "band",ci="sd",marker=scoremarkers[a],linewidth=3,
# hue="region", style="event",
data=fmri)
plt.xlabel("SNPC",fontsize=20)
plt.ylabel('Accuracy,%',fontsize=20)
plt.yticks(np.arange(55, 72, 5))
plt.legend([r"S $\rightarrow$ P",r"C $\rightarrow$ S", r"R $\rightarrow$ S"],loc="lower left",fontsize=12)
plt.savefig(os.path.join(dir+ name), format="pdf",bbox_inches="tight",dpi = 400)
# plt.clf()
draw(["/data/maning/git/shot/draw/SNPC.csv"
],
"/data/maning/git/shot/draw/", "SNPCAcc.pdf")
# plt.clf()
# R-C (domainnet)
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_clipart.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2clipart_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_clipart.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/", "real2clipart_alphas1.pdf")
#R-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/tar_real2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/real_norm1_temp0.05_lr0.001/","real2sketch_alphas1.pdf")
# C-S
# draw(["/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_sketch.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/tar_clipart2sketch_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_sketch.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/clipart_norm1_temp0.05_lr0.001/","clipart2sketch_alphas1.pdf")
# S-P
# draw([
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha0.5_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha1.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha2.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha3.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear0_wnl0_alpha10.0_num3_painting.csv",
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/tar_sketch2painting_lr0.001MNPC5_im1.0_u200.0_unlent0.0_nonlinear1_wnl0_alpha0.75_num3_painting.csv"
# ],
# "/data/maning/git/shot/ssda/2022_05_10multi/seed2022/multi/sketch_norm1_temp0.05_lr0.001/","sketch2painting_alphas1.pdf")
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment