Mastering the Art of Saving Multiple Models
Imagine building not just one, but a collection of powerful AI models, each with its own unique strengths. But the real challenge? Preserving and managing them all without losing a shred of their potential. In the world of deep learning, saving multiple models is more than just a task, it’s a strategic move that ensures your innovations remain accessible and adaptable, ready to tackle diverse challenges at any moment. Let’s dive into the art of safeguarding your AI arsenal, where each saved model is a step toward future proofing your technological breakthroughs.
Table of Content
- Saving multiple models
- Exporting models on ONXX format
- Serving model with FastAPI
- Benefits of FastAPI
Saving Multiple Models
There may be instances where your comprehensive model is composed of
multiple neural networks. Take Generative Adversarial Network (GAN)
as an example, which comprises two distinct networks, the generator and the discriminator. In such cases, it is recommended to store the entire model
as a single dictionary. Here, is a guide on how you can achieve this.
torch.save({
'first_model_state':
model1.state_dict(),
'second_model_state':
model2.state_dict(),
'first_optimizer_state':
optimizer1.state_dict(),
'second_optimizer_state':
optimizer2.state_dict(),
# ... any other states
}, file_path)
To load the model back to memory.
# initialize your models and optimizers first
model1 = Model1Class(*args, **kwargs)
model2 = Model2Class(*args, **kwargs)
optimizer1 = Optimizer1Class(*args, **kwargs)
optimizer2 = Optimizer2Class(*args, **kwargs)
# load the states from the file
saved_states = torch.load(file_path)
model1.load_state_dict(saved_states['first_model_state'])
model2.load_state_dict(saved_states['second_model_state'])
optimizer1.load_state_dict(saved_states['first_optimizer_state'])
optimizer2.load_state_dict(saved_states['second_optimizer_state'])
# switch to evaluation mode or training mode
model1.eval() # or model1.train()
model2.eval() # or model2.train()
Exploring Model on ONXX Format
ONNX provides an open source format for AI models, both deep learning
and traditional ML. It defines an extensible computation graph model as
well as definitions of built-in operators and standard data types.
The main stages of ONXX Format.
- Interoperability — ONNX is supported by a variety of frameworks
such as PyTorch, TensorFlow, MXNet and tools like NVIDIA’s
TensorRT. You can train a model in one framework, export it to
ONNX, and use it in another framework for inference. - Portability — Models in ONNX format can be deployed on a variety of
platforms, from cloud-based servers with powerful GPUs to edge
devices like mobile phones and IoT devices. - Performance — Some runtimes, like ONNX Runtime, can optimize
the execution of the computation graph, leading to performance
improvements.
To export a PyTorch model to ONNX format, you can use the
torch.onnx.export function. The following code provides example of
exporting and using the ONNX model for inference. When exporting a model to ONNX format, you need to provide a dummy input that matches
the input your model expects. By passing through the dummy input, the
exporter can infer the shape and data type of the input tensor, and these are
then used in the exported ONNX graph as metadata. This allows ONNX
runtime to understand what kind of input the model expects, including the
shape and type.
import torch
import torchvision
dummy_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.alexnet(pretrained=True)
torch.onnx.export(model, dummy_input, "model.onnx")
# Inference
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession('model.onnx')
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name:
np.random.randn(1,3,224,224).astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
The accompanying notebook provides additional example of exporting and
using HuggingFace models in ONNX.
Serving Model with FastAPI
FastAPI is a modern, fast high-performance, web framework for building
APIs with Python 3.6+, based on standard Python type hints. It was
developed as an efficient alternative to existing Python frameworks, such as
Flask and Django, providing significant performance benefits and
simplified syntax.
FastAPI takes advantage of Python’s type checking, which makes your code
more robust and simplifies debugging. It is also designed to work well with
modern frontend JavaScript frameworks, which often consume RESTful
APIs.
Key features of FastAPI include automatic interactive API documentation,
inherent validation using Pydantic models, OAuth2 support with JWT
tokens and password hashing, CORS handling, customizable exception
handling and more. It is asynchronous friendly and allows the usage of
WebSockets and other web protocols.
Benefits of FastAPI
- Performance — FastAPI is one of the fastest Python frameworks
available, only lower than Starlette and Uvicorn, upon which it is
built. It is faster than traditional frameworks and can even compete
with NodeJS and Go. - Easy to code — FastAPI’s use of Python type hints and Pydantic
models make it easier to define API schemas, validate request data,
and extract request data such as JSON fields, path parameters and
query parameters. - Automatic API documentation — FastAPI generates an interactive
API documentation UI automatically, making it easier for developers
and users to understand and try out your API. - Support for modern Python features — FastAPI supports
asynchronous request handling, making it suitable for WebSockets and other scenarios requiring asynchronicity. It also supports HTTP/2
and WebSockets. - Robustness — Thanks to automatic data validation and serialization
using Pydantic, and Python’s type hints, FastAPI applications tend to
be bug resistant and easier to debug and maintain.
Conclusion
Mastering the process of saving multiple models, exporting them in the ONNX format and serving them with FastAPI unlocks a new level of efficiency and scalability in AI deployment. By strategically managing multiple models, you ensure that each innovation is preserved and primed for future use. Exporting these models in the ONNX format not only enhances cross-platform compatibility but also streamlines integration into diverse environments. Leveraging FastAPI for serving these models offers unparalleled speed, simplicity and flexibility, allowing you to deliver cutting-edge solutions with minimal overhead. As we embrace these practices, we elevate our AI capabilities, making our models not just powerful, but also agile, responsive and ready to drive impact across industries. This holistic approach to model management and deployment sets the stage for a future where AI can rapidly adapt and scale, meeting the demands of an ever evolving technological landscape.