/main_pacnn.py (9902b26936ede5c05a10e36fffd713a176c080e0) (5749 bytes) (mode 100644) (type blob)

from args_util import real_args_parse
from data_flow import get_train_val_list, get_dataloader, create_training_image_list
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError
import torch
from torch import nn
import torch.nn.functional as F
from models import CSRNet,PACNN
import os
import cv2
from torchvision import datasets, transforms
from data_flow import ListDataset
import pytorch_ssim
from time import time
from evaluator import MAECalculator

from model_util import save_checkpoint

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # device = "cpu"
    print(device)
    args = real_args_parse()
    print(args)
    DATA_PATH = args.input
    DATASET_NAME = "shanghaitech"

    # create list
    if DATASET_NAME is "shanghaitech":
        TRAIN_PATH = os.path.join(DATA_PATH, "train_data")
        TEST_PATH = os.path.join(DATA_PATH, "test_data")
        train_list, val_list = get_train_val_list(TRAIN_PATH)
        test_list = create_training_image_list(TEST_PATH)
    elif DATASET_NAME is "ucf_cc_50":
        train_list, val_list = get_train_val_list(DATA_PATH, test_size=0.2)
        test_list = None

    # create data loader
    train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name="ucf_cc_50")
    train_loader_pacnn = torch.utils.data.DataLoader(
        ListDataset(train_list,
                    shuffle=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                    std=[0.229, 0.224, 0.225]),
                    ]),
                    train=True,
                    batch_size=1,
                    num_workers=4, dataset_name="shanghaitech_pacnn"),
        batch_size=1, num_workers=4)

    val_loader_pacnn = torch.utils.data.DataLoader(
        ListDataset(val_list,
                    shuffle=False,
                    transform=transforms.Compose([
                        transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                    std=[0.229, 0.224, 0.225]),
                    ]),
                    train=False,
                    batch_size=1,
                    num_workers=4, dataset_name="shanghaitech_pacnn"),
        batch_size=1, num_workers=4)

    # create model
    net = PACNN().to(device)
    criterion_mse = nn.MSELoss(size_average=False).to(device)
    criterion_ssim = pytorch_ssim.SSIM(window_size=11).to(device)

    optimizer = torch.optim.SGD(net.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.decay)
    for e in range(1):
        print("start epoch ", e)
        loss_sum = 0
        sample = 0
        start_time = time()
        counting = 0
        for train_img, label in train_loader_pacnn:
            net.train()
            # zero the parameter gradients
            optimizer.zero_grad()

            # load data
            d1_label, d2_label, d3_label = label
            d1_label = d1_label.to(device)
            d2_label = d2_label.to(device)
            d3_label = d3_label.to(device)

            # forward pass

            d1, d2, d3 = net(train_img.to(device))
            loss_1 = criterion_mse(d1, d1_label) + criterion_ssim(d1.unsqueeze(0), d1_label.unsqueeze(0))
            loss_2 = criterion_mse(d2, d2_label) + criterion_ssim(d2.unsqueeze(0), d2_label.unsqueeze(0))
            loss_3 = criterion_mse(d3, d3_label) + criterion_ssim(d3.unsqueeze(0), d3_label.unsqueeze(0))

            loss = loss_1 + loss_2 + loss_3
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            sample += 1
            optimizer.zero_grad()
            counting += 1
            if counting%10 ==0:
                print("counting ", counting, " -- avg loss", loss_sum/sample)
            # if counting == 100:
            #     break

        end_time = time()
        avg_loss = loss_sum/sample
        epoch_time = end_time - start_time
        print(epoch_time, avg_loss, sample)


        save_checkpoint({
            'state_dict': net.state_dict(),
        }, False, "test2")



    # evaluate

    best_checkpoint = torch.load("test2checkpoint.pth.tar")
    net = PACNN().to(device)
    print(net)
    net.load_state_dict(best_checkpoint['state_dict'])

    # device = "cpu"
    mae_calculator_d1 = MAECalculator()
    mae_calculator_d2 = MAECalculator()
    mae_calculator_d3 = MAECalculator()
    with torch.no_grad():
        for val_img, label in val_loader_pacnn:
            net.eval()
            # load data
            d1_label, d2_label, d3_label = label

            # forward pass
            d1, d2, d3 = net(val_img.to(device))

            d1_label = d1_label.to(device)
            d2_label = d2_label.to(device)
            d3_label = d3_label.to(device)

            # score
            mae_calculator_d1.eval(d1.cpu().detach().numpy(), d1_label.cpu().detach().numpy())
            mae_calculator_d2.eval(d2.cpu().detach().numpy(), d2_label.cpu().detach().numpy())
            mae_calculator_d3.eval(d3.cpu().detach().numpy(), d3_label.cpu().detach().numpy())
        print("count ", mae_calculator_d1.count)
        print("d1_val ", mae_calculator_d1.get_mae())
        print("d2_val ", mae_calculator_d2.get_mae())
        print("d3_val ", mae_calculator_d3.get_mae())




Mode Type Size Ref File
100644 blob 42 35eb2d2be02bde3d859e08bfd0a4a87849fc2495 .gitignore
100644 blob 153 7065abeef0ce647898419b0a10698cfa2532a8a1 README.md
100644 blob 2320 6728e74e0a9a175e9d67721a8c3e4065404774d0 args_util.py
040000 tree - 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 backup_notebook
040000 tree - 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 bug
100644 blob 1775 1165f1aba0814b448a3595a32bd74f1967509509 crowd_counting_error_metrics.py
100644 blob 9444 7e0162d07beb25c5bd36bc33488ec97576dbec3b data_flow.py
040000 tree - c6749f18a4aed2bde949801eaf79004a1495efcc dataset_script
100644 blob 428 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 evaluator.py
100644 blob 160 836144a442570bb6e334d00b7bd9dcfbb267fa0c hard_code_variable.py
100644 blob 5749 9902b26936ede5c05a10e36fffd713a176c080e0 main_pacnn.py
100644 blob 2760 3c2d5ba1c81ef2770ad216c566e268f4ece17262 main_shanghaitech.py
100644 blob 2683 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 main_ucfcc50.py
100644 blob 635 8bb883aa2f26897c89c5ef13e364e213cd301b53 model_util.py
040000 tree - fe5046dd5067efbd25af7883b5823434bac67436 models
040000 tree - 970ac54d8293aed6667e016f2245547f3a5449c3 pytorch_ssim
040000 tree - d288bbcd0753d5be01e5bfcd2f1b4114c9adecee train_script
100644 blob 2064 cd7d5eb28fcb11143d93a34a07a5cf4c46f9479e visualize_data_loader.py
100644 blob 611 4064de3335374a5d537060f597c1a41c1d305ad0 visualize_util.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main