Subject | Hash | Author | Date (UTC) |
---|---|---|---|
get ready for short training run with 30 epochs | 66dda0858561897cd5f81e10077459adb39d86dd | Thai Thien | 2020-02-27 15:22:01 |
implement attn_can_adcrowdnet | ffd38664a43d861c20cdc225746b9ce2a00260c7 | Thai Thien | 2020-02-27 15:10:27 |
WIP: add can-adcrowdnet | 5620b83449b31d00a367c8de77e431e19a5ccfb3 | Thai Thien | 2020-02-25 11:31:38 |
add readable timestamp viz | ae1fdb49ddb9ea77659529dceb7fb87c2790c8dc | Thai Thien | 2020-02-24 03:49:57 |
change save name prefix | c53a86f30fb8fd4e8f3a409eb67827d56a43ae5c | Thai Thien | 2020-02-02 10:48:15 |
training flow that work | fb242273e8f696916f9d1ff4bb76b4e5869799ef | Thai Thien | 2020-02-02 10:42:01 |
fix the dataloader for shanghaitech | 5f2aee9f316e6555e6a70c6ad037a4e6b491867b | Thai Thien | 2020-02-02 09:19:50 |
context aware visualize seem ok | 1bdb6ffe77ca4e40ef8f299b2506df2266243db4 | Thai Thien | 2020-02-02 05:07:10 |
visualize eval context aware network seem ok | f3fe45c23dfeab3730624737efabb0b14d23c25b | Thai Thien | 2020-02-02 04:50:34 |
visualize_shanghaitech_pacnn_with_perspective run without error | 12366a2de2bd60ff4bd36e6132d44e37dedf7462 | Thai Thien | 2020-02-02 04:21:16 |
eval context aware network on ShanghaiTechB can run | e8c454d2b6d287c830c1286c9a37884b3cfc615f | Thai Thien | 2020-02-02 04:09:14 |
import ShanghaiTechDataPath in data_util | e81eb56315d44375ff5c0e747d61456601492f8f | Thai Thien | 2020-02-02 04:04:36 |
add model_context_aware_network.py | 2a36025c001d85afc064c090f4d22987b328977b | Thai Thien | 2020-02-02 03:46:38 |
PACNN (TODO: test this) | 44d5ae7ec57c760fb4f105dd3e3492148a0cc075 | Thai Thien | 2020-02-02 03:40:26 |
add data path | 80134de767d0137a663f343e4606bafc57a1bc1f | Thai Thien | 2020-02-02 03:38:21 |
test if ShanghaiTech datapath is correct | 97ee84944a4393ec3732879b24f614826f8e7798 | Thai Thien | 2020-02-01 03:57:31 |
refactor and test ShanghaiTech datapath | 9542ebc00f257edc38690180b7a4353794be4019 | Thai Thien | 2020-02-01 03:53:49 |
fix the unzip flow | b53c5989935335377eb6a88c942713d3eccc5df7 | Thai Thien | 2020-02-01 03:53:13 |
data_script run seem ok | 67420c08fc1c10a66404d3698994865726a106cd | Thai Thien | 2020-02-01 03:33:18 |
add perspective | 642d6fff8c9f31e510fda85a7fb631fb855d8a6d | Thai Thien | 2019-10-06 16:54:44 |
File | Lines added | Lines deleted |
---|---|---|
models/__init__.py | 2 | 1 |
train_attn_can_adcrowdnet.py | 17 | 4 |
train_script/attn_can_adcrowdnet/train_server_31_epoch_shA.sh | 6 | 0 |
File models/__init__.py changed (mode: 100644) (index 556c781..af62baf) | |||
1 | 1 | from .csrnet import CSRNet | from .csrnet import CSRNet |
2 | 2 | from .pacnn import PACNN, PACNNWithPerspectiveMap | from .pacnn import PACNN, PACNNWithPerspectiveMap |
3 | 3 | from .context_aware_network import CANNet | from .context_aware_network import CANNet |
4 | from .deform_conv_v2 import DeformConv2d | ||
4 | from .deform_conv_v2 import DeformConv2d | ||
5 | from .attn_can_adcrowdnet import AttnCanAdcrowdNet |
File train_attn_can_adcrowdnet.py copied from file train_context_aware_network.py (similarity 85%) (mode: 100644) (index 3d5c656..e489e9c) | |||
... | ... | from visualize_util import get_readable_time | |
9 | 9 | ||
10 | 10 | import torch | import torch |
11 | 11 | from torch import nn | from torch import nn |
12 | from models import CANNet | ||
12 | from models import AttnCanAdcrowdNet | ||
13 | 13 | import os | import os |
14 | 14 | ||
15 | 15 | ||
... | ... | if __name__ == "__main__": | |
32 | 32 | ||
33 | 33 | ||
34 | 34 | # model | # model |
35 | model = CANNet() | ||
35 | model = AttnCanAdcrowdNet() | ||
36 | 36 | model = model.to(device) | model = model.to(device) |
37 | 37 | ||
38 | 38 | # loss function | # loss function |
... | ... | if __name__ == "__main__": | |
77 | 77 | print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" | print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
78 | 78 | .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) | .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
79 | 79 | ||
80 | def score_function(engine): | ||
81 | """ | ||
82 | saver score function | ||
83 | :param engine: | ||
84 | :return: | ||
85 | """ | ||
86 | engine.state.metrics['mae'] | ||
87 | |||
80 | 88 | # docs on save and load | # docs on save and load |
81 | 89 | to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} | to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
82 | save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True), filename_prefix=args.task_id) | ||
83 | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler) | ||
90 | save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True, atomic=True), | ||
91 | filename_prefix=args.task_id, | ||
92 | n_saved=5, | ||
93 | score_function=score_function, | ||
94 | score_name="mae") | ||
95 | |||
96 | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), save_handler) | ||
84 | 97 | ||
85 | 98 | trainer.run(train_loader, max_epochs=args.epochs) | trainer.run(train_loader, max_epochs=args.epochs) |
File train_script/attn_can_adcrowdnet/train_server_31_epoch_shA.sh added (mode: 100644) (index 0000000..49459f1) | |||
1 | CUDA_VISIBLE_DEVICES=4 nohup python train_context_aware_network.py \ | ||
2 | --task_id attn_can_adcrowdnet_default_shtA_31 \ | ||
3 | --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ | ||
4 | --output saved_model/attn_can_adcrowdnet_default_shtA_31 \ | ||
5 | --datasetname shanghaitech_keepfull \ | ||
6 | --epochs 31 > logs/attn_can_adcrowdnet_default_shtA_31_nohup.log & |