Fine-tuning
Fine-tuning
Why fine-tuning?
Generalist foundation models are pre-trained on massive datasets but as the pre-fix “pre” insinuates, there is a lot more room for improvement.
In most entreprise applications, prompt engineering and Retrieval Augmented Generation (RAG) over generalist foundation models will not be sufficient for production-grade accuracy. Fine-tuning solves this problem by training a base model on domain-specific data, making it suitable for applications that require domain-specific knowledge.
In addition to offering leading performance to generalist foundation models, fine-tuning comes at a much lower cost and yields much faster response times, when compared to prompt engineering or RAG over generalist models.
Fine-tuning methods
There are 2 general classes of fine-tuning methods: full-fine tuning (FFT), wherein all the parameters of the model being fine-tuned are updated and parameter-efficient fine-tuning (PEFT), wherein adapter modules with a small number of new trainable parameters are added to the existing base model parameters.
The table below summarizes popular fine-tuning methods. SeekrFlow currently supports FFT. Future releases will support additional methods.
Method | Category | Description | Status |
---|---|---|---|
Full Fine-Tuning | FFT | Fine-tuning the entire model on the target task. This is computationally expensive and requires a large amount of labeled data. | Current |
LoRA (Low-Rank Adaptation) | PEFT | Decomposing weight matrices into low-rank matrices and only training these low-rank components. | Upcoming |
Prefix Tuning | PEFT | Prepending trainable prefix tokens to the input sequence. | Upcoming |
MoRA (High-Rank Adaptation) | PEFT | Overcomes limitations of LoRA. | Upcoming |
Cost of fine-tuning vs RAG
The cost associated with a request to a LLM depends on
- the number of input tokens and the number of output tokens
- the size of the LLM and the amount of computing resources it takes up (large generalist models require more computing power than smaller specialist models)
With RAG, the cost for (1) may be high because the number of input tokens is driven by how many retrieved documents the LLM will have to "read" before it starts generating an answer. In other words, because the LLM does not know the answer to the question, it will first have to "read" all the retrieved documents. It's not uncommon for the number of requests to a LLM application to be in the hundreds or thousands per minute or even per second. Hence, the total cost for all the requests is driven up by the LLM having to "read" all the retrieved documents, for every single request!
As an example, let's take a proprietary model like gpt-4-turbo
and a LLM application that on average
- retrieves 100K tokens per request
- responds to 100 requests per second
At current costs for gpt-4-turbo
($10 per 1 Million input tokens), the input token cost for a single request is about $1; the total cost for 100 requests for our application is $100 per second!
In addition, for RAG to work well, it requires a large generalist model, which in turn requires a large amount of computing resources. That said, the cost for (2) is also high.
There are other costs associated with RAG, such as, maintaining a database to host the documents, developing and maintaining a RAG evaluation framework, and having to increase the number of serving instances to offset the latency that comes with an increased input size.
SeekrFlow removes these costs by enabling its users to build specialist models that are not required to "read" or lookup any information before providing a response.
SeekrFlow fine-tuned models and RAG can also work in conjunction. In the case where lookups to live data are needed, a SeekrFlow fine-tuned model can be used in a RAG flow, just like any other model.
Fine-tuning a base model
Let's begin the steps towards fine-tuning a model to a specific task. Here, we will demonstrate boosting the conversational question answering and RAG capabilites of llama-3-8b
in relation to news documents.
(We will use the file that we previously formatted and uploaded to SeekrFlow.)
To create a fine-tuning job, we must first create a project, to which we will associate our fine-tuning run. We can also retrieve project information and get a list of all of our projects:
import os
from seekrai import SeekrFlow
client = SeekrFlow(api_key=os.environ.get("SEEKR_API_KEY"))
proj = client.projects.create(name="project-name", description="project-description")
# get project info
client.projects.retrieve(proj.id)
# list all projects
client.projects.list()
Next, users need to specify a TrainingConfig
object and an InfrastructureConfig
object.
The TrainingConfig
defines all parameters that affect the actual code of the training script, such as the base model to be fine-tuned, number of epochs, quantization, etc.
The InfrasructureConfig
defines the infrastructure that the fine-tuning job will run on. As SeekrFlow is agnostic to the accelerator hardware that it runs on, it accepts multiple choices for the accelerator type including. Gaudi2
, A100
and H100
. SeekrFlow will take care of provisioning all necessary compute resources and configurations required for the accelerator chosen.
In this example will use 8 Gaudi2 instances, which will trigger SeekrFlow to run on multi-card training mode:
from seekrai.types import TrainingConfig, InfrastructureConfig
training_config = TrainingConfig(
training_files=[file.id], # NOTE: We are passing in the ID of a previously uploaded fine-tuning file
model='meta-llama/Meta-Llama-3-8B', # Base model choice
n_epochs=1,
n_checkpoints=3,
batch_size=4,
learning_rate=1e-5,
experiment_name="experiment-name",
)
infrastructure_config = InfrastructureConfig(
n_accel=8,
accel_type="GAUDI2"
)
Now that we have created our configuration files, we are ready to fine-tune a model!
fine_tune = client.fine_tuning.create(
training_config=training_config,
infrastructure_config=infrastructure_config,
project_id = proj.id, # NOTE: To associate this fine-tune with a project, we are passing in the ID of the project created above
)
Monitoring a fine-tuning job
SeekrFlow job runs are tracked using SeekrFlow's event monitoring and tracking system.
To retrieve the status and progress of a run, you can use:
print(client.fine_tuning.retrieve(fine_tune.id).status)
The following snippet can be used to plot the loss of the fine-tuning run:
import matplotlib.pyplot as plt
ft_id = fine_tune.id
events = client.fine_tuning.retrieve(ft_id).events
ft_response_events_sorted = sorted(events, key=lambda x: x.epoch)
epochs = [event.epoch for event in ft_response_events_sorted]
losses = [event.loss for event in ft_response_events_sorted]
plt.figure(figsize=(8, 4))
plt.plot(epochs, losses, marker="o", linestyle="-", color="b")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("training loss over epochs")
plt.grid(True)
max_labels = 10
step = max(1, len(epochs) // max_labels)
plt.xticks(epochs[::step], rotation=45)
plt.tight_layout()
plt.show()
Inference
In order to run inference against a trained model, we have to promote it to the inference API:
dep = client.deployments.create(
name="model-deployment",
description="model-deployment-description",
model_type="Fine-tuned Run",
model_id=fine_tune.id,
n_instances=1,
)
ft_deploy = client.deployments.promote(dep.id)
# to demote your model (when you are finished with it)
client.deployments.demote(ft_deploy.id)
This may take a few minutes. Once a model is promoted, you can run inference to obtain chat completions:
stream = client.chat.completions.create(
model=ft_deploy.id,
messages=[
{"role": "system", "content": "You are SeekrBot, a helpful AI assistant"},
{"role": "user", "content": "who are you?"}
],
stream=True,
max_tokens=1024,
)
for chunk in stream:
print(chunk.choices[0].delta.content or "", end="")
model
can be any of our base supported models or models that have been promoted for inference.
Returning token log probabilities during inference
We can also return the token log probabilities, or "logprobs". We follow the OpenAI convention for formatting the request:
- To return the logprobs of the generated tokens, set
logprobs=True
. - To additionally return the top n most likely tokens and their associated logprobs, set
top_logprobs=n
, where n > 0.
client = SeekrFlow(api_key=os.environ.get("SEEKR_API_KEY"))
stream = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Tell me about New York."}],
stream=True,
logprobs=True,
top_logprobs=5, # NOTE: Max number, m, depends on model deployment spec; n > m may throw validation error
)
for chunk in stream:
print(chunk.choices[0].delta.content or "", end="", flush=True)
print(chunk.choices[0].logprobs)
Listing model fine-tunes
We can also list all model fine-tunes that were previously created, or get information about a particular model fine-tune:
client.fine_tuning.list()
client.fine_tuning.retrieve(ft_id)
Updated about 1 month ago