File data_flow.py changed (mode: 100644) (index e05400d..5c329ef) |
... |
... |
def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] co |
750 |
750 |
# so how to sample another dataset entry? |
# so how to sample another dataset entry? |
751 |
751 |
return torch.utils.data.dataloader.default_collate(batch) |
return torch.utils.data.dataloader.default_collate(batch) |
752 |
752 |
|
|
753 |
|
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1): |
|
|
753 |
|
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", visualize_mode=False, batch_size=1, train_loader_for_eval_check = False): |
754 |
754 |
if visualize_mode: |
if visualize_mode: |
755 |
755 |
transformer = transforms.Compose([ |
transformer = transforms.Compose([ |
756 |
756 |
transforms.ToTensor() |
transforms.ToTensor() |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
773 |
773 |
num_workers=0, |
num_workers=0, |
774 |
774 |
collate_fn=my_collate, pin_memory=False) |
collate_fn=my_collate, pin_memory=False) |
775 |
775 |
|
|
|
776 |
|
train_loader_for_eval = torch.utils.data.DataLoader( |
|
777 |
|
ListDataset(train_list, |
|
778 |
|
shuffle=False, |
|
779 |
|
transform=transformer, |
|
780 |
|
train=False, |
|
781 |
|
batch_size=batch_size, |
|
782 |
|
num_workers=0, |
|
783 |
|
dataset_name=dataset_name), |
|
784 |
|
batch_size=1, |
|
785 |
|
num_workers=0, |
|
786 |
|
collate_fn=my_collate, pin_memory=False) |
|
787 |
|
|
776 |
788 |
if val_list is not None: |
if val_list is not None: |
777 |
789 |
val_loader = torch.utils.data.DataLoader( |
val_loader = torch.utils.data.DataLoader( |
778 |
790 |
ListDataset(val_list, |
ListDataset(val_list, |
|
... |
... |
def get_dataloader(train_list, val_list, test_list, dataset_name="shanghaitech", |
798 |
810 |
pin_memory=False) |
pin_memory=False) |
799 |
811 |
else: |
else: |
800 |
812 |
test_loader = None |
test_loader = None |
801 |
|
|
|
802 |
|
return train_loader, val_loader, test_loader |
|
|
813 |
|
if train_loader_for_eval_check: |
|
814 |
|
return train_loader, train_loader_for_eval, val_loader, test_loader |
|
815 |
|
else: |
|
816 |
|
return train_loader, val_loader, test_loader |
File experiment_main.py copied from file experiment_meow_main.py (similarity 77%) (mode: 100644) (index 049432d..3d90052) |
... |
... |
from models.meow_experiment.ccnn_head import H1, H2 |
16 |
16 |
from models.meow_experiment.kitten_meow_1 import H1_Bigtail3 |
from models.meow_experiment.kitten_meow_1 import H1_Bigtail3 |
17 |
17 |
from models import CustomCNNv2 |
from models import CustomCNNv2 |
18 |
18 |
import os |
import os |
19 |
|
from model_util import get_lr |
|
|
19 |
|
from model_util import get_lr, BestMetrics |
20 |
20 |
|
|
21 |
21 |
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
COMET_ML_API = "S3mM1eMq6NumMxk2QJAXASkUM" |
22 |
22 |
PROJECT_NAME = "meow-one-experiment-insita" |
PROJECT_NAME = "meow-one-experiment-insita" |
|
... |
... |
if __name__ == "__main__": |
41 |
41 |
experiment.log_text(args.note) |
experiment.log_text(args.note) |
42 |
42 |
|
|
43 |
43 |
DATA_PATH = args.input |
DATA_PATH = args.input |
44 |
|
TRAIN_PATH = os.path.join(DATA_PATH, "train_data") |
|
|
44 |
|
TRAIN_PATH = os.path.join(DATA_PATH, "train_data_train_split") |
|
45 |
|
VAL_PATH = os.path.join(DATA_PATH, "train_data_validate_split") |
45 |
46 |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
TEST_PATH = os.path.join(DATA_PATH, "test_data") |
46 |
47 |
dataset_name = args.datasetname |
dataset_name = args.datasetname |
47 |
48 |
if dataset_name=="shanghaitech": |
if dataset_name=="shanghaitech": |
|
... |
... |
if __name__ == "__main__": |
54 |
55 |
|
|
55 |
56 |
# create list |
# create list |
56 |
57 |
train_list = create_image_list(TRAIN_PATH) |
train_list = create_image_list(TRAIN_PATH) |
|
58 |
|
val_list = create_image_list(VAL_PATH) |
57 |
59 |
test_list = create_image_list(TEST_PATH) |
test_list = create_image_list(TEST_PATH) |
58 |
60 |
|
|
59 |
61 |
# create data loader |
# create data loader |
60 |
|
train_loader, train_loader_for_eval, test_loader = get_dataloader(train_list, train_list, test_list, dataset_name=dataset_name, batch_size=args.batch_size) |
|
|
62 |
|
train_loader, train_loader_eval, val_loader, test_loader = get_dataloader(train_list, val_list, test_list, dataset_name=dataset_name, batch_size=args.batch_size, |
|
63 |
|
train_loader_for_eval_check=True) |
61 |
64 |
|
|
62 |
65 |
print("len train_loader ", len(train_loader)) |
print("len train_loader ", len(train_loader)) |
63 |
66 |
|
|
|
... |
... |
if __name__ == "__main__": |
122 |
125 |
'mse': CrowdCountingMeanSquaredError(), |
'mse': CrowdCountingMeanSquaredError(), |
123 |
126 |
'loss': Loss(loss_fn) |
'loss': Loss(loss_fn) |
124 |
127 |
}, device=device) |
}, device=device) |
|
128 |
|
|
|
129 |
|
evaluator_test = create_supervised_evaluator(model, |
|
130 |
|
metrics={ |
|
131 |
|
'mae': CrowdCountingMeanAbsoluteError(), |
|
132 |
|
'mse': CrowdCountingMeanSquaredError(), |
|
133 |
|
'loss': Loss(loss_fn) |
|
134 |
|
}, device=device) |
|
135 |
|
|
|
136 |
|
best_mae = BestMetrics(best_metric="mae") |
|
137 |
|
best_mse = BestMetrics(best_metric="mse") |
|
138 |
|
|
|
139 |
|
|
125 |
140 |
print(model) |
print(model) |
126 |
141 |
|
|
127 |
142 |
print(args) |
print(args) |
|
... |
... |
if __name__ == "__main__": |
130 |
145 |
# timer |
# timer |
131 |
146 |
train_timer = Timer(average=True) # time to train whole epoch |
train_timer = Timer(average=True) # time to train whole epoch |
132 |
147 |
batch_timer = Timer(average=True) # every batch |
batch_timer = Timer(average=True) # every batch |
133 |
|
evaluate_timer = Timer(average=True) |
|
|
148 |
|
evaluate_validate_timer = Timer(average=True) |
|
149 |
|
evaluate_test_timer = Timer(average=True) |
134 |
150 |
|
|
135 |
151 |
batch_timer.attach(trainer, |
batch_timer.attach(trainer, |
136 |
152 |
start =Events.EPOCH_STARTED, |
start =Events.EPOCH_STARTED, |
|
... |
... |
if __name__ == "__main__": |
166 |
182 |
|
|
167 |
183 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
168 |
184 |
def log_training_results(trainer): |
def log_training_results(trainer): |
169 |
|
evaluator_train.run(train_loader_for_eval) |
|
|
185 |
|
evaluator_train.run(train_loader_eval) |
170 |
186 |
metrics = evaluator_train.state.metrics |
metrics = evaluator_train.state.metrics |
171 |
187 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
172 |
188 |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
print(timestamp + " Training set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
... |
... |
if __name__ == "__main__": |
185 |
201 |
|
|
186 |
202 |
@trainer.on(Events.EPOCH_COMPLETED) |
@trainer.on(Events.EPOCH_COMPLETED) |
187 |
203 |
def log_validation_results(trainer): |
def log_validation_results(trainer): |
188 |
|
evaluate_timer.resume() |
|
189 |
|
evaluator_validate.run(test_loader) |
|
190 |
|
evaluate_timer.pause() |
|
191 |
|
evaluate_timer.step() |
|
|
204 |
|
evaluate_validate_timer.resume() |
|
205 |
|
evaluator_validate.run(val_loader) |
|
206 |
|
evaluate_validate_timer.pause() |
|
207 |
|
evaluate_validate_timer.step() |
192 |
208 |
|
|
193 |
209 |
metrics = evaluator_validate.state.metrics |
metrics = evaluator_validate.state.metrics |
194 |
210 |
timestamp = get_readable_time() |
timestamp = get_readable_time() |
|
... |
... |
if __name__ == "__main__": |
199 |
215 |
experiment.log_metric("valid_loss", metrics['loss']) |
experiment.log_metric("valid_loss", metrics['loss']) |
200 |
216 |
|
|
201 |
217 |
# timer |
# timer |
202 |
|
experiment.log_metric("evaluate_timer", evaluate_timer.value()) |
|
203 |
|
print("evaluate_timer ", evaluate_timer.value()) |
|
|
218 |
|
experiment.log_metric("evaluate_valid_timer", evaluate_validate_timer.value()) |
|
219 |
|
print("evaluate_valid_timer ", evaluate_validate_timer.value()) |
|
220 |
|
|
|
221 |
|
# check if that validate is best |
|
222 |
|
flag_mae = best_mae.checkAndRecord(metrics['mae']) |
|
223 |
|
flag_mse = best_mse.checkAndRecord(metrics['mse']) |
|
224 |
|
|
|
225 |
|
if flag_mae or flag_mse: |
|
226 |
|
evaluate_test_timer.resume() |
|
227 |
|
evaluator_test.run(test_loader) |
|
228 |
|
evaluate_test_timer.pause() |
|
229 |
|
evaluate_test_timer.step() |
|
230 |
|
test_metrics = evaluator_test.state.metrics |
|
231 |
|
timestamp = get_readable_time() |
|
232 |
|
print(timestamp + " Test set Results - Epoch: {} Avg mae: {:.2f} Avg mse: {:.2f} Avg loss: {:.2f}" |
|
233 |
|
.format(trainer.state.epoch, test_metrics['mae'], test_metrics['mse'], test_metrics['loss'])) |
|
234 |
|
experiment.log_metric("test_mae", test_metrics['mae']) |
|
235 |
|
experiment.log_metric("test_mse", test_metrics['mse']) |
|
236 |
|
experiment.log_metric("test_loss", test_metrics['loss']) |
|
237 |
|
experiment.log_metric("valid_best_epoch", trainer.state.epoch) |
|
238 |
|
|
204 |
239 |
|
|
205 |
240 |
def checkpoint_valid_mae_score_function(engine): |
def checkpoint_valid_mae_score_function(engine): |
206 |
241 |
score = engine.state.metrics['mae'] |
score = engine.state.metrics['mae'] |