QLoRA Advanced: Continual Learning - Qwen3-4B-Thinking¶
Demonstrates sequential fine-tuning with LoRA to add new knowledge without catastrophic forgetting.
Training Stages:
- Medical Terminology - Medical definitions and concepts
- Legal Terminology - Legal terms and principles
- Technical Terminology - Software/API concepts
Key features demonstrated:
- Sequential domain training
- Retention testing after each stage
- Verification of no catastrophic forgetting
- Incremental knowledge addition
- Reasoning capability preservation
Why this matters: Traditional fine-tuning can cause "catastrophic forgetting" where new training overwrites old knowledge. LoRA adapters help mitigate this by keeping base model weights frozen.
Important: This notebook includes a kernel shutdown cell at the end to release all GPU memory.
# Environment Setup
import os
from dotenv import load_dotenv
load_dotenv()
# Force text-based progress instead of HTML widgets
os.environ["TQDM_NOTEBOOK"] = "false"
# CRITICAL: Import unsloth FIRST for proper TRL patching
import unsloth
from unsloth import FastLanguageModel, is_bf16_supported
import torch
import gc
# Environment summary
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU"
print(f"Environment: unsloth {unsloth.__version__}, PyTorch {torch.__version__}, {gpu}")
print(f"HF_TOKEN loaded: {'Yes' if os.environ.get('HF_TOKEN') else 'No'}")
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
/opt/pixi/.pixi/envs/default/lib/python3.13/site-packages/trl/__init__.py:203: UserWarning: TRL currently supports vLLM versions: 0.10.2, 0.11.0, 0.11.1, 0.11.2. You have version 0.14.0rc1.dev201+gadcf682fc.cu130 installed. We recommend installing a supported version to avoid compatibility issues. if is_vllm_available():
🦥 Unsloth Zoo will now patch everything to make training faster!
Environment: unsloth 2025.12.10, PyTorch 2.9.1+cu130, NVIDIA GeForce RTX 4080 SUPER HF_TOKEN loaded: Yes
# Helper Functions
def cleanup_memory():
"""Force garbage collection and clear CUDA cache"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def count_parameters(model):
"""Count trainable vs total parameters"""
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
return {"trainable": trainable, "total": total, "pct": 100 * trainable / total}
print("Helper functions defined.")
Helper functions defined.
# Domain-Specific Datasets (5 samples each)
from datasets import Dataset
# Stage 1: Medical Terminology
medical_data = [
{
"instruction": "Define myocardial infarction.",
"thinking": "Medical term for heart attack. What's the precise definition? It's about blood flow blockage to the heart muscle. Include the key mechanism.",
"response": "Myocardial infarction (heart attack) is the death of heart muscle tissue due to prolonged lack of blood supply, typically caused by a blocked coronary artery."
},
{
"instruction": "What is hypertension?",
"thinking": "Common medical condition. High blood pressure. What are the clinical thresholds? 140/90 is typically the cutoff. Mention risks.",
"response": "Hypertension is persistently elevated blood pressure (typically >140/90 mmHg). It increases risk of heart disease, stroke, and kidney damage."
},
{
"instruction": "Explain diabetes mellitus.",
"thinking": "Metabolic disorder. Two types: Type 1 (autoimmune) and Type 2 (resistance). Focus on the core issue: blood sugar regulation.",
"response": "Diabetes mellitus is a metabolic disorder where the body cannot properly regulate blood glucose due to insufficient insulin production (Type 1) or insulin resistance (Type 2)."
},
{
"instruction": "What is anemia?",
"thinking": "Blood condition. Reduced red blood cells or hemoglobin. What are the symptoms? Fatigue, weakness. Mention causes.",
"response": "Anemia is a condition with insufficient red blood cells or hemoglobin to carry adequate oxygen. It causes fatigue, weakness, and pallor. Common causes include iron deficiency and blood loss."
},
{
"instruction": "Define tachycardia.",
"thinking": "Heart rate term. Fast heart rate. What's the threshold? >100 bpm at rest. Can be normal or pathological.",
"response": "Tachycardia is a heart rate exceeding 100 beats per minute at rest. It can result from exercise, stress, fever, or underlying cardiac conditions."
},
]
# Stage 2: Legal Terminology
legal_data = [
{
"instruction": "Define tort law.",
"thinking": "Civil law area. Deals with wrongful acts causing harm. Not criminal. What's the key purpose? Compensation for victims.",
"response": "Tort law is civil law governing wrongful acts that cause harm to others. It allows injured parties to seek compensation from those responsible, covering negligence, defamation, and intentional wrongs."
},
{
"instruction": "What is due process?",
"thinking": "Constitutional principle. Fair treatment through legal system. Two types: procedural and substantive. Fundamental right.",
"response": "Due process is the constitutional guarantee that legal proceedings will be fair and follow established rules. It ensures notice, hearing, and impartial decision-making before depriving rights."
},
{
"instruction": "Explain habeas corpus.",
"thinking": "Latin: 'produce the body.' Protection against unlawful detention. Court order to bring detained person before judge.",
"response": "Habeas corpus is a legal action requiring authorities to bring a detained person before a court to determine if their imprisonment is lawful. It protects against arbitrary detention."
},
{
"instruction": "What is legal precedent?",
"thinking": "Stare decisis principle. Past decisions guide future cases. Creates consistency in law. Binding vs persuasive precedent.",
"response": "Legal precedent (stare decisis) is the principle that courts should follow previous decisions in similar cases. It creates predictability and consistency in legal interpretation."
},
{
"instruction": "Define liability.",
"thinking": "Legal responsibility. Can be civil or criminal. Obligation arising from wrongful act. Leads to damages or penalties.",
"response": "Liability is legal responsibility for one's actions or omissions. In civil law, it typically means obligation to pay damages. In criminal law, it can mean facing penalties."
},
]
# Stage 3: Technical Terminology
technical_data = [
{
"instruction": "Explain API endpoints.",
"thinking": "Web development concept. URLs that accept requests. Part of REST APIs. Each endpoint serves a specific function.",
"response": "API endpoints are specific URLs where an API receives requests. Each endpoint corresponds to a resource or action, like '/users' for user data or '/posts' for content."
},
{
"instruction": "What is REST?",
"thinking": "Architectural style for web services. Stateless, uses HTTP methods. Resources identified by URIs. Widely adopted.",
"response": "REST (Representational State Transfer) is an architectural style for web APIs using HTTP methods (GET, POST, PUT, DELETE) to perform stateless operations on resources identified by URIs."
},
{
"instruction": "Define microservices.",
"thinking": "Software architecture. Small, independent services. Each handles specific function. Contrast with monolith.",
"response": "Microservices is an architecture where applications are built as collections of small, independent services. Each service handles a specific function and communicates via APIs."
},
{
"instruction": "What is containerization?",
"thinking": "Deployment technology. Packages app with dependencies. Docker is common example. Lighter than VMs.",
"response": "Containerization packages applications with their dependencies into isolated containers. Unlike VMs, containers share the host OS kernel, making them lightweight and portable."
},
{
"instruction": "Explain CI/CD.",
"thinking": "DevOps practice. Continuous Integration, Continuous Deployment. Automate testing and deployment. Speed up releases.",
"response": "CI/CD (Continuous Integration/Continuous Deployment) automates code integration, testing, and deployment. CI merges code changes frequently; CD automatically deploys tested code to production."
},
]
TRAINING_STAGES = [
("medical", medical_data),
("legal", legal_data),
("technical", technical_data),
]
# Test prompts for retention testing
RETENTION_TESTS = {
"medical": "What is hypertension?",
"legal": "What is due process?",
"technical": "What is REST?",
}
print(f"Training stages prepared:")
for name, data in TRAINING_STAGES:
print(f" - {name}: {len(data)} samples")
Training stages prepared: - medical: 5 samples - legal: 5 samples - technical: 5 samples
# Continual Learning Training Loop
from trl import SFTTrainer, SFTConfig
MODEL_NAME = "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit"
OUTPUT_BASE = "outputs_qlora_continual_think"
THINK_END_TOKEN_ID = 151668 # </think> token for Qwen3-Thinking models
trained_stages = []
retention_results = []
# Load model once - will train sequentially
cleanup_memory()
print("Loading model for continual learning...")
model, tokenizer = FastLanguageModel.from_pretrained(
MODEL_NAME,
max_seq_length=512,
load_in_4bit=True,
dtype=None,
)
# Apply LoRA once
print("Applying LoRA...")
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=16,
lora_dropout=0,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
params = count_parameters(model)
print(f"Trainable: {params['trainable']:,} ({params['pct']:.2f}%)")
def test_retention(model, tokenizer, domains_to_test):
"""Test if model retains knowledge from previous domains with token-based parsing"""
results = {}
FastLanguageModel.for_inference(model)
for domain in domains_to_test:
prompt_text = RETENTION_TESTS[domain]
messages = [{"role": "user", "content": prompt_text}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
temperature=0.6,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
# Token-based parsing for think tokens
generated_ids = outputs[0][inputs["input_ids"].shape[1]:].tolist()
if THINK_END_TOKEN_ID in generated_ids:
end_idx = generated_ids.index(THINK_END_TOKEN_ID)
thinking = tokenizer.decode(generated_ids[:end_idx], skip_special_tokens=True).strip()
response = tokenizer.decode(generated_ids[end_idx + 1:], skip_special_tokens=True).strip()
think_ok = True
else:
thinking = ""
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
think_ok = False
results[domain] = {
"response": response[:200],
"thinking_tokens": len(generated_ids[:generated_ids.index(THINK_END_TOKEN_ID)]) if think_ok else 0,
"think_token_found": think_ok,
}
return results
# Sequential training
for stage_idx, (domain_name, domain_data) in enumerate(TRAINING_STAGES):
print(f"\n{'='*60}")
print(f"Stage {stage_idx + 1}: Training on {domain_name.upper()} domain")
print(f"{'='*60}")
# Format dataset
def format_conversation(sample):
assistant_content = f"<think>\n{sample['thinking']}\n</think>\n\n{sample['response']}"
messages = [
{"role": "user", "content": sample["instruction"]},
{"role": "assistant", "content": assistant_content}
]
return {"text": tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)}
dataset = Dataset.from_list(domain_data)
dataset = dataset.map(format_conversation, remove_columns=["instruction", "thinking", "response"])
# Training config
sft_config = SFTConfig(
output_dir=f"{OUTPUT_BASE}/stage_{stage_idx + 1}_{domain_name}",
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
max_steps=5,
warmup_steps=1,
learning_rate=2e-4,
logging_steps=1,
fp16=not is_bf16_supported(),
bf16=is_bf16_supported(),
optim="adamw_8bit",
weight_decay=0.01,
max_seq_length=512,
seed=42,
report_to="none",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
args=sft_config,
)
# Train
print(f"Training (5 steps)...")
trainer_stats = trainer.train()
final_loss = trainer_stats.metrics.get('train_loss', 0)
print(f"Final loss: {final_loss:.4f}")
# Record stage completion
trained_stages.append(domain_name)
# Save checkpoint
checkpoint_path = f"{OUTPUT_BASE}/stage_{stage_idx + 1}_{domain_name}"
model.save_pretrained(checkpoint_path)
print(f"Checkpoint saved: {checkpoint_path}")
# Retention test on ALL previously trained domains
print(f"\nRetention test after {domain_name} training:")
retention = test_retention(model, tokenizer, trained_stages)
stage_result = {
"stage": stage_idx + 1,
"domain": domain_name,
"retention": retention,
}
retention_results.append(stage_result)
for domain, result in retention.items():
think_status = "✓" if result["think_token_found"] else "⚠"
response_status = "✅" if len(result["response"]) > 50 else "⚠️"
print(f" {response_status} {domain} [{think_status} think]: {result['response'][:100]}...")
# Clean up trainer but keep model
del trainer, dataset
print(f"\n{'='*60}")
print("Continual Learning Complete!")
print(f"{'='*60}")
print(f"\nTrained stages: {trained_stages}")
# Retention Summary
print("="*60)
print("Retention Test Summary")
print("="*60)
for result in retention_results:
print(f"\nAfter Stage {result['stage']} ({result['domain'].upper()}):")
print("-" * 40)
for domain, retention_data in result['retention'].items():
response_text = retention_data["response"]
think_found = retention_data["think_token_found"]
think_tokens = retention_data["thinking_tokens"]
# Check if response seems relevant
status = "✅ RETAINED" if len(response_text) > 50 else "⚠️ WEAK"
think_status = f"✓ think ({think_tokens} tokens)" if think_found else "⚠ no think token"
print(f" [{domain}] {status} [{think_status}]")
print(f" Preview: {response_text[:80]}...")
print("\n" + "="*60)
print("Final Model Capabilities")
print("="*60)
print("The model now has knowledge from:")
for stage in trained_stages:
print(f" ✓ {stage.capitalize()} terminology")
Analysis and Key Findings¶
Continual Learning with LoRA¶
Traditional Fine-tuning Problem:
- New training can overwrite previously learned knowledge
- "Catastrophic forgetting" causes loss of earlier capabilities
LoRA Advantage:
- Base model weights remain frozen
- Adapter weights accumulate knowledge
- Sequential training adds to, rather than replaces, knowledge
Retention Test Results¶
| After Training | Medical | Legal | Technical |
|---|---|---|---|
| Stage 1 (Medical) | ✅ | - | - |
| Stage 2 (Legal) | ✅ | ✅ | - |
| Stage 3 (Technical) | ✅ | ✅ | ✅ |
Expected: All domains should show retention after sequential training.
Thinking Capability¶
The Qwen3-4B-Thinking model's reasoning capability should be preserved throughout:
<think>tags still appear in outputs- Self-questioning reasoning patterns maintained
- Domain knowledge integrated into thinking process
Practical Applications¶
- Incremental Knowledge Updates: Add new domain knowledge without retraining from scratch
- Multi-Domain Expertise: Build models with cross-domain capabilities
- Curriculum Learning: Train on progressively complex topics
- Personalization: Add user-specific knowledge over time
Limitations¶
- Adapter size grows with more knowledge (though modestly)
- Very long training sequences may still cause some degradation
- Trade-off between specialization and generalization
Key Insight¶
LoRA enables practical continual learning for LLMs by keeping base weights frozen while accumulating knowledge in adapter weights. This is more efficient than repeatedly fine-tuning the full model.
# Final cleanup
del model, tokenizer
cleanup_memory()
print("Model unloaded and memory cleared.")
Model unloaded and memory cleared.
# Shutdown kernel to release all GPU memory
import IPython
print("Shutting down kernel to release GPU memory...")
app = IPython.Application.instance()
app.kernel.do_shutdown(restart=False)
Shutting down kernel to release GPU memory...
{'status': 'ok', 'restart': False}