Home | Send Feedback

Fine-tune a LLM

Published: 27. October 2024  •  python, llm

In the previous blog post, I demonstrated how you can run a LLM in the browser with the help of the Transformers.js library.

We saw that running a tiny model like Llama 3.2 1B in the browser is quite feasible, but the quality of the generated text is a hit or a miss.

In this blog post, I will show you how to fine-tune the model for this specific task and improve the quality of the generated SQL statements.

Preparing training data

The first step in fine-tuning a model is to prepare the training data. This is the most labor-intensive part of the process, but it is also the most important one. The training data quality will determine the fine-tuned model's quality. For this example, I generated a list of questions and the corresponding SQL SELECT statement. With the help of GPT-4o and Claude 3.5 Sonnet, I came up with the following list of questions. Each row is one conversation with the question in the user message and the SQL statement in the assistant message.

[{"role": "user", "content": "What is the largest country?"},{"role": "assistant", "content": "SELECT * FROM countries ORDER BY area DESC LIMIT 1"}]
[{"role": "user", "content": "Which is the largest country?"},{"role": "assistant", "content": "SELECT * FROM countries ORDER BY area DESC LIMIT 1"}]
[{"role": "user", "content": "Which country has the highest population?"},{"role": "assistant", "content": "SELECT * FROM countries ORDER BY population DESC LIMIT 1"}]
[{"role": "user", "content": "Which countries have a population over 100 million?"},{"role": "assistant", "content": "SELECT * FROM countries WHERE population > 100000000 ORDER BY population DESC"}]

countries-training-data.json

This training dataset contains 186 rows. It's difficult to say how many examples are needed to fine-tune a model. It depends on the specific use case and whether the model is already familiar with the domain. We can assume that Llama 3.2 1B already knows some about SQL SELECT statements, so we don't need many examples. I saw an improvement in the generated SQL statements after fine-tuning the model with this tiny dataset. But I'm sure more examples would improve the quality even more.

Fine-tuning model

I followed this article for writing the following code to fine-tune the model. You can find my code here. It's a Python program managed by Poetry. I had to install the following dependencies.

poetry add transformers datasets accelerate peft trl torch bitsandbytes

As the prompt template, I used the following text.

PROMPT_TEMPLATE = """You are given a database schema and a question.
Based on the schema, generate SQL SELECT statement that answers the question.

Schema:
CREATE TABLE countries (
  id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
  name TEXT NOT NULL,
  area       INTEGER,
  area_land  INTEGER,
  area_water INTEGER,
  population        INTEGER,
  population_growth REAL,
  birth_rate        REAL,
  death_rate        REAL,
  migration_rate    REAL,
  flag_description TEXT
)

Question:
{question}
"""

main.py

It's a bit shorter than the prompt template I used in the previous blog post. This is one benefit of fine-tuning a model. Often, you can shorten your prompt and still reliably get the desired output.


As the base model, I use unsloth/Llama-3.2-1B-Instruct. The name of my fine-tuned model is Llama-3.2-1B-Instruct-Country-SQL.

First, the program initializes the QLoRA and LoRA config. Then, it loads the base model and the tokenizer.

  base_model = "unsloth/Llama-3.2-1B-Instruct"
  dataset_name = "countries-training-data.json"
  new_model = "Llama-3.2-1B-Instruct-Country-SQL"

  torch_dtype = torch.float16
  attn_implementation = "eager"

  # QLoRA config
  bnb_config = BitsAndBytesConfig(
    load_in_4bit=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
  )

  # Load model
  model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
  )

  tokenizer = AutoTokenizer.from_pretrained(base_model)

  # LoRA config
  peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
  )
  model = get_peft_model(model, peft_config)

main.py

LoRA (Low-Rank Adaptation of Large Language Models) is a fine-tuning technique that adapts a pre-trained model by injecting low-rank matrices into its layers. Rather than updating all parameters, LoRA trains only these added matrices, significantly reducing the memory and computational costs while achieving effective fine-tuning for specific tasks.

QLoRA (Quantized Low-Rank Adaptation) builds upon LoRA by applying quantization techniques to further minimize memory usage, specifically by quantizing the model weights to lower precision (e.g., 4-bit precision). This allows even larger models to be fine-tuned on consumer-grade hardware without sacrificing much performance, making it a popular approach for efficient fine-tuning of large language models.

I'm not an expert in this field, so I copied the code from this article. I'm sure you can improve the fine-tuning process by understanding all these parameters better. But I'm quite happy with the results I got with this code.

The next step is to prepare the training data. This code converts the messages with tokenizer.apply_chat_template into a format the model can understand.

  messages = []
  with open(dataset_name, 'r') as f:
    for line in f:
      conversation = json.loads(line)
      user_message = next(filter(lambda x: x["role"] == "user", conversation))
      user_message["content"] = PROMPT_TEMPLATE.format(question=user_message["content"])
      obj = {"text": (tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False))}
      messages.append(obj)

  dataset = Dataset.from_list(messages)

  dataset = dataset.train_test_split(test_size=0.1)

