Subject | Hash | Author | Date (UTC) |
---|---|---|---|
add readable timestamp viz | ae1fdb49ddb9ea77659529dceb7fb87c2790c8dc | Thai Thien | 2020-02-24 03:49:57 |
change save name prefix | c53a86f30fb8fd4e8f3a409eb67827d56a43ae5c | Thai Thien | 2020-02-02 10:48:15 |
training flow that work | fb242273e8f696916f9d1ff4bb76b4e5869799ef | Thai Thien | 2020-02-02 10:42:01 |
fix the dataloader for shanghaitech | 5f2aee9f316e6555e6a70c6ad037a4e6b491867b | Thai Thien | 2020-02-02 09:19:50 |
context aware visualize seem ok | 1bdb6ffe77ca4e40ef8f299b2506df2266243db4 | Thai Thien | 2020-02-02 05:07:10 |
visualize eval context aware network seem ok | f3fe45c23dfeab3730624737efabb0b14d23c25b | Thai Thien | 2020-02-02 04:50:34 |
visualize_shanghaitech_pacnn_with_perspective run without error | 12366a2de2bd60ff4bd36e6132d44e37dedf7462 | Thai Thien | 2020-02-02 04:21:16 |
eval context aware network on ShanghaiTechB can run | e8c454d2b6d287c830c1286c9a37884b3cfc615f | Thai Thien | 2020-02-02 04:09:14 |
import ShanghaiTechDataPath in data_util | e81eb56315d44375ff5c0e747d61456601492f8f | Thai Thien | 2020-02-02 04:04:36 |
add model_context_aware_network.py | 2a36025c001d85afc064c090f4d22987b328977b | Thai Thien | 2020-02-02 03:46:38 |
PACNN (TODO: test this) | 44d5ae7ec57c760fb4f105dd3e3492148a0cc075 | Thai Thien | 2020-02-02 03:40:26 |
add data path | 80134de767d0137a663f343e4606bafc57a1bc1f | Thai Thien | 2020-02-02 03:38:21 |
test if ShanghaiTech datapath is correct | 97ee84944a4393ec3732879b24f614826f8e7798 | Thai Thien | 2020-02-01 03:57:31 |
refactor and test ShanghaiTech datapath | 9542ebc00f257edc38690180b7a4353794be4019 | Thai Thien | 2020-02-01 03:53:49 |
fix the unzip flow | b53c5989935335377eb6a88c942713d3eccc5df7 | Thai Thien | 2020-02-01 03:53:13 |
data_script run seem ok | 67420c08fc1c10a66404d3698994865726a106cd | Thai Thien | 2020-02-01 03:33:18 |
add perspective | 642d6fff8c9f31e510fda85a7fb631fb855d8a6d | Thai Thien | 2019-10-06 16:54:44 |
fix padding with p | 86c2fa07822d956a34b3b37e14da485a4249f01b | Thai Thien | 2019-10-06 02:52:58 |
pacnn perspective loss | fb673e38a5f24ae9004fe2b7b93c88991e0c2304 | Thai Thien | 2019-10-06 01:38:28 |
data_flow shanghaitech_pacnn_with_perspective seem working | 91d350a06f358e03223966297d124daee94123d0 | Thai Thien | 2019-10-06 01:31:11 |
File | Lines added | Lines deleted |
---|---|---|
train_context_aware_network.py | 8 | 4 |
train_script/CAN/train_can_server_100epoch_shA.sh | 6 | 0 |
train_script/CAN/train_can_short.sh | 0 | 0 |
visualize_util.py | 8 | 0 |
File train_context_aware_network.py changed (mode: 100644) (index a7404f8..3d5c656) | |||
... | ... | from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError | |
5 | 5 | from ignite.engine import Engine | from ignite.engine import Engine |
6 | 6 | from ignite.handlers import Checkpoint, DiskSaver | from ignite.handlers import Checkpoint, DiskSaver |
7 | 7 | from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError | from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
8 | from visualize_util import get_readable_time | ||
8 | 9 | ||
9 | 10 | import torch | import torch |
10 | 11 | from torch import nn | from torch import nn |
... | ... | if __name__ == "__main__": | |
55 | 56 | ||
56 | 57 | @trainer.on(Events.ITERATION_COMPLETED(every=50)) | @trainer.on(Events.ITERATION_COMPLETED(every=50)) |
57 | 58 | def log_training_loss(trainer): | def log_training_loss(trainer): |
58 | print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) | ||
59 | timestamp = get_readable_time() | ||
60 | print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) | ||
59 | 61 | ||
60 | 62 | ||
61 | 63 | @trainer.on(Events.EPOCH_COMPLETED) | @trainer.on(Events.EPOCH_COMPLETED) |
62 | 64 | def log_training_results(trainer): | def log_training_results(trainer): |
63 | 65 | evaluator.run(train_loader) | evaluator.run(train_loader) |
64 | 66 | metrics = evaluator.state.metrics | metrics = evaluator.state.metrics |
65 | print("Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" | ||
67 | timestamp = get_readable_time() | ||
68 | print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" | ||
66 | 69 | .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) | .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
67 | 70 | ||
68 | 71 | ||
... | ... | if __name__ == "__main__": | |
70 | 73 | def log_validation_results(trainer): | def log_validation_results(trainer): |
71 | 74 | evaluator.run(val_loader) | evaluator.run(val_loader) |
72 | 75 | metrics = evaluator.state.metrics | metrics = evaluator.state.metrics |
73 | print("Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" | ||
76 | timestamp = get_readable_time() | ||
77 | print(timestamp + " Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" | ||
74 | 78 | .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) | .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) |
75 | 79 | ||
76 | 80 | # docs on save and load | # docs on save and load |
77 | 81 | to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} | to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer} |
78 | save_handler = Checkpoint(to_save, DiskSaver('saved_model/context_aware_network', create_dir=True), filename_prefix=args.task_id) | ||
82 | save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True), filename_prefix=args.task_id) | ||
79 | 83 | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler) | trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler) |
80 | 84 | ||
81 | 85 | trainer.run(train_loader, max_epochs=args.epochs) | trainer.run(train_loader, max_epochs=args.epochs) |
File train_script/CAN/train_can_server_100epoch_shA.sh added (mode: 100644) (index 0000000..c80599b) | |||
1 | CUDA_VISIBLE_DEVICES=4 nohup python train_context_aware_network.py \ | ||
2 | --task_id can_default_shtA_100 \ | ||
3 | --input /data/rnd/thient/thient_data/ShanghaiTech/part_A \ | ||
4 | --output saved_model/context_aware_network \ | ||
5 | --datasetname shanghaitech_keepfull \ | ||
6 | --epochs 100 > logs/can_default_shtA_100_nohup.log & |
File train_script/CAN/train_can_short.sh renamed from train_script/train_can_short.sh (similarity 100%) |
File visualize_util.py changed (mode: 100644) (index 0190315..1b0f845) | |||
... | ... | from matplotlib import pyplot as plt | |
4 | 4 | from matplotlib import cm as CM | from matplotlib import cm as CM |
5 | 5 | import os | import os |
6 | 6 | import numpy as np | import numpy as np |
7 | import time | ||
7 | 8 | ||
8 | 9 | from PIL import Image | from PIL import Image |
9 | 10 | ||
... | ... | def save_img(imgnp, name): | |
31 | 32 | # plt.show() | # plt.show() |
32 | 33 | # im = Image.fromarray(imgnp[0].permute(1, 2, 0).numpy()) | # im = Image.fromarray(imgnp[0].permute(1, 2, 0).numpy()) |
33 | 34 | # im.save(name) | # im.save(name) |
35 | |||
36 | def get_readable_time(): | ||
37 | """ | ||
38 | make human readable time with format year-month-day hour-minute | ||
39 | :return: a string of human readable time (ex: '2020-02-24 10:31' ) | ||
40 | """ | ||
41 | return time.strftime('%Y-%m-%d %H:%M') |