/experiment_main.py (b6b690d5fa639867bfa892cd1218be83e804c9bb) (12605 bytes) (mode 100644) (type blob)

from comet_ml import Experiment

from args_util import meow_parse
from data_flow import get_dataloader, create_image_list
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Loss
from ignite.handlers import Checkpoint, DiskSaver, Timer
from crowd_counting_error_metrics import CrowdCountingMeanAbsoluteError, CrowdCountingMeanSquaredError, CrowdCountingMeanAbsoluteErrorWithCount, CrowdCountingMeanSquaredErrorWithCount
from visualize_util import get_readable_time

import torch
from torch import nn
from models.meow_experiment.kitten_meow_1 import M1, M2, M3, M4
from models.meow_experiment.ccnn_tail import BigTailM1, BigTailM2, BigTail3, BigTail4, BigTail5, BigTail6, BigTail7
from models.meow_experiment.ccnn_head import H1, H2, H3
from models.meow_experiment.kitten_meow_1 import H1_Bigtail3
from models import CustomCNNv2, CompactCNNV7
from models.compact_cnn import CompactCNNV8, CompactCNNV9
import os
from model_util import get_lr, BestMetrics

COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM"
PROJECT_NAME = "crowd-counting-train-val"
# PROJECT_NAME = "crowd-counting-debug"


def very_simple_param_count(model):
    result = sum([p.numel() for p in model.parameters()])
    return result


