GithubHelp home page GithubHelp logo

havakv / torchtuples Goto Github PK

View Code? Open in Web Editor NEW
26.0 2.0 13.0 378 KB

Training neural networks in PyTorch

License: BSD 2-Clause "Simplified" License

Python 100.00%
pytorch python neural-network machine-learning deep-learning

torchtuples's Introduction

torchtuples

Python package PyPI PyPI PyPI - Python Version License

torchtuples is a small python package for training PyTorch models. It works equally well for numpy arrays and torch tensors. One of the main benefits of torchtuples is that it handles data in the form of nested tuples (see example below).

Installation

torchtuples depends on PyTorch which should be installed from HERE.

Next, torchtuples can be installed with pip:

pip install torchtuples

Or, via conda:

conda install -c conda-forge torchtuples

For the bleeding edge version, install directly from github (consider adding --force-reinstall):

pip install git+git://github.com/havakv/torchtuples.git

or by cloning the repo:

git clone https://github.com/havakv/torchtuples.git
cd torchtuples
python setup.py install

Example

import torch
from torch import nn
from torchtuples import Model, optim

Make a data set with three sets of covariates x0, x1 and x2, and a target y. The covariates are structured in a nested tuple x.

n = 500
x0, x1, x2 = [torch.randn(n, 3) for _ in range(3)]
y = torch.randn(n, 1)
x = (x0, (x0, x1, x2))

Create a simple ReLU net that takes as input the tensor x_tensor and the tuple x_tuple. Note that x_tuple can be of arbitrary length. The tensors in x_tuple are passed through a layer lin_tuple, averaged, and concatenated with x_tensor. We then pass our new tensor through the layer lin_cat.

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin_tuple = nn.Linear(3, 2)
        self.lin_cat = nn.Linear(5, 1)
        self.relu = nn.ReLU()

    def forward(self, x_tensor, x_tuple):
        x = [self.relu(self.lin_tuple(xi)) for xi in x_tuple]
        x = torch.stack(x).mean(0)
        x = torch.cat([x, x_tensor], dim=1)
        return self.lin_cat(x)

    def predict(self, x_tensor, x_tuple):
        x = self.forward(x_tensor, x_tuple)
        return torch.sigmoid(x)

We can now fit the model with

model = Model(Net(), nn.MSELoss(), optim.SGD(0.01))
log = model.fit(x, y, batch_size=64, epochs=5)

and make predictions with either the Net.predict method

preds = model.predict(x)

or with the Net.forward method

preds = model.predict_net(x)

For more examples, see the examples folder.

torchtuples's People

Contributors

havakv avatar sarthakpati avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

torchtuples's Issues

Can I use lr_scheduler of Pytorch in model.fit ?

First of all, thank you for your great works !!

Can I use lr_scheduler of Pytorch in model.fit ?
I've noticed that there is lr_scheduler.py in this repo, however, I couldn't find how to use it in examples.
If it's possible, then how can I modify the following code to use the scheduler for learning rate?

log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)

Many thanks

File path in Windows

The generated file paths include multiple colons ":" which doesn't work on windows.

Using LSTM instead of MLPVanilla

In order to use LSTM instead of MLPVanilla with the CoxTime and CoxPH models, I have the following model class. It works mechanically, but I want to make sure that the implementation is theoretically correct. I'm trying to make each patient the input sequence for the LSTM model and the hidden and cell states can be transferred within that sequence, not on the whole batch of patients as a sequence. Would you be able to share some insights?

from torch import nn

class LSTMCox(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_layers, output_size):
      super(LSTMCox, self).__init__()
      self.n_layers = n_layers
      self.hidden_dim = hidden_dim
      self.embedding_dim = embedding_dim
      
      self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers)
      self.fc = nn.Linear(hidden_dim, output_size)
      self.activation = nn.ReLU()

    def forward(self, input):
      input = input.view(len(input), 1, self.embedding_dim)

      lstm_out, _ = self.lstm(input)
      lstm_out = lstm_out.contiguous().view(len(input), -1)

      out = self.fc(lstm_out)
      out = self.activation(out)

      return out

net = LSTMCox(in_features, 512, 1, 1)
model = CoxPH(net, tt.optim.Adam)
model.optimizer.set_lr(0.01)
log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, val_batch_size=batch_size)

Stop wrapping optimizers to simplify use fo torch optimizers

Right now all optimizers are wrapped, so to access a torch optimizer object we need to call model.optimizer.optimizer. It might makes sense to be able to get the torch optimizer with model.optimizer.

If we continue to wrap torch optimizers, maybe make it a wrapper objects model.optimizer_wrapper.

Add conda recipe

Adding a conda recipe would make distribution simpler, especially in combination with packages that require C/C++ libraries.

Cheers,
Sarthak

Make it simpler to access train/val metrics

Now one needs to call model.val_metrics.scores['loss']['score'][-1] and model.train_metrics.scores['loss']['score'][-1] to get the last training/validation scores. This should be easier to access.

Also, should probably make the same change to the MonitorMetrics callback.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.