List of commits:
Subject Hash Author Date (UTC)
training flow that work fb242273e8f696916f9d1ff4bb76b4e5869799ef Thai Thien 2020-02-02 10:42:01
fix the dataloader for shanghaitech 5f2aee9f316e6555e6a70c6ad037a4e6b491867b Thai Thien 2020-02-02 09:19:50
context aware visualize seem ok 1bdb6ffe77ca4e40ef8f299b2506df2266243db4 Thai Thien 2020-02-02 05:07:10
visualize eval context aware network seem ok f3fe45c23dfeab3730624737efabb0b14d23c25b Thai Thien 2020-02-02 04:50:34
visualize_shanghaitech_pacnn_with_perspective run without error 12366a2de2bd60ff4bd36e6132d44e37dedf7462 Thai Thien 2020-02-02 04:21:16
eval context aware network on ShanghaiTechB can run e8c454d2b6d287c830c1286c9a37884b3cfc615f Thai Thien 2020-02-02 04:09:14
import ShanghaiTechDataPath in data_util e81eb56315d44375ff5c0e747d61456601492f8f Thai Thien 2020-02-02 04:04:36
add model_context_aware_network.py 2a36025c001d85afc064c090f4d22987b328977b Thai Thien 2020-02-02 03:46:38
PACNN (TODO: test this) 44d5ae7ec57c760fb4f105dd3e3492148a0cc075 Thai Thien 2020-02-02 03:40:26
add data path 80134de767d0137a663f343e4606bafc57a1bc1f Thai Thien 2020-02-02 03:38:21
test if ShanghaiTech datapath is correct 97ee84944a4393ec3732879b24f614826f8e7798 Thai Thien 2020-02-01 03:57:31
refactor and test ShanghaiTech datapath 9542ebc00f257edc38690180b7a4353794be4019 Thai Thien 2020-02-01 03:53:49
fix the unzip flow b53c5989935335377eb6a88c942713d3eccc5df7 Thai Thien 2020-02-01 03:53:13
data_script run seem ok 67420c08fc1c10a66404d3698994865726a106cd Thai Thien 2020-02-01 03:33:18
add perspective 642d6fff8c9f31e510fda85a7fb631fb855d8a6d Thai Thien 2019-10-06 16:54:44
fix padding with p 86c2fa07822d956a34b3b37e14da485a4249f01b Thai Thien 2019-10-06 02:52:58
pacnn perspective loss fb673e38a5f24ae9004fe2b7b93c88991e0c2304 Thai Thien 2019-10-06 01:38:28
data_flow shanghaitech_pacnn_with_perspective seem working 91d350a06f358e03223966297d124daee94123d0 Thai Thien 2019-10-06 01:31:11
multiscale loss and final loss only mode c65dd0e74ad28503821e5c8651a3b47b4a0c7c64 Thai Thien 2019-10-05 15:58:19
wip : perspective map eac63f2671dc5b064753acc4f40bf0f9f216ad2a Thai Thien 2019-10-04 16:26:56
Commit fb242273e8f696916f9d1ff4bb76b4e5869799ef - training flow that work
Author: Thai Thien
Author date (UTC): 2020-02-02 10:42
Committer name: Thai Thien
Committer date (UTC): 2020-02-02 10:42
Parent(s): 5f2aee9f316e6555e6a70c6ad037a4e6b491867b
Signing key:
Tree: f94031eb54d905e6138a0e54dcf55c0c9ac424c0
File Lines added Lines deleted
args_util.py 26 0
train_context_aware_network.py 18 11
train_script/train_can_short.sh 5 0
File args_util.py changed (mode: 100644) (index fa2f9fc..33e540f)
... ... def like_real_args_parse(data_input):
52 52 args.print_freq = 30 args.print_freq = 30
53 53
54 54
55 def context_aware_network_args_parse():
56 """
57 this is not dummy
58 if you are going to make all-in-one notebook, ignore this
59 :return:
60 """
61 parser = argparse.ArgumentParser(description='CrowdCounting Context Aware Network')
62 parser.add_argument("--task_id", action="store", default="dev")
63 parser.add_argument('-a', action="store_true", default=False)
64
65 parser.add_argument('--input', action="store", type=str, default=HardCodeVariable().SHANGHAITECH_PATH_PART_A)
66 parser.add_argument('--output', action="store", type=str, default="saved_model/context_aware_network")
67 parser.add_argument('--datasetname', action="store", default="shanghaitech_keepfull")
68
69 # args with default value
70 parser.add_argument('--load_model', action="store", default="", type=str)
71 parser.add_argument('--lr', action="store", default=1e-8, type=float)
72 parser.add_argument('--momentum', action="store", default=0.9, type=float)
73 parser.add_argument('--decay', action="store", default=5*1e-3, type=float)
74 parser.add_argument('--epochs', action="store", default=1, type=int)
75 parser.add_argument('--test', action="store_true", default=False)
76
77
78 arg = parser.parse_args()
79 return arg
80
55 81 def real_args_parse(): def real_args_parse():
56 82 """ """
57 83 this is not dummy this is not dummy
File train_context_aware_network.py copied from file main_shanghaitech.py (similarity 70%) (mode: 100644) (index 3c2d5ba..0194e9b)
1 from args_util import real_args_parse
1 from args_util import context_aware_network_args_parse
2 2 from data_flow import get_train_val_list, get_dataloader, create_training_image_list from data_flow import get_train_val_list, get_dataloader, create_training_image_list
3 3 from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
4 4 from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError
5 from ignite.engine import Engine
6 from ignite.handlers import Checkpoint, DiskSaver
5 7 from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError
8
6 9 import torch import torch
7 10 from torch import nn from torch import nn
8 from models import CSRNet
11 from models import CANNet
9 12 import os import os
10 13
14
11 15 if __name__ == "__main__": if __name__ == "__main__":
12 16 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13 17 print(device) print(device)
14 args = real_args_parse()
18 args = context_aware_network_args_parse()
15 19 print(args) print(args)
16 20 DATA_PATH = args.input DATA_PATH = args.input
17 21 TRAIN_PATH = os.path.join(DATA_PATH, "train_data") TRAIN_PATH = os.path.join(DATA_PATH, "train_data")
 
