File visualize_data_loader.py added (mode: 100644) (index 0000000..ad2acdc) |
|
1 |
|
from args_util import real_args_parse |
|
2 |
|
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
|
3 |
|
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator |
|
4 |
|
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError |
|
5 |
|
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
|
6 |
|
import torch |
|
7 |
|
from torch import nn |
|
8 |
|
import torch.nn.functional as F |
|
9 |
|
from models import CSRNet,PACNN |
|
10 |
|
import os |
|
11 |
|
import cv2 |
|
12 |
|
from torchvision import datasets, transforms |
|
13 |
|
from data_flow import ListDataset |
|
14 |
|
import pytorch_ssim |
|
15 |
|
|
|
16 |
|
from hard_code_variable import HardCodeVariable |
|
17 |
|
from visualize_util import save_img, save_density_map |
|
18 |
|
|
|
19 |
|
|
|
20 |
|
if __name__ == "__main__": |
|
21 |
|
HARD_CODE = HardCodeVariable() |
|
22 |
|
DATA_PATH = HARD_CODE.UCF_CC_50_PATH |
|
23 |
|
train_list, val_list = get_train_val_list(DATA_PATH, test_size=0.2) |
|
24 |
|
test_list = None |
|
25 |
|
|
|
26 |
|
# create data loader |
|
27 |
|
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="ucf_cc_50") |
|
28 |
|
train_loader_pacnn = torch.utils.data.DataLoader( |
|
29 |
|
ListDataset(train_list, |
|
30 |
|
shuffle=True, |
|
31 |
|
transform=transforms.Compose([ |
|
32 |
|
transforms.ToTensor() |
|
33 |
|
]), |
|
34 |
|
train=True, |
|
35 |
|
batch_size=1, |
|
36 |
|
num_workers=4, dataset_name="ucf_cc_50_pacnn"), |
|
37 |
|
batch_size=1, num_workers=4) |
|
38 |
|
|
|
39 |
|
img, label = next(iter(train_loader_pacnn)) |
|
40 |
|
|
|
41 |
|
print(img.shape) |
|
42 |
|
save_img(img, "pacnn_loader_img.png") |
|
43 |
|
save_density_map(label[0].numpy()[0], "pacnn_loader_density1.png") |
|
44 |
|
save_density_map(label[1].numpy()[0], "pacnn_loader_density2.png") |
|
45 |
|
save_density_map(label[2].numpy()[0], "pacnn_loader_density3.png") |
|
46 |
|
|
File visualize_util.py added (mode: 100644) (index 0000000..4064de3) |
|
1 |
|
import glob |
|
2 |
|
import PIL.Image as Image |
|
3 |
|
from matplotlib import pyplot as plt |
|
4 |
|
from matplotlib import cm as CM |
|
5 |
|
import os |
|
6 |
|
import numpy as np |
|
7 |
|
|
|
8 |
|
from PIL import Image |
|
9 |
|
|
|
10 |
|
|
|
11 |
|
def save_density_map(density_map, name): |
|
12 |
|
plt.figure(dpi=600) |
|
13 |
|
plt.axis('off') |
|
14 |
|
plt.margins(0, 0) |
|
15 |
|
plt.imshow(density_map, cmap=CM.jet) |
|
16 |
|
plt.savefig(name, dpi=600, bbox_inches='tight', pad_inches=0) |
|
17 |
|
|
|
18 |
|
|
|
19 |
|
def save_img(imgnp, name): |
|
20 |
|
# plt.imshow(imgnp[0].permute(1, 2, 0).numpy()) |
|
21 |
|
plt.imsave(name, imgnp[0].permute(1, 2, 0).numpy()) |
|
22 |
|
# plt.show() |
|
23 |
|
# im = Image.fromarray(imgnp[0].permute(1, 2, 0).numpy()) |
|
24 |
|
# im.save(name) |