if __name__ == "__main__":
    torch.set_num_threads(2) # 4 thread
    experiment = Experiment(project_name=PROJECT_NAME, api_key=COMET_ML_API)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    args = meow_parse()
    print(args)

    experiment.set_name(args.task_id)
    experiment.set_cmd_args()
    experiment.log_text(args.note)

    DATA_PATH = args.input
    TRAIN_PATH = os.path.join(DATA_PATH, "train_data_train_split")
    VAL_PATH = os.path.join(DATA_PATH, "train_data_validate_split")
    TEST_PATH = os.path.join(DATA_PATH, "test_data")
    dataset_name = args.datasetname
    if dataset_name=="shanghaitech":
        print("will use shanghaitech dataset with crop ")
    elif dataset_name == "shanghaitech_keepfull":
        print("will use shanghaitech_keepfull")
    else:
        print("cannot detect dataset_name")
        print("current dataset_name is ", dataset_name)

    # create list
    train_list = create_image_list(TRAIN_PATH)
    val_list = create_image_list(VAL_PATH)
    test_list = create_image_list(TEST_PATH)

    # create data loader
    train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name, batch_size=args.batch_size,
                                                                              train_loader_for_eval_check=True)

    print("len train_loader ", len(train_loader))

    # model
    model_name = args.model
    experiment.log_other("model", model_name)
    experiment.add_tag(model_name)
    if model_name == "M1":
        model = M1()
    elif model_name == "M2":
        model = M2()
    elif model_name == "M3":
        model = M3()
    elif model_name == "M4":
        model = M4()
    elif model_name == "CustomCNNv2":
        model = CustomCNNv2()
    elif model_name == "BigTailM1":
        model = BigTailM1()
    elif model_name == "BigTailM2":
        model = BigTailM2()
    elif model_name == "BigTail3":
        model = BigTail3()
    elif model_name == "BigTail4":
        model = BigTail4()
    elif model_name == "BigTail5":
        model = BigTail5()
    elif model_name == "BigTail6":
        model = BigTail6()
    elif model_name == "BigTail7":
        model = BigTail7()
    elif model_name == "H1":
        model = H1()
    elif model_name == "H2":
        model = H2()
    elif model_name == "H3":
        model = H3()
    elif model_name == "H1_Bigtail3":
        model = H1_Bigtail3()
    elif model_name == "CompactCNNV7":
        model = CompactCNNV7()
    elif model_name == "CompactCNNV8":
        model = CompactCNNV8()
    elif model_name == "CompactCNNV9":
        model = CompactCNNV9()
    else:
        print("error: you didn't pick a model")
        exit(-1)
    n_param = very_simple_param_count(model)
    experiment.log_other("n_param", n_param)
    if hasattr(model, 'model_note'):
        experiment.log_other("model_note", model.model_note)
    model = model.to(device)

    # loss function
    # loss_fn = nn.MSELoss(reduction='sum').to(device)
    if args.loss_fn == "MSE":
        loss_fn = nn.MSELoss(reduction='sum').to(device)
        print("use MSELoss")
    elif args.loss_fn == "L1":
        loss_fn = nn.L1Loss(reduction='sum').to(device)
        print("use L1Loss")
    elif args.loss_fn == "L1Mean":
        loss_fn = nn.L1Loss(reduction='mean').to(device)
        print("use L1Mean")
    elif args.loss_fn == "MSEMean":
        loss_fn = nn.MSELoss(reduction='mean').to(device)
        print("use MSEMean")
    elif args.loss_fn == "MSENone":
        """
        Doesnt work
        because 
        RuntimeError: grad can be implicitly created only for scalar outputs
        """
        loss_fn = nn.MSELoss(reduction='none').to(device)
        print("use MSE without any reduction")
    experiment.add_tag(args.loss_fn)

    if args.optim == "adam":
        optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                weight_decay=args.decay)
        print("use adam")
    elif args.optim == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    weight_decay=args.decay,
                                    momentum=args.momentum)
        print("use sgd")
    experiment.add_tag(args.optim)

    trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device)
    evaluator_train = create_supervised_evaluator(model,
                                            metrics={
                                                'mae': CrowdCountingMeanAbsoluteErrorWithCount(),
                                                'mse': CrowdCountingMeanSquaredErrorWithCount(),
                                               #  'loss': Loss(loss_fn)
                                            }, device=device)

    evaluator_validate = create_supervised_evaluator(model,
                                            metrics={
                                                'mae': CrowdCountingMeanAbsoluteErrorWithCount(),
                                                'mse': CrowdCountingMeanSquaredErrorWithCount(),
                                               # 'loss': Loss(loss_fn)
                                            }, device=device)

    evaluator_test = create_supervised_evaluator(model,
                                            metrics={
                                                'mae': CrowdCountingMeanAbsoluteErrorWithCount(),
                                                'mse': CrowdCountingMeanSquaredErrorWithCount(),
                                               # 'loss': Loss(loss_fn)
                                            }, device=device)

    best_mae = BestMetrics(best_metric="mae")
    best_mse = BestMetrics(best_metric="mse")


    print(model)

    print(args)


    # timer
    train_timer = Timer(average=True)  # time to train whole epoch
    batch_timer = Timer(average=True)  # every batch
    evaluate_validate_timer = Timer(average=True)
    evaluate_test_timer = Timer(average=True)

    batch_timer.attach(trainer,
                        start =Events.EPOCH_STARTED,
                        resume =Events.ITERATION_STARTED,
                        pause =Events.ITERATION_COMPLETED,
                        step =Events.ITERATION_COMPLETED)

    train_timer.attach(trainer,
                        start =Events.EPOCH_STARTED,
                        resume =Events.EPOCH_STARTED,
                        pause =Events.EPOCH_COMPLETED,
                        step =Events.EPOCH_COMPLETED)

    if len(args.load_model) > 0:
        load_model_path = args.load_model
        print("load mode " + load_model_path)
        to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
        checkpoint = torch.load(load_model_path)
        Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
        print("load model complete")
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            print("change lr to ", args.lr)
    else:
        print("do not load, keep training")


    @trainer.on(Events.ITERATION_COMPLETED(every=100))
    def log_training_loss(trainer):
        timestamp = get_readable_time()
        print(timestamp + " Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        experiment.log_metric("epoch", trainer.state.epoch)
        if not args.skip_train_eval:
            evaluator_train.run(train_loader_eval)
            metrics = evaluator_train.state.metrics
            timestamp = get_readable_time()
            print(timestamp + " Training set Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
                  .format(trainer.state.epoch, metrics['mae'], metrics['mse'], 0))
            # experiment.log_metric("epoch", trainer.state.epoch)
            experiment.log_metric("train_mae", metrics['mae'])
            experiment.log_metric("train_mse", metrics['mse'])
            experiment.log_metric("lr", get_lr(optimizer))

            experiment.log_metric("batch_timer", batch_timer.value())
            experiment.log_metric("train_timer", train_timer.value())

            print("batch_timer ", batch_timer.value())
            print("train_timer ", train_timer.value())

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        evaluate_validate_timer.resume()
        evaluator_validate.run(val_loader)
        evaluate_validate_timer.pause()
        evaluate_validate_timer.step()

        metrics = evaluator_validate.state.metrics
        timestamp = get_readable_time()
        print(timestamp + " Validation set Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
              .format(trainer.state.epoch, metrics['mae'], metrics['mse'], 0))
        experiment.log_metric("valid_mae", metrics['mae'])
        experiment.log_metric("valid_mse", metrics['mse'])

        # timer
        experiment.log_metric("evaluate_valid_timer", evaluate_validate_timer.value())
        print("evaluate_valid_timer ", evaluate_validate_timer.value())

        # check if that validate is best
        flag_mae = best_mae.checkAndRecord(metrics['mae'], metrics['mse'])
        flag_mse = best_mse.checkAndRecord(metrics['mae'], metrics['mse'])

        if flag_mae or flag_mse:
            experiment.log_metric("valid_best_mae", metrics['mae'])
            experiment.log_metric("valid_best_mse", metrics['mse'])
            experiment.log_metric("valid_best_epoch", trainer.state.epoch)
            print("BEST VAL, evaluating on test set")
            evaluate_test_timer.resume()
            evaluator_test.run(test_loader)
            evaluate_test_timer.pause()
            evaluate_test_timer.step()
            test_metrics = evaluator_test.state.metrics
            timestamp = get_readable_time()
            print(timestamp + " Test set Results - Epoch: {}  Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}"
                  .format(trainer.state.epoch, test_metrics['mae'], test_metrics['mse'], 0))
            experiment.log_metric("test_mae", test_metrics['mae'])
            experiment.log_metric("test_mse", test_metrics['mse'])
            # experiment.log_metric("test_loss", test_metrics['loss'])

    def checkpoint_valid_mae_score_function(engine):
        score = engine.state.metrics['mae']
        return -score

    # docs on save and load
    to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer}
    save_handler = Checkpoint(to_save, DiskSaver('saved_model/' + args.task_id, create_dir=True, atomic=True),
                              filename_prefix=args.task_id,
                              n_saved=3)

    save_handler_best = Checkpoint(to_save, DiskSaver('saved_model_best/' + args.task_id, create_dir=True, atomic=True),
                              filename_prefix=args.task_id, score_name="valid_mae", score_function=checkpoint_valid_mae_score_function,
                              n_saved=3)

    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=10), save_handler)
    evaluator_validate.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_handler_best)

    trainer.run(train_loader, max_epochs=args.epochs)


