12

I have already got the complete model by using pytorch, however I wanna convert the .pth file into .pb, which could be used in Tensorflow. Does anyone have some ideas?

2 Answers 2

13

You can use ONNX: Open Neural Network Exchange Format

To convert .pth file to .pb First, you need to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow (PyTorch => ONNX => Tensorflow)

This is an example of MNISTModel to Convert a PyTorch model to Tensorflow using ONNX from onnx/tutorials

Save the trained model to a file

torch.save(model.state_dict(), 'output/mnist.pth') 

Load the trained model from file

trained_model = Net() trained_model.load_state_dict(torch.load('output/mnist.pth')) # Export the trained model to ONNX dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx") 

Load the ONNX file

model = onnx.load('output/mnist.onnx') # Import the ONNX model to Tensorflow tf_rep = prepare(model) 

Save the Tensorflow model into a file

tf_rep.export_graph('output/mnist.pb') 

AS noted by @tsveti_iko in the comment

NOTE: The prepare() is build-in in the onnx-tf, so you first need to install it through the console like this pip install onnx-tf, then import it in the code like this: import onnx from onnx_tf.backend import prepare and after that you can finally use it as described in the answer.

Sign up to request clarification or add additional context in comments.

3 Comments

NOTE: The prepare() is build-in in the onnx-tf, so you first need to install it through the console like this pip install onnx-tf, then import it in the code like this: import onnx from onnx_tf.backend import prepare and after that you can finally use it as described in the answer.
@tsveti_iko thanks to bring attention. i thought it should be in the answer rather in comment so i have added it in the answer.
indeed, you can even insert the code parts in your existing code blocks and remove the NOTE quote
0

If you are using TF 1.15 or below you might not find above code helpful because you would end-up solving miss-match version error
So here is all version matched code working for TF 1.X

Keras 2.3.0 Keras-Applications 1.0.8 Keras-Preprocessing 1.1.2 numpy 1.21.5 onnx 1.8.0 onnx-tf 1.3.0 protobuf 3.19.4 tensorboard 1.15.0 tensorflow 1.15.0 tensorflow-estimator 1.15.1 torch 1.6.0+cpu torchvision 0.7.0+cpu 

After having all these packages use the answer by Dishin

Note: Variable is depreciated in newer version of torch

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.