List of commits:
Subject Hash Author Date (UTC)
mse-ssim loss function 052de7c879bff7690a7cfc1905c8376bf8605c45 Thai Thien 2020-04-30 11:41:24
delete misinformation note ca07533d585def04eec086685f2e72eacc330ddc Thai Thien 2020-04-29 16:59:47
typo 3505077301af5349665f12121862db2512ad450a Thai Thien 2020-04-29 16:59:03
change, adapt, survive 6755c8375302af2146cb63adc967631d53a7b1c8 Thai Thien 2020-04-29 16:48:10
gpu 861ec0d41cea2eec359c6ddfe207e1ed6b583369 Thai Thien 2020-04-28 18:08:11
ccnn_v7_t4 a84f63e64fe8b31fe22c94d383de5ed4e1a27fe4 Thai Thien 2020-04-28 18:07:33
2500 d37516f53d21f3dfc05143ba5ffda80fc5a07825 Thai Thien 2020-04-28 17:56:57
change max epoch 246ae30e6dd2b4a15b7dc70a3dc05592ac1c48f2 Thai Thien 2020-04-28 17:55:23
h1 t3 h1 t4 fbcd13dd240e06a982a1ce48f27cd1d542a26b63 Thai Thien 2020-04-28 17:53:06
h1 96204cb5262500020371637741131d24e3fa91d0 Thai Thien 2020-04-27 17:35:51
typo adfb213c2564bc90b8b69811469534b004808644 Thai Thien 2020-04-27 17:17:58
batch size 8 for shb c90fa9a5d725a1ef0d29ed23f947ee05b9aa7894 Thai Thien 2020-04-27 17:10:34
change proxy 30cd53782eb17b416c471502f1e6c6e9975a644b Thai Thien 2020-04-27 17:06:28
experiment with CCNN when we max pooling after merge c8e8daec89ee71c2f30e30cf3030298ee7073e56 Thai Thien 2020-04-27 16:52:50
refactor train_compact_cnn 9d1ecece2382b79a98e6cf2d4579ae68172dbb6a Thai Thien 2020-04-27 16:35:47
we add head experiment 82a25e6b89945609486cbafc433eaea20dcdee39 Thai Thien 2020-04-27 16:23:50
CompactCNNV3 7de3766d085ebdbdf82b024eb517568dd82d8d6d Thai Thien 2020-04-27 16:23:20
no_norm da3c84dca19b0d281082679d88af3b9d27165bfe Thai Thien 2020-04-25 17:32:45
M4_t3_sha_c_shb d0d61ff74ed23f595d05d6a813c0a93239f61438 Thai Thien 2020-04-25 17:17:56
training script 624ecec7b12641f734e12ee2ebb6158c7c89683a Thai Thien 2020-04-25 17:08:25
Commit 052de7c879bff7690a7cfc1905c8376bf8605c45 - mse-ssim loss function
Author: Thai Thien
Author date (UTC): 2020-04-30 11:41
Committer name: Thai Thien
Committer date (UTC): 2020-04-30 11:41
Parent(s): ca07533d585def04eec086685f2e72eacc330ddc
Signer:
Signing key:
Signing status: N
Tree: 39310bbb32b85d8c65523e68577ff99f812b4a50
File Lines added Lines deleted
args_util.py 2 0
local_train_script/ccnn_t2_shb.sh 12 0
mse_ssim_loss.py 13 0
pytorch_ssim/README.md 2 0
train_compact_cnn.py 9 4
File args_util.py changed (mode: 100644) (index d83740e..c646aeb)
... ... def meow_parse():
122 122 parser.add_argument('--test', action="store_true", default=False) parser.add_argument('--test', action="store_true", default=False)
123 123 parser.add_argument('--no_norm', action="store_true", default=False, parser.add_argument('--no_norm', action="store_true", default=False,
124 124 help="if true, does not use transforms.Normalize in dataloader") help="if true, does not use transforms.Normalize in dataloader")
125 parser.add_argument('--use_ssim', action="store_true", default=False,
126 help="if true, use mse and negative ssim as loss function")
125 127 arg = parser.parse_args() arg = parser.parse_args()
126 128 return arg return arg
127 129
File local_train_script/ccnn_t2_shb.sh added (mode: 100644) (index 0000000..16ab5e7)
1 task="local_ccnn_t2_shb"
2
3 python train_compact_cnn.py \
4 --task_id $task \
5 --note "" \
6 --model "CompactCNNV7" \
7 --input /data/ShanghaiTech/part_B \
8 --lr 1e-4 \
9 --decay 1e-4 \
10 --batch_size 8 \
11 --datasetname shanghaitech_rnd \
12 --epochs 301
File mse_ssim_loss.py added (mode: 100644) (index 0000000..babe708)
1 import torch
2 from torch import nn
3 from pytorch_ssim import SSIM
4
5
6 class MseSsimLoss(torch.nn.Module):
7 def __init__(self):
8 super(MseSsimLoss, self).__init__()
9 self.mse = nn.MSELoss(reduction='sum')
10 self.ssim = SSIM(window_size=5)
11
12 def forward(self, input, target):
13 return self.mse(input, target) - self.ssim(input, target)
File pytorch_ssim/README.md added (mode: 100644) (index 0000000..8cb0a17)
1 Here is original github
2 https://github.com/Po-Hsun-Su/pytorch-ssim
File train_compact_cnn.py changed (mode: 100644) (index 7b57017..4b570dc)
... ... from ignite.metrics import Loss
7 7 from ignite.handlers import Checkpoint, DiskSaver, Timer from ignite.handlers import Checkpoint, DiskSaver, Timer
8 8 from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError
9 9 from visualize_util import get_readable_time from visualize_util import get_readable_time
10
10 from mse_ssim_loss import MseSsimLoss
11 11 import torch import torch
12 12 from torch import nn from torch import nn
13 from pytorch_ssim import SSIM
13 14
14 15 from models import CompactCNNV2, CompactCNNV7 from models import CompactCNNV2, CompactCNNV7
15 16
 
... ... import os
17 18 from model_util import get_lr from model_util import get_lr
18 19
19 20 COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM"
20 PROJECT_NAME = "crowd-counting-framework"
21 # PROJECT_NAME = "crowd-counting-debug"
21 # PROJECT_NAME = "crowd-counting-framework"
22 PROJECT_NAME = "crowd-counting-debug"
22 23
23 24
24 25 def very_simple_param_count(model): def very_simple_param_count(model):
 
... ... if __name__ == "__main__":
80 81 model = model.to(device) model = model.to(device)
81 82
82 83 # loss function # loss function
83 loss_fn = nn.MSELoss(reduction='sum').to(device)
84 if args.use_ssim:
85 loss_fn = MseSsimLoss()
86 print("use ssim")
87 else:
88 loss_fn = nn.MSELoss(reduction='sum').to(device)
84 89
85 90 optimizer = torch.optim.Adam(model.parameters(), args.lr, optimizer = torch.optim.Adam(model.parameters(), args.lr,
86 91 weight_decay=args.decay) weight_decay=args.decay)
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main