/persia_pytorch_toolkit/meter_utils.py (57d3a08765ebb17b5b7df946fb2761b522749473) (2306 bytes) (mode 100644) (type blob)
import torch
import queue
class RiemannAUCMeter():
"""
Efficient PyTorch AUC Estimator.
"""
def __init__(self, num_bins=100000):
self.num_bins = num_bins
self.reset()
def add(self, outputs, labels):
"""Outputs should be probabilities from 0 to 1."""
indices = torch.clamp(outputs * self.num_bins, min=0, max=self.num_bins - 1).long()
p_indices = torch.masked_select(indices, labels != 0)
n_indices = torch.masked_select(indices, labels == 0)
self.p_mask = torch.sparse.FloatTensor(p_indices.unsqueeze(0), torch.ones_like(p_indices), torch.Size([self.num_bins]))
self.n_mask = torch.sparse.FloatTensor(n_indices.unsqueeze(0), torch.ones_like(n_indices), torch.Size([self.num_bins]))
self.p_cnt += self.p_mask
self.n_cnt += self.n_mask
def value(self):
p_sum = self.p_cnt.sum().item()
n_sum = self.n_cnt.sum().item()
prod = torch.dot(self.p_cnt, self.n_cnt).item() * 0.5
n_cumsum = self.n_cnt.cumsum(0)
up_sum = torch.dot(self.p_cnt, n_cumsum).item() + prod
try:
return float(up_sum)/float(p_sum*n_sum)
except ZeroDivisionError:
return 0
def reset(self):
self.p_cnt = torch.zeros(self.num_bins, dtype=torch.long)
self.n_cnt = torch.zeros(self.num_bins, dtype=torch.long)
class RiemannRunningAUCMeter(RiemannAUCMeter):
def __init__(self, num_bins=100000, buffer_size=100):
super().__init__(num_bins)
self.buffer_size = buffer_size
self.buffer = queue.Queue(maxsize=self.buffer_size)
self.p_cnt = torch.zeros(self.num_bins, dtype=torch.long)
self.n_cnt = torch.zeros(self.num_bins, dtype=torch.long)
def add(self, outputs, labels):
super().add(outputs, labels)
if self.buffer.full():
p_mask, n_mask = self.buffer.get()
self.p_cnt -= p_mask
self.n_cnt -= n_mask
self.buffer.put((self.p_mask, self.n_mask))
def reset(self):
pass
if __name__ == "__main__":
meter = RiemannAUCMeter()
torch.manual_seed(7)
outputs = torch.rand(10000)
labels = torch.rand(10000) > 0.4
for _ in range(100):
meter.add(outputs, labels.float())
print(meter.value())
Mode |
Type |
Size |
Ref |
File |
100644 |
blob |
1203 |
894a44cc066a027465cd26d634948d56d13af9af |
.gitignore |
100644 |
blob |
106 |
31e29114018d9b0254e2e8d2d46797c484362eff |
README.md |
040000 |
tree |
- |
db847373d98d5ed84adf118786b63925b7aa504d |
persia_pytorch_toolkit |
100644 |
blob |
354 |
60b2cf036b92d2c0c4be4ceec72b98c30ca04ccb |
setup.py |
Hints:
Before first commit, do not forget to setup your git environment:
git config --global user.name "your_name_here"
git config --global user.email "your@email_here"
Clone this repository using HTTP(S):
git clone https://rocketgit.com/user/ikzk/persia-pytorch-toolkit
Clone this repository using ssh (do not forget to upload a key first):
git clone ssh://rocketgit@ssh.rocketgit.com/user/ikzk/persia-pytorch-toolkit
Clone this repository using git:
git clone git://git.rocketgit.com/user/ikzk/persia-pytorch-toolkit
You are allowed to anonymously push to this repository.
This means that your pushed commits will automatically be transformed into a
merge request:
... clone the repository ...
... make some changes and some commits ...
git push origin main