File data_flow.py changed (mode: 100644) (index 1df0ec7..26d0022) |
... |
... |
class ListDataset(Dataset): |
1082 |
1082 |
def __getitem__(self, index): |
def __getitem__(self, index): |
1083 |
1083 |
assert index <= len(self), 'index range error' |
assert index <= len(self), 'index range error' |
1084 |
1084 |
img_path = self.lines[index] |
img_path = self.lines[index] |
1085 |
|
if self.debug: |
|
1086 |
|
print(img_path) |
|
|
1085 |
|
# if self.debug: |
|
1086 |
|
# print(img_path) |
1087 |
1087 |
# try to check cache item if exist |
# try to check cache item if exist |
1088 |
1088 |
if self.cache and self.train and index in self.cache_train.keys(): |
if self.cache and self.train and index in self.cache_train.keys(): |
1089 |
1089 |
img, target = self.cache_train[index] |
img, target = self.cache_train[index] |
|
... |
... |
class ListDataset(Dataset): |
1107 |
1107 |
self.cache_train[index] = (img, target) |
self.cache_train[index] = (img, target) |
1108 |
1108 |
else: |
else: |
1109 |
1109 |
self.cache_eval[index] = (img, target) |
self.cache_eval[index] = (img, target) |
1110 |
|
|
|
1111 |
|
return img, target |
|
|
1110 |
|
if self.debug: |
|
1111 |
|
_, p_count = self.load_data_fn(img_path, train=False) |
|
1112 |
|
print(img_path + " " + str(target.sum()) + " " + str(p_count)) |
|
1113 |
|
return img, target, p_count |
|
1114 |
|
else: |
|
1115 |
|
return img, target |
1112 |
1116 |
|
|
1113 |
1117 |
|
|
1114 |
1118 |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, |
1115 |
|
train_loader_for_eval_check=False, cache=False, pin_memory=False): |
|
|
1119 |
|
train_loader_for_eval_check=False, cache=False, pin_memory=False, |
|
1120 |
|
debug=False): |
|
1121 |
|
|
1116 |
1122 |
if visualize_mode: |
if visualize_mode: |
1117 |
1123 |
transformer = transforms.Compose([ |
transformer = transforms.Compose([ |
1118 |
1124 |
transforms.ToTensor() |
transforms.ToTensor() |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
1134 |
1140 |
train=True, |
train=True, |
1135 |
1141 |
batch_size=batch_size, |
batch_size=batch_size, |
1136 |
1142 |
num_workers=0, |
num_workers=0, |
|
1143 |
|
debug=debug, |
1137 |
1144 |
dataset_name=dataset_name, cache=cache), |
dataset_name=dataset_name, cache=cache), |
1138 |
1145 |
batch_size=batch_size, |
batch_size=batch_size, |
1139 |
1146 |
num_workers=0, |
num_workers=0, |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
1146 |
1153 |
train=False, |
train=False, |
1147 |
1154 |
batch_size=batch_size, |
batch_size=batch_size, |
1148 |
1155 |
num_workers=0, |
num_workers=0, |
|
1156 |
|
debug=debug, |
1149 |
1157 |
dataset_name=dataset_name, cache=cache), |
dataset_name=dataset_name, cache=cache), |
1150 |
1158 |
batch_size=1, |
batch_size=1, |
1151 |
1159 |
num_workers=0, |
num_workers=0, |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
1157 |
1165 |
shuffle=False, |
shuffle=False, |
1158 |
1166 |
transform=transformer, |
transform=transformer, |
1159 |
1167 |
train=False, |
train=False, |
|
1168 |
|
debug=debug, |
1160 |
1169 |
dataset_name=dataset_name, cache=cache), |
dataset_name=dataset_name, cache=cache), |
1161 |
1170 |
num_workers=0, |
num_workers=0, |
1162 |
1171 |
batch_size=1, |
batch_size=1, |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
1170 |
1179 |
shuffle=False, |
shuffle=False, |
1171 |
1180 |
transform=transformer, |
transform=transformer, |
1172 |
1181 |
train=False, |
train=False, |
|
1182 |
|
debug=debug, |
1173 |
1183 |
dataset_name=dataset_name), |
dataset_name=dataset_name), |
1174 |
1184 |
num_workers=0, |
num_workers=0, |
1175 |
1185 |
batch_size=1, |
batch_size=1, |
File debug/debug_sha.py added (mode: 100644) (index 0000000..4a6691f) |
|
1 |
|
# sha shanghaitech_keepfull is not convergent |
|
2 |
|
from args_util import real_args_parse |
|
3 |
|
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
|
4 |
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
5 |
|
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError |
|
6 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
|
7 |
|
import torch |
|
8 |
|
from torch import nn |
|
9 |
|
import torch.nn.functional as F |
|
10 |
|
from models import CSRNet,PACNN |
|
11 |
|
import os |
|
12 |
|
import cv2 |
|
13 |
|
from torchvision import datasets, transforms |
|
14 |
|
from data_flow import ListDataset |
|
15 |
|
import pytorch_ssim |
|
16 |
|
|
|
17 |
|
from hard_code_variable import HardCodeVariable |
|
18 |
|
from data_util import ShanghaiTechDataPath |
|
19 |
|
from visualize_util import save_img, save_density_map |
|
20 |
|
|
|
21 |
|
|
|
22 |
|
def visualize_shanghaitech_keepfull(): |
|
23 |
|
HARD_CODE = HardCodeVariable() |
|
24 |
|
shanghaitech_data = ShanghaiTechDataPath(root=HARD_CODE.SHANGHAITECH_PATH) |
|
25 |
|
shanghaitech_data_part_a_train = shanghaitech_data.get_a().get_train().get() |
|
26 |
|
saved_folder = "visualize/debug_dataloader_shanghaitech" |
|
27 |
|
os.makedirs(saved_folder, exist_ok=True) |
|
28 |
|
train_list, val_list = get_train_val_list(shanghaitech_data_part_a_train, test_size=0.2) |
|
29 |
|
test_list = None |
|
30 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull", visualize_mode=True, |
|
31 |
|
debug=True) |
|
32 |
|
|
|
33 |
|
# do with train loader |
|
34 |
|
train_loader_iter = iter(train_loader) |
|
35 |
|
for i in range(10): |
|
36 |
|
img, label, count = next(train_loader_iter) |
|
37 |
|
save_img(img, os.path.join(saved_folder, "train_img" + str(i) +".png")) |
|
38 |
|
save_path = os.path.join(saved_folder, "train_label" + str(i) +".png") |
|
39 |
|
save_density_map(label.numpy()[0][0], save_path) |
|
40 |
|
print("saved " + save_path) |
|
41 |
|
|
|
42 |
|
if __name__ == "__main__": |
|
43 |
|
visualize_shanghaitech_keepfull() |