/main_shanghaitech.py (3c2d5ba1c81ef2770ad216c566e268f4ece17262) (2760 bytes) (mode 100644) (type blob)

from args_util import real_args_parse
from data_flow import get_train_val_list, get_dataloader, create_training_image_list
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss, MeanAbsoluteError, MeanSquaredError
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError
import torch
from torch import nn
from models import CSRNet
import os

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    args = real_args_parse()
    print(args)
    DATA_PATH = args.input
    TRAIN_PATH = os.path.join(DATA_PATH, "train_data")
    TEST_PATH = os.path.join(DATA_PATH, "test_data")


    # create list
    train_list, val_list = get_train_val_list(TRAIN_PATH)
    test_list = create_training_image_list(TEST_PATH)

    # create data loader
    train_loader, val_loader, test_loader = get_dataloader(train_list, val_list, test_list)


    # model
    model = CSRNet()
    model = model.to(device)

    # loss function
    loss_fn = nn.MSELoss(size_average=False).cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.decay)

    trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'mae': CrowdCountingMeanAbsoluteError(),
                                                'mse': CrowdCountingMeanSquaredError(),
                                                'nll': Loss(loss_fn)
                                            }, device=device)
    print(model)

    print(args)


    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))


    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        print("Validation Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
              .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll']))


    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        print("Validation Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
              .format(trainer.state.epoch, metrics['mae'], metrics['mse'], metrics['nll']))


    trainer.run(train_loader, max_epochs=10)

Mode Type Size Ref File
100644 blob 42 35eb2d2be02bde3d859e08bfd0a4a87849fc2495 .gitignore
100644 blob 153 7065abeef0ce647898419b0a10698cfa2532a8a1 README.md
100644 blob 2320 6728e74e0a9a175e9d67721a8c3e4065404774d0 args_util.py
040000 tree - 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 backup_notebook
040000 tree - 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 bug
100644 blob 1775 1165f1aba0814b448a3595a32bd74f1967509509 crowd_counting_error_metrics.py
100644 blob 9444 7e0162d07beb25c5bd36bc33488ec97576dbec3b data_flow.py
040000 tree - c6749f18a4aed2bde949801eaf79004a1495efcc dataset_script
100644 blob 428 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 evaluator.py
100644 blob 160 836144a442570bb6e334d00b7bd9dcfbb267fa0c hard_code_variable.py
100644 blob 5749 9902b26936ede5c05a10e36fffd713a176c080e0 main_pacnn.py
100644 blob 2760 3c2d5ba1c81ef2770ad216c566e268f4ece17262 main_shanghaitech.py
100644 blob 2683 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 main_ucfcc50.py
100644 blob 635 8bb883aa2f26897c89c5ef13e364e213cd301b53 model_util.py
040000 tree - fe5046dd5067efbd25af7883b5823434bac67436 models
040000 tree - 970ac54d8293aed6667e016f2245547f3a5449c3 pytorch_ssim
040000 tree - d288bbcd0753d5be01e5bfcd2f1b4114c9adecee train_script
100644 blob 2064 cd7d5eb28fcb11143d93a34a07a5cf4c46f9479e visualize_data_loader.py
100644 blob 611 4064de3335374a5d537060f597c1a41c1d305ad0 visualize_util.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main