Mode Type Size Ref File
100644 blob 61 169fe2b7d512a59cfedf86ddb7ed040173c7434d .gitignore
100644 blob 1342 f2eb3073ff4a8536cf4e8104ff942b525e3c7f34 .travis.yml
100644 blob 1255 1dfa426237bc174a2ba2186240191a6b7041bc86 README.md
100644 blob 8019 b07453bc637967304c072283d44cef1a6a6ef2ac args_util.py
040000 tree - 5e9d7f0e1fd3a9e4d5a37f3d6de0c3ecd3125af8 backup_notebook
040000 tree - 55d1d196f5b6ed4bfc1e8a715df1cfff1dd18117 bug
100644 blob 3591 7b4c18e8cf2c0417cd13d3f77ea0571c9e0e493f crowd_counting_error_metrics.py
100644 blob 41187 eea892e8b92f9bdd41abec25f3fd7400c8e5a0ef data_flow.py
040000 tree - 00d517832c6836e120b895d08d9c3bcc6bf42c61 data_util
040000 tree - be60ebdcacdf7d430e0c089c515a46962034a222 dataset_script
040000 tree - 5968d9e909bfdcf5dd241f7c968e8498891c03bd debug
040000 tree - 9862b9cbc6e7a1d43565f12d85d9b17d1bf1814e env_file
100644 blob 4460 9b254c348a3453f4df2c3ccbf21fb175a16852de eval_context_aware_network.py
100644 blob 428 35cc7bfe48a4ed8dc56635fd3a6763612d8af771 evaluator.py
100644 blob 12605 b6b690d5fa639867bfa892cd1218be83e804c9bb experiment_main.py
100644 blob 8876 049432d6bde50245a4acba4e116d59605b5b6315 experiment_meow_main.py
100644 blob 1916 1d228fa4fa2887927db069f0c93c61a920279d1f explore_model_summary.py
100644 blob 2718 b09b84e8b761137654ba6904669799c4866554b3 hard_code_variable.py
040000 tree - b3aa858a157f5e1e22c00fdb6f9dd071f4c6c163 local_train_script
040000 tree - 927d159228536a86499de8a294700f8599b8a60b logs
100644 blob 15300 cb90faba0bd4a45f2606a1e60975ed05bfacdb07 main_pacnn.py
100644 blob 2760 3c2d5ba1c81ef2770ad216c566e268f4ece17262 main_shanghaitech.py
100644 blob 2683 29189260c1a2c03c8e59cd0b4bd61df19d5ce098 main_ucfcc50.py
100644 blob 2794 f37b3bb572c53dd942c51243bd5b0853228c6ddb model_util.py
040000 tree - ff265ef837660d8534d7bd4af6d53819ee13ed3b models
100644 blob 1066 811554259182e63240d7aa9406f315377b3be1ac mse_ssim_loss.py
040000 tree - 2cc497edce5da8793879cc5e82718d1562ef17e8 playground
040000 tree - c7c295e9e418154ae7c754dc888a77df8f50aa61 pytorch_ssim
100644 blob 1727 1cd14cbff636cb6145c8bacf013e97eb3f7ed578 sanity_check_dataloader.py
040000 tree - a1e8ea43eba8a949288a00fff12974aec8692003 saved_model_best
100644 blob 3525 27067234ad3deddd743dcab0d7b3ba4812902656 train_attn_can_adcrowdnet.py
100644 blob 3488 e47bfc7e91c46ca3c61be0c5258302de4730b06d train_attn_can_adcrowdnet_freeze_vgg.py
100644 blob 5352 3ee3269d6fcc7408901af46bed52b1d86ee9818c train_attn_can_adcrowdnet_simple.py
100644 blob 5728 90b846b68f15bdc58e3fd60b41aa4b5d82864ec4 train_attn_can_adcrowdnet_simple_lrscheduler.py
100644 blob 9081 664051f8838434c386e34e6dd6e6bca862cb3ccd train_compact_cnn.py
100644 blob 5702 fdec7cd1ee062aa4a2182a91e2fb1bd0db3ab35f train_compact_cnn_lrscheduler.py
100644 blob 5611 2a241c876015db34681d73ce534221de482b0b90 train_compact_cnn_sgd.py
100644 blob 3525 eb52f7a4462687c9b2bf1c3a887014c4afefa26d train_context_aware_network.py
100644 blob 5651 48631e36a1fdc063a6d54d9206d2fd45521d8dc8 train_custom_compact_cnn.py
100644 blob 5594 07d6c9c056db36082545b5b60b1c00d9d9f6396d train_custom_compact_cnn_lrscheduler.py
100644 blob 5281 8a92eb87b54f71ad2a799a7e05020344a22e22d3 train_custom_compact_cnn_sgd.py
040000 tree - bffb4d70dad1ae76f41af1ae6b0067421c687d9b train_script
100644 blob 5392 03c78fe177520b309ee21e5c2b7ca67598fad99a visualize_data_loader.py
100644 blob 1146 1b0f845587f0f37166d44fa0c74b51f89cf8b349 visualize_util.py
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/hahattpro/crowd_counting_framework

Clone this repository using git:
git clone git://git.rocketgit.com/user/hahattpro/crowd_counting_framework

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main