Working with Pytorch
Classes to pay attention
DataLoader
, DataModule
, Dataset
of pytorch.
training and predicting with the model
Lightning's trainer class makes training straightforward by
Enabling and disabling gradients as needed
Invoking callback functions
Dispatching data and computations to appropriate devices
# 1. Training
trainer = Trainer(gpus=4, max_epochs=10)
trainer.fit(model, dm)
trainer.test(datamodule=dm)
# 2. Save your model
torch.save(model.state_dict(), 'model.pt')
# 3. Predict with your model
rock_feature = torch.tensor([...])
rock_prediction = model(rock_feature)
Datasets
.http://archive.ics.uci.edu/ml/datasets/connectionist+bench+(sonar,+mines+vs.+rocks)

Last updated
Was this helpful?