File data_flow.py changed (mode: 100644) (index cb3c78c..9bbb7c1) |
... |
... |
def data_augmentation(img, target): |
279 |
279 |
class ListDataset(Dataset): |
class ListDataset(Dataset): |
280 |
280 |
def __init__(self, root, shape=None, shuffle=True, transform=None, train=False, seen=0, batch_size=1, |
def __init__(self, root, shape=None, shuffle=True, transform=None, train=False, seen=0, batch_size=1, |
281 |
281 |
debug=False, |
debug=False, |
282 |
|
num_workers=4, dataset_name="shanghaitech"): |
|
|
282 |
|
num_workers=0, dataset_name="shanghaitech"): |
283 |
283 |
""" |
""" |
284 |
284 |
if you have different image size, then batch_size must be 1 |
if you have different image size, then batch_size must be 1 |
285 |
285 |
:param root: |
:param root: |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
352 |
352 |
transform=transformer, |
transform=transformer, |
353 |
353 |
train=True, |
train=True, |
354 |
354 |
batch_size=1, |
batch_size=1, |
355 |
|
num_workers=4, |
|
|
355 |
|
num_workers=0, |
356 |
356 |
dataset_name=dataset_name), |
dataset_name=dataset_name), |
357 |
357 |
batch_size=1, |
batch_size=1, |
358 |
358 |
num_workers=4) |
num_workers=4) |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
363 |
363 |
transform=transformer, |
transform=transformer, |
364 |
364 |
train=False, |
train=False, |
365 |
365 |
dataset_name=dataset_name), |
dataset_name=dataset_name), |
366 |
|
num_workers=4, |
|
|
366 |
|
num_workers=0, |
367 |
367 |
batch_size=1) |
batch_size=1) |
368 |
368 |
|
|
369 |
369 |
if test_list is not None: |
if test_list is not None: |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
373 |
373 |
transform=transformer, |
transform=transformer, |
374 |
374 |
train=False, |
train=False, |
375 |
375 |
dataset_name=dataset_name), |
dataset_name=dataset_name), |
376 |
|
num_workers=4, |
|
|
376 |
|
num_workers=0, |
377 |
377 |
batch_size=1) |
batch_size=1) |
378 |
378 |
else: |
else: |
379 |
379 |
test_loader = None |
test_loader = None |
File train_attn_can_adcrowdnet.py changed (mode: 100644) (index 878bdea..2dbf75f) |
... |
... |
if __name__ == "__main__": |
25 |
25 |
|
|
26 |
26 |
# create list |
# create list |
27 |
27 |
train_list, val_list = get_train_val_list(TRAIN_PATH) |
train_list, val_list = get_train_val_list(TRAIN_PATH) |
28 |
|
test_list = create_training_image_list(TEST_PATH) |
|
|
28 |
|
test_list = None |
29 |
29 |
|
|
30 |
30 |
# create data loader |
# create data loader |
31 |
31 |
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull") |
train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull") |
File train_script/attn_can_adcrowdnet/train_server_31_epoch_shA.sh changed (mode: 100644) (index 49459f1..ee917b1) |
1 |
|
CUDA_VISIBLE_DEVICES=4 nohup python train_context_aware_network.py \ |
|
|
1 |
|
CUDA_VISIBLE_DEVICES=4 nohup python train_attn_can_adcrowdnet.py \ |
2 |
2 |
--task_id attn_can_adcrowdnet_default_shtA_31 \ |
--task_id attn_can_adcrowdnet_default_shtA_31 \ |
3 |
3 |
--input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ |
--input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ |
4 |
4 |
--output saved_model/attn_can_adcrowdnet_default_shtA_31 \ |
--output saved_model/attn_can_adcrowdnet_default_shtA_31 \ |