File data_flow.py changed (mode: 100644) (index a9da806..7e02d74) |
... |
... |
def load_data_shanghaitech_256(img_path, train=True): |
273 |
273 |
crop_size = (crop_sq_size, crop_sq_size) |
crop_size = (crop_sq_size, crop_sq_size) |
274 |
274 |
dx = int(random.random() * (img.size[0] - crop_sq_size)) |
dx = int(random.random() * (img.size[0] - crop_sq_size)) |
275 |
275 |
dy = int(random.random() * (img.size[1] - crop_sq_size)) |
dy = int(random.random() * (img.size[1] - crop_sq_size)) |
276 |
|
if dx < 0 or dy < 0: # we crop more than we can chew, so... |
|
|
276 |
|
if img.size[0] - crop_sq_size < 0 or img.size[1] - crop_sq_size < 0: # we crop more than we can chew, so... |
277 |
277 |
return None, None |
return None, None |
278 |
278 |
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
img = img.crop((dx, dy, crop_size[0] + dx, crop_size[1] + dy)) |
279 |
279 |
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
target = target[dy:crop_size[1] + dy, dx:crop_size[0] + dx] |
File playground/ccnnv2_playground.py copied from file sanity_check_dataloader.py (similarity 57%) (mode: 100644) (index 9f28698..c7b96b4) |
|
1 |
|
from models import CompactCNNV2 |
1 |
2 |
from args_util import sanity_check_dataloader_parse |
from args_util import sanity_check_dataloader_parse |
2 |
3 |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
from data_flow import get_train_val_list, get_dataloader, create_training_image_list |
3 |
4 |
import torch |
import torch |
4 |
5 |
import os |
import os |
5 |
6 |
|
|
6 |
|
|
|
7 |
7 |
if __name__ == "__main__": |
if __name__ == "__main__": |
8 |
8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
9 |
9 |
print(device) |
print(device) |
|
... |
... |
if __name__ == "__main__": |
22 |
22 |
|
|
23 |
23 |
# create data loader |
# create data loader |
24 |
24 |
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name, batch_size=5) |
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name, batch_size=5) |
25 |
|
|
|
|
25 |
|
model = CompactCNNV2() |
|
26 |
|
# model = model.cuda() |
26 |
27 |
print("============== TRAIN LOADER ====================================================") |
print("============== TRAIN LOADER ====================================================") |
27 |
28 |
min_1 = 500 |
min_1 = 500 |
28 |
29 |
min_2 = 500 |
min_2 = 500 |
29 |
30 |
for img, label in train_loader: |
for img, label in train_loader: |
30 |
|
print("img shape:" + str(img.shape) + " == " + "label shape " + str(label.shape)) |
|
31 |
|
size_1 = img.shape[2] |
|
32 |
|
size_2 = img.shape[3] |
|
33 |
|
if min_1 > size_1: |
|
34 |
|
min_1 = size_1 |
|
35 |
|
if min_2 > size_2: |
|
36 |
|
min_2 = size_2 |
|
37 |
|
if size_1 < 256 or size_2 < 256: |
|
38 |
|
count_below_256+=1 |
|
39 |
|
# example: img shape:torch.Size([1, 3, 716, 1024]) == label shape torch.Size([1, 1, 89, 128]) |
|
40 |
|
|
|
41 |
|
print("============== VAL LOADER ====================================================") |
|
42 |
|
for img, label in val_loader: |
|
43 |
|
print("img shape:" + str(img.shape) + " == " + "label shape " + str(label.shape)) |
|
44 |
|
print(min_1) |
|
45 |
|
print(min_2) |
|
46 |
|
print("count < 256 = ", count_below_256) |
|
|
31 |
|
out = model(img) |
|
32 |
|
print(out.shape) |
|
33 |
|
exit() |