List of commits:
Subject Hash Author Date (UTC)
add riemann auc estimator 8ce1b37008f10d6ca825728b1d0106b03bb0529d Xiangru Lian 2019-01-08 21:11:27
fix d4f0a50ffb6afea07df425d759fdf8460839d0ce Xiangru Lian 2019-01-07 22:27:56
add numpy support 2e10283edfc427ea4da6a76c8c3a6f85a44bace4 Xiangru Lian 2019-01-07 22:24:12
update fcfc8a26e275395688fb3fa50538f890c1f612e4 Xiangru Lian 2019-01-07 21:55:04
use markdown for readme f46e30761b4f361d5494484d9bdb20dc04a20d83 Xiangru Lian 2019-01-07 21:53:46
add readme a9a0b91ea0e0485039ea76b6b69c5f983815e8a3 Xiangru Lian 2019-01-07 21:52:31
fix 16a965862eaefa554a540fd2eeee26f73907d5dd Xiangru Lian 2019-01-07 21:37:12
fix setup.py 24e85045cc859ffa7d6f514c80257fde94c0794a Xiangru Lian 2019-01-07 21:35:15
compatible with setup.py 657782ac87ea9251f2e3726b483089eca5384caf Xiangru Lian 2019-01-07 21:33:48
initial commit 27cb0b3b4f5f20240af6d14ead4f373a1aaa5343 Xiangru Lian 2019-01-07 21:26:20
Update .gitignore 52642556fea807b773617938eecd2d9d17f34bc6 ikzk 2019-01-07 20:16:37
Commit 8ce1b37008f10d6ca825728b1d0106b03bb0529d - add riemann auc estimator
Author: Xiangru Lian
Author date (UTC): 2019-01-08 21:11
Committer name: Xiangru Lian
Committer date (UTC): 2019-01-08 21:11
Parent(s): d4f0a50ffb6afea07df425d759fdf8460839d0ce
Signing key:
Tree: c234f044ab6c21f313a121f51b56251551556a23
File Lines added Lines deleted
persia_pytorch_toolkit/__init__.py 1 0
persia_pytorch_toolkit/meter_utils.py 42 0
File persia_pytorch_toolkit/__init__.py changed (mode: 100644) (index ac56390..c716994)
1 1 from . import utils from . import utils
2 2 from . import model_utils from . import model_utils
3 3 from . import tensor_utils from . import tensor_utils
4 from . import meter_utils
File persia_pytorch_toolkit/meter_utils.py added (mode: 100644) (index 0000000..e641078)
1 import torch
2
3 class RiemannAUCMeter():
4 """
5 Efficient PyTorch AUC Estimator.
6 """
7 def __init__(self, num_bins=100000):
8 self.num_bins = num_bins
9 self.p_cnt = torch.zeros(self.num_bins, dtype=torch.long)
10 self.n_cnt = torch.zeros(self.num_bins, dtype=torch.long)
11
12 def add(self, outputs, labels):
13 """Outputs should be probabilities from 0 to 1."""
14 indices = torch.clamp(outputs * self.num_bins, min=0, max=self.num_bins - 1).long()
15 p_indices = torch.masked_select(indices, labels != 0)
16 n_indices = torch.masked_select(indices, labels == 0)
17 p_mask = torch.sparse.FloatTensor(p_indices.unsqueeze(0), torch.ones_like(p_indices), torch.Size([self.num_bins]))
18 n_mask = torch.sparse.FloatTensor(n_indices.unsqueeze(0), torch.ones_like(n_indices), torch.Size([self.num_bins]))
19 self.p_cnt += p_mask
20 self.n_cnt += n_mask
21
22 def value(self):
23 p_sum = self.p_cnt.sum().item()
24 n_sum = self.n_cnt.sum().item()
25
26 prod = torch.dot(self.p_cnt, self.n_cnt).item() * 0.5
27 n_cumsum = self.n_cnt.cumsum(0)
28 up_sum = torch.dot(self.p_cnt, n_cumsum).item() + prod
29 try:
30 return float(up_sum)/float(p_sum*n_sum)
31 except ZeroDivisionError:
32 return 0
33
34
35 if __name__ == "__main__":
36 meter = RiemannAUCMeter()
37 torch.manual_seed(7)
38 outputs = torch.rand(10000)
39 labels = torch.rand(10000) > 0.4
40 for _ in range(100):
41 meter.add(outputs, labels.float())
42 print(meter.value())
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"

Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit

Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit

You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main