Working with Pytorch

Classes to pay attention

DataLoader, DataModule, Dataset of pytorch.

training and predicting with the model

Lightning's trainer class makes training straightforward by

  • Enabling and disabling gradients as needed

  • Invoking callback functions

  • Dispatching data and computations to appropriate devices

# 1. Training
trainer = Trainer(gpus=4, max_epochs=10)

trainer.fit(model, dm)


trainer.test(datamodule=dm)


# 2. Save your model
torch.save(model.state_dict(), 'model.pt')



# 3. Predict with your model

rock_feature = torch.tensor([...])
rock_prediction = model(rock_feature)

Datasets

.http://archive.ics.uci.edu/ml/datasets/connectionist+bench+(sonar,+mines+vs.+rocks)

Last updated

Was this helpful?