File args_util.py changed (mode: 100644) (index d83740e..c646aeb) |
... |
... |
def meow_parse(): |
122 |
122 |
parser.add_argument('--test', action="store_true", default=False) |
parser.add_argument('--test', action="store_true", default=False) |
123 |
123 |
parser.add_argument('--no_norm', action="store_true", default=False, |
parser.add_argument('--no_norm', action="store_true", default=False, |
124 |
124 |
help="if true, does not use transforms.Normalize in dataloader") |
help="if true, does not use transforms.Normalize in dataloader") |
|
125 |
|
parser.add_argument('--use_ssim', action="store_true", default=False, |
|
126 |
|
help="if true, use mse and negative ssim as loss function") |
125 |
127 |
arg = parser.parse_args() |
arg = parser.parse_args() |
126 |
128 |
return arg |
return arg |
127 |
129 |
|
|
File mse_ssim_loss.py added (mode: 100644) (index 0000000..babe708) |
|
1 |
|
import torch |
|
2 |
|
from torch import nn |
|
3 |
|
from pytorch_ssim import SSIM |
|
4 |
|
|
|
5 |
|
|
|
6 |
|
class MseSsimLoss(torch.nn.Module): |
|
7 |
|
def __init__(self): |
|
8 |
|
super(MseSsimLoss, self).__init__() |
|
9 |
|
self.mse = nn.MSELoss(reduction='sum') |
|
10 |
|
self.ssim = SSIM(window_size=5) |
|
11 |
|
|
|
12 |
|
def forward(self, input, target): |
|
13 |
|
return self.mse(input, target) - self.ssim(input, target) |
File train_compact_cnn.py changed (mode: 100644) (index 7b57017..4b570dc) |
... |
... |
from ignite.metrics import Loss |
7 |
7 |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
from ignite.handlers import Checkpoint, DiskSaver, Timer |
8 |
8 |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError |
9 |
9 |
from visualize_util import get_readable_time |
from visualize_util import get_readable_time |
10 |
|
|
|
|
10 |
|
from mse_ssim_loss import MseSsimLoss |
11 |
11 |
import torch |
import torch |
12 |
12 |
from torch import nn |
from torch import nn |
|
13 |
|
from pytorch_ssim import SSIM |
13 |
14 |
|
|
14 |
15 |
from models import CompactCNNV2, CompactCNNV7 |
from models import CompactCNNV2, CompactCNNV7 |
15 |
16 |
|
|
|
... |
... |
import os |
17 |
18 |
from model_util import get_lr |
from model_util import get_lr |
18 |
19 |
|
|
19 |
20 |
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
20 |
|
PROJECT_NAME = "crowd-counting-framework" |
|
21 |
|
# PROJECT_NAME = "crowd-counting-debug" |
|
|
21 |
|
# PROJECT_NAME = "crowd-counting-framework" |
|
22 |
|
PROJECT_NAME = "crowd-counting-debug" |
22 |
23 |
|
|
23 |
24 |
|
|
24 |
25 |
def very_simple_param_count(model): |
def very_simple_param_count(model): |
|
... |
... |
if __name__ == "__main__": |
80 |
81 |
model = model.to(device) |
model = model.to(device) |
81 |
82 |
|
|
82 |
83 |
# loss function |
# loss function |
83 |
|
loss_fn = nn.MSELoss(reduction='sum').to(device) |
|
|
84 |
|
if args.use_ssim: |
|
85 |
|
loss_fn = MseSsimLoss() |
|
86 |
|
print("use ssim") |
|
87 |
|
else: |
|
88 |
|
loss_fn = nn.MSELoss(reduction='sum').to(device) |
84 |
89 |
|
|
85 |
90 |
optimizer = torch.optim.Adam(model.parameters(), args.lr, |
optimizer = torch.optim.Adam(model.parameters(), args.lr, |
86 |
91 |
weight_decay=args.decay) |
weight_decay=args.decay) |