I have already got the complete model by using pytorch, however I wanna convert the .pth file into .pb, which could be used in Tensorflow. Does anyone have some ideas?
2 Answers
You can use ONNX: Open Neural Network Exchange Format
To convert .pth file to .pb First, you need to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow (PyTorch => ONNX => Tensorflow)
This is an example of MNISTModel to Convert a PyTorch model to Tensorflow using ONNX from onnx/tutorials
Save the trained model to a file
torch.save(model.state_dict(), 'output/mnist.pth') Load the trained model from file
trained_model = Net() trained_model.load_state_dict(torch.load('output/mnist.pth')) # Export the trained model to ONNX dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx") Load the ONNX file
model = onnx.load('output/mnist.onnx') # Import the ONNX model to Tensorflow tf_rep = prepare(model) Save the Tensorflow model into a file
tf_rep.export_graph('output/mnist.pb') AS noted by @tsveti_iko in the comment
NOTE: The
prepare()is build-in in theonnx-tf, so you first need to install it through the console like thispip install onnx-tf, then import it in the code like this:import onnx from onnx_tf.backend import prepareand after that you can finally use it as described in the answer.
3 Comments
prepare() is build-in in the onnx-tf, so you first need to install it through the console like this pip install onnx-tf, then import it in the code like this: import onnx from onnx_tf.backend import prepare and after that you can finally use it as described in the answer.If you are using TF 1.15 or below you might not find above code helpful because you would end-up solving miss-match version error
So here is all version matched code working for TF 1.X
Keras 2.3.0 Keras-Applications 1.0.8 Keras-Preprocessing 1.1.2 numpy 1.21.5 onnx 1.8.0 onnx-tf 1.3.0 protobuf 3.19.4 tensorboard 1.15.0 tensorflow 1.15.0 tensorflow-estimator 1.15.1 torch 1.6.0+cpu torchvision 0.7.0+cpu After having all these packages use the answer by Dishin
Note: Variable is depreciated in newer version of torch