File persia_pytorch_toolkit/meter_utils.py changed (mode: 100644) (index da129aa..5cb420e) |
1 |
1 |
import torch |
import torch |
|
2 |
|
import queue |
2 |
3 |
|
|
3 |
4 |
class RiemannAUCMeter(): |
class RiemannAUCMeter(): |
4 |
5 |
""" |
""" |
|
... |
... |
class RiemannAUCMeter(): |
13 |
14 |
indices = torch.clamp(outputs * self.num_bins, min=0, max=self.num_bins - 1).long() |
indices = torch.clamp(outputs * self.num_bins, min=0, max=self.num_bins - 1).long() |
14 |
15 |
p_indices = torch.masked_select(indices, labels != 0) |
p_indices = torch.masked_select(indices, labels != 0) |
15 |
16 |
n_indices = torch.masked_select(indices, labels == 0) |
n_indices = torch.masked_select(indices, labels == 0) |
16 |
|
p_mask = torch.sparse.FloatTensor(p_indices.unsqueeze(0), torch.ones_like(p_indices), torch.Size([self.num_bins])) |
|
17 |
|
n_mask = torch.sparse.FloatTensor(n_indices.unsqueeze(0), torch.ones_like(n_indices), torch.Size([self.num_bins])) |
|
18 |
|
self.p_cnt += p_mask |
|
19 |
|
self.n_cnt += n_mask |
|
|
17 |
|
self.p_mask = torch.sparse.FloatTensor(p_indices.unsqueeze(0), torch.ones_like(p_indices), torch.Size([self.num_bins])) |
|
18 |
|
self.n_mask = torch.sparse.FloatTensor(n_indices.unsqueeze(0), torch.ones_like(n_indices), torch.Size([self.num_bins])) |
|
19 |
|
self.p_cnt += self.p_mask |
|
20 |
|
self.n_cnt += self.n_mask |
20 |
21 |
|
|
21 |
22 |
def value(self): |
def value(self): |
22 |
23 |
p_sum = self.p_cnt.sum().item() |
p_sum = self.p_cnt.sum().item() |
|
... |
... |
class RiemannAUCMeter(): |
35 |
36 |
self.n_cnt = torch.zeros(self.num_bins, dtype=torch.long) |
self.n_cnt = torch.zeros(self.num_bins, dtype=torch.long) |
36 |
37 |
|
|
37 |
38 |
|
|
|
39 |
|
class RiemannRunningAUCMeter(RiemannAUCMeter): |
|
40 |
|
def __init__(self, num_bins=100000, buffer_size=100): |
|
41 |
|
super().__init__(num_bins) |
|
42 |
|
self.buffer_size = buffer_size |
|
43 |
|
self.buffer = queue.Queue(maxsize=self.buffer_size) |
|
44 |
|
|
|
45 |
|
def add(self, outputs, labels): |
|
46 |
|
super().add(outputs, labels) |
|
47 |
|
if self.buffer.full(): |
|
48 |
|
p_mask, n_mask = self.buffer.get() |
|
49 |
|
self.p_cnt -= p_mask |
|
50 |
|
self.n_cnt -= n_mask |
|
51 |
|
self.buffer.put((self.p_mask, self.n_mask)) |
|
52 |
|
|
|
53 |
|
|
38 |
54 |
if __name__ == "__main__": |
if __name__ == "__main__": |
39 |
55 |
meter = RiemannAUCMeter() |
meter = RiemannAUCMeter() |
40 |
56 |
torch.manual_seed(7) |
torch.manual_seed(7) |