Working with Pytorch
training and predicting with the model
# 1. Training
trainer = Trainer(gpus=4, max_epochs=10)
trainer.fit(model, dm)
trainer.test(datamodule=dm)
# 2. Save your model
torch.save(model.state_dict(), 'model.pt')
# 3. Predict with your model
rock_feature = torch.tensor([...])
rock_prediction = model(rock_feature)Datasets

Last updated