This article was written in collaboration with dara akojede .
Lightweight models are AI models designed to be smaller and more efficient compared to their traditional counterparts. This interprets to:
Faster processing as they require less computational power to run and operate on devices with low resources like laptops or even smartphones.
Reduced memory usage because they take up less space in memory.
Lower computational costs required to run the model.
This kind of model is achieved by using fewer parameters, creating an optimized architecture, and quantization (representing the model's data using fewer bits).
Examples of lightweight models include MobileNet, a computer vision model designed for mobile and embedded vision applications, EfficientDet, an object detection model, and EfficientNet, a CNN that uses compound scaling to enable better performance. All these are lightweight models from Google.
In this article, we will be looking at Gemma, a state-of-the-art lightweight model.
Introduction to Gemma
Gemma is a family of lightweight, open-source machine learning models developed by Google AI. These models are designed to be accessible and efficient, making AI development more available for a broad range of users. Released on February 21st, 2024, Gemma is built from the same research and technology that was used to create the Gemini models. Amongst the key features, which are being lightweight and open-source, Gemma is also text-based. It excels in tasks like text summarization, question answering, and reasoning.
Based on the number of trainable parameters, Gemma models come in two main variations: 2B and 7B. It also offers instruction-tuned models like Gemma 2B-FT and 7B-FT, which are specifically designed for further customization using personal datasets. Gemma applications can be functional in various industries that perform actions on text.
Getting Started with Gemma
To get started with using Gemma, below are the following prerequisites to successfully run Gemma
Prerequisites
A Kaggle account. Create one here if you don’t have one.
To get access to Gemma, an access request must be sent to the Gemma model card and select “Request Access.” You will be required to complete the consent form and accept the terms and conditions. Then, select a Colab runtime and configure your API key. A detailed setup can be found in the Gemma Setup docs.
In this tutorial, we will be using the Colab notebook environment to run the model. After you've completed the Gemma setup, you will have to set variables for your Colab environment.
import os
from google.colab import userdata
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
After setting the variables for the environment, the next step is to install dependencies. To use Gemma, KerasNLP is the dependency used. KerasNLP is a collection of natural language processing (NLP) models implemented in Keras and runnable on JAX, PyTorch, and TensorFlow.
pip install -q -U keras-nlp
pip install -q -U keras>=3
Now that KerasNLP has been installed, a backend will be chosen to run Gemma. In the code block below, jax
is used.
import os
os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
The last step is to import the installed libraries and instantiate the Gemma model using the from_preset
method on the GemmaCausalLM class, an end-to-end Gemma model for causal language modeling.
import keras
import keras_nlp
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
# get more information about the model
gemma_lm.summary()
Let's start generating some text now! The class has a generate
method that generates text based on a prompt.
gemma_lm.generate("What is Generative AI?", max_length=64)
The first run may take some time, but subsequent runs will return near-immediate results from the prompts provided.
The generate
method can also take in a batch of prompts as a list of strings.
gemma_lm.generate(
["What is the greatest thing ever?",
"Why is the sky blue?"],
max_length=256)
Fine-tuning Gemma models using LoRA
Fine-tuning is the process of taking a pre-trained model and adjusting it further with additional training on a more specific dataset. This technique leverages the general capabilities of the model and allows the model to excel at specific tasks rather than remaining a general-purpose tool. One technique for achieving this fine-tuning is LoRA (Low-Rank Adaptation).
LoRA is a technique designed to enhance the capabilities of pre-trained transformer models. It was developed to optimize transformer networks efficiently by focusing on a significantly smaller set of trainable parameters. These parameters act like a lightweight "adapter" that sits on top of the pre-trained LLM.
By fine-tuning this adapter, LoRA modifies the model's behavior for the new task without needing to make extensive changes to the underlying structure. This translates to faster training times, reduced memory usage, and the ability to run LLMs on less powerful hardware.
In this section, we will be fine-tuning a mental health dataset from Hugging Face.
First, we download the dataset by running the block below.
wget -O mental_health_counseling_conversations https://huggingface.co/datasets/Amod/mental_health_counseling_conversations/raw/main/combined_dataset.json
After downloading the data, we perform a simple preprocessing and use a subset of 2000 examples. More data is needed for high-quality fine-tuning.
import json
data = []
template = "Question:\n{Context}\n\nResponse:\n{Response}"
with open("combined_dataset.json") as file:
for line in file:
features = json.loads(line)
# Appending all rows
data.append(template.format(**features))
data = data[:500]
Let us prompt and observe the response generated.
prompt = template.format(
Context="What should I do when I feel sad?",
Response="",
)
print(gemma_lm.generate(prompt, max_length=256))
To get better responses from the model, we now fine-tune with LoRA on the dataset; we will be using a rank of 4 because it is advised to begin with a small rank for computation efficiency.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
Then, we configure and run a training session.
gemma_lm.preprocessor.sequence_length = 512
optimizer = keras.optimizers.AdamW(
learning_rate=5e-5,
weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])
gemma_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=optimizer,
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
After fine-tuning, we can now give a prompt and see the difference in the response.
prompt = template.format(
Context="What should I do when I feel sad?",
Response="",
)
There is a difference in the responses, and this is because of fine-tuning. To get better responses from the fine-tuned model, the following can be done:
Training for more steps (epochs).
Setting a higher LoRA rank.
Modifying hyperparameter values.
Increasing the size of the fine-tuning dataset.
Conclusion
We explored Gemma's innovative and efficient capabilities. It is text-focused and can perform a range of tasks on text. Furthermore, Gemma's support for fine-tuning using LoRA opens up possibilities for customization and adaptation to specific tasks and datasets. This feature enables users to enhance the model's performance further and tailor it to their unique requirements.