main.py

Finally, the program configures the arguments for the trainer, creates the trainer, which is an instance of the SFTTrainer class, and trains the model with the train method.

  training_arguments = SFTConfig(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=5,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    group_by_length=True,
  )

  trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
  )

  # Train model
  trainer.train()

main.py

I ran the fine-tuning process on a g5.4xlarge EC2 instance on AWS. It only took about a minute to fine-tune the model. But because this model is so tiny, running this on your local machine should be easy when you have a GPU.

The last step in the fine-tuning program saves the model and uploads it to Hugging Face.

  trainer.model.save_pretrained(new_model)
  trainer.model.push_to_hub(new_model, use_temp_dir=False, token="hf_")

main.py

To run this program, I started it with the following command.

poetry run python finetune/main.py

This is not the end of the process because this only saved and uploaded the LoRA adapters. We have to merge this with the base model to make it usable. The following Python program does this. You can do this in one program, but I separated it into two programs to understand better what's happening.

The merge program first loads the base model.

base_model = "unsloth/Llama-3.2-1B-Instruct"
new_model = "Llama-3.2-1B-Instruct-Country-SQL"

tokenizer = AutoTokenizer.from_pretrained(base_model)

base_model_reload = AutoModelForCausalLM.from_pretrained(
  base_model,
  return_dict=True,
  low_cpu_mem_usage=True,
  torch_dtype=torch.float16,
  device_map="auto",
  trust_remote_code=True,
)

merge.py

Then, it merges the adapter with the base model.

model = PeftModel.from_pretrained(base_model_reload, new_model)
model = model.merge_and_unload()

merge.py

And lastly, it saves the model and uploads it to Hugging Face.

model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)
model.push_to_hub(new_model, use_temp_dir=False, token="hf_")
tokenizer.push_to_hub(new_model, use_temp_dir=False, token="hf_")

merge.py

I ran this program with the following command.

poetry run python finetune/merge.py

This process only takes a few seconds.

Run fine-tuned model

You can now test the fine-tuned model with the transformers Python library. The model ralscha/Llama-3.2-1B-Instruct-Country-SQL is publicly available on Hugging Face.

    tokenizer = AutoTokenizer.from_pretrained("ralscha/Llama-3.2-1B-Instruct-Country-SQL")
    model = AutoModelForCausalLM.from_pretrained("ralscha/Llama-3.2-1B-Instruct-Country-SQL")
    messages = [
        {
            "role": "user",
            "content": (
                PROMPT_TEMPLATE.format(question="Show me countries where the population is greater than 10 million."))
        }
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False,
                                           add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", padding=True,
                       truncation=True)

    outputs = model.generate(**inputs, max_new_tokens=150,
                             num_return_sequences=1)

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(text.split("assistant")[1])

main.py

Converting to ONNX

The model we trained can't be loaded with Transformers.js in this format. We have to convert it to ONNX. Thankfully, the Transformers.js library provides a Python script to convert PyTorch models.

I cloned the Transformers.js repository and ran the following command to convert the model

python -m scripts.convert --quantize --model_id ralscha/Llama-3.2-1B-Instruct-Country-SQL

This command converts the model to ONNX and quantizes it. The program outputs the model in different quantization formats. I chose the smallest one, q4f16, and copied it into my web application's assets folder.

Make sure that the directory layout looks like this. Transformers.js expects the files in this structure.

assets
└── ralscha
    └── Llama-3.2-1B-Instruct-Country-SQL
        └── onnx
            └── model_q4f16.onnx
        ├── config.json
        ├── generation_config.json
        ├── tokenizer.json
        └── tokenizer_config.json

In the TypeScript code, I updated the prompt template to the shorter version and the one I used in the fine-tuning process. I also changed the model name to ralscha/Llama-3.2-1B-Instruct-Country-SQL in the pipeline call.

    this.generator = await pipeline('text-generation', 'ralscha/Llama-3.2-1B-Instruct-Country-SQL', {
      device: 'wasm',
      dtype: 'q4f16',
      local_files_only: true,
    });

home.page.ts

One problem is that this fine-tuned model only works with the wasm device. Compared to webgpu, this is much slower. I couldn't figure out what I did wrong. If somebody knows where the problem is, please let me know.

Conclusion

When testing the web application with this model, I saw that the quality of the generated SQL statements improved. The model now generates SQL statements more consistently and with fewer errors. More training data would probably improve the quality even more.

Fine-tuning a model is a powerful technique to improve the quality of the generated answers. These tiny models can greatly benefit from fine-tuning, especially when used for a specific task. These small models don't take a long time to fine-tune, and when you have a GPU, you can run the process on your local machine.

With the programs I showed you in this blog post, you have all the tools to fine-tune a model and convert it to ONNX. This article about fine-tuning helped me a lot to get the process up and running.