Subject | Hash | Author | Date (UTC) |
---|---|---|---|
add riemann auc estimator | 8ce1b37008f10d6ca825728b1d0106b03bb0529d | Xiangru Lian | 2019-01-08 21:11:27 |
fix | d4f0a50ffb6afea07df425d759fdf8460839d0ce | Xiangru Lian | 2019-01-07 22:27:56 |
add numpy support | 2e10283edfc427ea4da6a76c8c3a6f85a44bace4 | Xiangru Lian | 2019-01-07 22:24:12 |
update | fcfc8a26e275395688fb3fa50538f890c1f612e4 | Xiangru Lian | 2019-01-07 21:55:04 |
use markdown for readme | f46e30761b4f361d5494484d9bdb20dc04a20d83 | Xiangru Lian | 2019-01-07 21:53:46 |
add readme | a9a0b91ea0e0485039ea76b6b69c5f983815e8a3 | Xiangru Lian | 2019-01-07 21:52:31 |
fix | 16a965862eaefa554a540fd2eeee26f73907d5dd | Xiangru Lian | 2019-01-07 21:37:12 |
fix setup.py | 24e85045cc859ffa7d6f514c80257fde94c0794a | Xiangru Lian | 2019-01-07 21:35:15 |
compatible with setup.py | 657782ac87ea9251f2e3726b483089eca5384caf | Xiangru Lian | 2019-01-07 21:33:48 |
initial commit | 27cb0b3b4f5f20240af6d14ead4f373a1aaa5343 | Xiangru Lian | 2019-01-07 21:26:20 |
Update .gitignore | 52642556fea807b773617938eecd2d9d17f34bc6 | ikzk | 2019-01-07 20:16:37 |
File | Lines added | Lines deleted |
---|---|---|
persia_pytorch_toolkit/__init__.py | 1 | 0 |
persia_pytorch_toolkit/meter_utils.py | 42 | 0 |
File persia_pytorch_toolkit/__init__.py changed (mode: 100644) (index ac56390..c716994) | |||
1 | 1 | from . import utils | from . import utils |
2 | 2 | from . import model_utils | from . import model_utils |
3 | 3 | from . import tensor_utils | from . import tensor_utils |
4 | from . import meter_utils |
File persia_pytorch_toolkit/meter_utils.py added (mode: 100644) (index 0000000..e641078) | |||
1 | import torch | ||
2 | |||
3 | class RiemannAUCMeter(): | ||
4 | """ | ||
5 | Efficient PyTorch AUC Estimator. | ||
6 | """ | ||
7 | def __init__(self, num_bins=100000): | ||
8 | self.num_bins = num_bins | ||
9 | self.p_cnt = torch.zeros(self.num_bins, dtype=torch.long) | ||
10 | self.n_cnt = torch.zeros(self.num_bins, dtype=torch.long) | ||
11 | |||
12 | def add(self, outputs, labels): | ||
13 | """Outputs should be probabilities from 0 to 1.""" | ||
14 | indices = torch.clamp(outputs * self.num_bins, min=0, max=self.num_bins - 1).long() | ||
15 | p_indices = torch.masked_select(indices, labels != 0) | ||
16 | n_indices = torch.masked_select(indices, labels == 0) | ||
17 | p_mask = torch.sparse.FloatTensor(p_indices.unsqueeze(0), torch.ones_like(p_indices), torch.Size([self.num_bins])) | ||
18 | n_mask = torch.sparse.FloatTensor(n_indices.unsqueeze(0), torch.ones_like(n_indices), torch.Size([self.num_bins])) | ||
19 | self.p_cnt += p_mask | ||
20 | self.n_cnt += n_mask | ||
21 | |||
22 | def value(self): | ||
23 | p_sum = self.p_cnt.sum().item() | ||
24 | n_sum = self.n_cnt.sum().item() | ||
25 | |||
26 | prod = torch.dot(self.p_cnt, self.n_cnt).item() * 0.5 | ||
27 | n_cumsum = self.n_cnt.cumsum(0) | ||
28 | up_sum = torch.dot(self.p_cnt, n_cumsum).item() + prod | ||
29 | try: | ||
30 | return float(up_sum)/float(p_sum*n_sum) | ||
31 | except ZeroDivisionError: | ||
32 | return 0 | ||
33 | |||
34 | |||
35 | if __name__ == "__main__": | ||
36 | meter = RiemannAUCMeter() | ||
37 | torch.manual_seed(7) | ||
38 | outputs = torch.rand(10000) | ||
39 | labels = torch.rand(10000) > 0.4 | ||
40 | for _ in range(100): | ||
41 | meter.add(outputs, labels.float()) | ||
42 | print(meter.value()) |