... ... if __name__ == "__main__":
23 27 test_list = create_training_image_list(TEST_PATH) test_list = create_training_image_list(TEST_PATH)
24 28
25 29 # create data loader # create data loader
26 train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list)
30 train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech_keepfull")
27 31
28 32
29 33 # model # model
30 model = CSRNet()
34 model = CANNet()
31 35 model = model.to(device) model = model.to(device)
32 36
33 37 # loss function # loss function
34 loss_fn = nn.MSELoss(size_average=False).cuda()
38 loss_fn = nn.MSELoss(size_average=False).to(device)
35 39
36 40 optimizer = torch.optim.SGD(model.parameters(), args.lr, optimizer = torch.optim.SGD(model.parameters(), args.lr,
37 41 momentum=args.momentum, momentum=args.momentum,
 
... ... if __name__ == "__main__":
49 53 print(args) print(args)
50 54
51 55
52 @trainer.on(Events.ITERATION_COMPLETED)
56 @trainer.on(Events.ITERATION_COMPLETED(every=50))
53 57 def log_training_loss(trainer): def log_training_loss(trainer):
54 58 print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output)) print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))
55 59
 
... ... if __name__ == "__main__":
58 62 def log_training_results(trainer): def log_training_results(trainer):
59 63 evaluator.run(train_loader) evaluator.run(train_loader)
60 64 metrics = evaluator.state.metrics metrics = evaluator.state.metrics
61 print("Validation Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
65 print("Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
62 66 .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll']))
63 67
64 68
 
... ... if __name__ == "__main__":
66 70 def log_validation_results(trainer): def log_validation_results(trainer):
67 71 evaluator.run(val_loader) evaluator.run(val_loader)
68 72 metrics = evaluator.state.metrics metrics = evaluator.state.metrics
69 print("Validation Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
73 print("Validation set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
70 74 .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll'])) .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll']))
71 75
72
73 trainer.run(train_loader, max_epochs=10)
76 # docs on save and load
77 to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
78 save_handler = Checkpoint(to_save, DiskSaver('saved_model/context_aware_network', create_dir=True))
79 trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler)
80 trainer.run(train_loader, max_epochs=args.epochs)
File train_script/train_can_short.sh changed (mode: 100644) (index e69de29..4242707)
1 python train_context_aware_network.py \
2 --input /data/ShanghaiTech/part_A/ \
3 --output saved_model/context_aware_network \
4 --datasetname shanghaitech_keepfull \
5 --epochs 3
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main