Skip to content

Natooz/TorchToolkit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TorchToolkit

PyPI version fury.io Python 3.8 GitHub CI GitHub license

Hi 👋, this is a small Python package containing useful functions to use with PyTorch. It includes utilities, metrics and sampling methods to use during and after training a model.

Feel free to use it, take the code for your projects, and raise an issue if you have question or a pull request if you want to contribute.

pip install torchtoolkit

It requires Python 3.8 or above.

Simplest example:

from torchtoolkit.metrics import Accuracy
from torch import randint, randn
from pathlib import Path

acc = Accuracy(mode='top_k', top_kp=5)
for _ in range(10):
    res = randn((16, 32))
    expected = randint(0, 32, (16, ))
    acc(res, expected)  # saving results
acc.save(Path('path', 'to', 'save', 'file.csv'))
acc.analyze()

I built it for my own usage, so you won't find documentation besides the docstring.