<?xml version="1.0" encoding="utf-8"?>
<rss version="2.0" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:atom="http://www.w3.org/2005/Atom">
    <channel>
        <title>Philipp Kaindl</title>
        <link>https://philippkaindl.com</link>
        <description>Writing about ML engineering, agents, and tooling.</description>
        <lastBuildDate>Wed, 08 Apr 2026 01:33:11 GMT</lastBuildDate>
        <docs>https://validator.w3.org/feed/docs/rss2.html</docs>
        <generator>https://github.com/jpmonette/feed</generator>
        <language>en</language>
        <copyright>© 2026 Philipp Kaindl</copyright>
        <atom:link href="https://philippkaindl.com/rss.xml" rel="self" type="application/rss+xml"/>
        <item>
            <title><![CDATA[To Fine-Tune or Not? Looking at It Through an Engineer's Economic Lens]]></title>
            <link>https://philippkaindl.com/posts/fine-tuning-or-not</link>
            <guid isPermaLink="false">https://philippkaindl.com/posts/fine-tuning-or-not</guid>
            <pubDate>Tue, 07 Apr 2026 00:00:00 GMT</pubDate>
            <description><![CDATA[The fine-tune-or-not question is economic, not ideological. Here is the framework for making it from first principles.]]></description>
            <content:encoded><![CDATA[
Fine-tuning discussions are often framed like matters of taste. One camp says small fine-tuned open models are the future. Another says prompting a frontier model through an API is enough for almost everything.

That framing is usually wrong.

For most real projects, the question is much simpler: for a given task, at a given volume, with a given error profile, what setup gives you the lowest expected cost per event?

That is an economics question, not an ideology question.

There are really two separate decisions:

1. Is this task worth automating at all?
2. If yes, is fine-tuning the cheapest way to reach the required quality?

Once you build a cost model, the answer is often much less mysterious than people think.

## The Wrong Framing

When I talk to customers, peers, or people at conferences, I keep hearing two kinds of statements:

- "This is the year of small open-source models that are fine-tuned."
- "Fine-tuning is the way to go for enterprise."

And on the other side:

- "We already run a large model via API and do huge volumes every day. Fine-tuning probably does not make sense for us, because fine-tuning is expensive and API based is convenient."

Both camps miss the point.

The question is not whether fine-tuning is fashionable. The question is whether it pays off.

## The Engineer's Lens

My background is mechanical engineering.
The one thing that engineering teaches you is reasoning from first principles. Another key learning: once you move beyond experiments, the first serious question is always economic: does the investment pay off?

Before switching to a new machine, new part, new production process, or even before making any development decisions, you model the economics.
That is it. Simple. If you have a thoroughly vetted economic model that you can run "what if" scenarios on, then we talk. Before that, you are just guessing and wasting everybody's time.

You want to know throughput, scrap, oversight cost, downtime risk, and break-even. If you do not know a factor, you make an educated guess. You come back to the guess as time progresses and update it with measured data from the system. Your model gets better over time.

In ML and GenAI, life is often simpler than that. And still, teams regularly make decisions with less rigor.

Production engineering has already trained us to ask useful questions:

- What is the throughput?
- What is the good-part rate?
- What is the scrap rate?
- What is the cost of human oversight?
- Can the system run fully automatic or not?
- Do we need real-time processing, or can we batch?

And then the really important questions:

- What is the direct cost of scrap?
- What is the indirect cost of scrap?
- How does the cost of an error increase the longer it stays in the system?

Production engineering also gives you frameworks like Six Sigma and statistical process control: ways to measure defect rates, understand process variation, and decide when a process is stable enough to trust at scale. Once you want to run an AI system with limited or no human oversight, the same mindset matters: define acceptable quality limits, design an inspection strategy, and draw enough samples to estimate whether the process stays inside those limits. That lets you estimate the expected cost of errors statistically instead of arguing from anecdotes.

That mindset transfers almost directly to ML systems. A wrong prediction is just another form of scrap. The mistake is that many teams still evaluate model choices through benchmark scores, gut feeling, or vendor narratives instead of expected business cost per event.

Let us make that concrete with a classification example.

## Start With the Data or Accept That You Are Guessing

Before discussing prompting strategy, model choice, or fine-tuning, get the basic inputs right.

For almost any classification or extraction workflow, you need at least these four groups of numbers. The exact categories depend on your domain, but these cover most of the ground.

### 1. Volume Data

- How many emails, documents, pages, or events do you process?
- Is there seasonality?
- Is the workload steady, bursty, or batch-driven?

### 2. Time Data

- How long does a human take to classify one item on average?
- How long does a human take to extract the required information?
- How much review time remains even after automation?

Once you know average handling time and hourly wage, you already have a baseline cost per event.

### 3. Cost of Errors

This is the part many teams skip, and it is often where most of the value sits.

- What happens when the prediction is wrong?
- Does another human fix it later?
- Does it create delay?
- Does it trigger rework downstream?
- Does it hurt customer satisfaction?
- Does it create compliance or escalation risk?

Even if the estimate is rough, assign a number. A rough number is better than pretending the cost is zero.

When a cost is uncertain, do not default it to zero. Use a low, base, and high estimate instead. Bad estimates can be corrected. Missing estimates make the whole model blind.

### 4. System Cost

- API inference cost
- hosting cost
- training or fine-tuning cost
- evaluation cost
- monitoring cost
- human-in-the-loop cost
- retraining or relabeling cost

Without this, you cannot answer the actual question.

If collecting all of this feels like a lot of structure to set up from scratch, there is good news: at the end of this post, we packaged the full question bank and the formulas into a skill you can drop into your favorite AI coding agent. It will walk you through exactly these inputs.

## A Note on Human Benchmarking

In some environments, especially where worker protections are strong, teams hesitate to measure individual handling performance. Fair enough. But that does not remove the economic problem.

If you cannot measure per person, measure per team:

- total hours per week
- total volume per week
- total rework volume

A good estimate is still far better than flying blind. If you cannot get even rough numbers, that is usually not an AI problem. It is a data and operations problem.

## A Note on Hard-to-Measure Metrics

Some of the most important costs are hard to measure directly: customer satisfaction, brand damage, escalation probability, regulatory risk, or the cost of making a customer wait one day too long.

That does not mean they should be ignored. It means someone has to make a decision and assign a rough value.

I often hear things like: "We want to improve NPS by 0.1 points." Fine. What is that worth? If nobody in the room can even make an approximate economic argument, then the KPI may be directionally useful, but it is not yet usable for investment decisions.

This is not KPI shaming. Soft metrics matter. But if they matter enough to drive budgets, then somebody has to put some skin in the game and translate them, however roughly, into business value.

A useful fallback is to price soft metrics through scenarios. Ask: "If this gets 20% worse for a quarter, what is the plausible downside?" That usually gets you to a better number than abstract KPI talk.

## Worked Example: Triage at a Mid-Sized Insurance Company

Let us make this more tangible.

Imagine a mid-sized insurer with a large operations team handling incoming written interactions: claims notifications, billing questions, policy changes, proof-of-insurance requests, cancellation requests, and supporting documents.

Some of these come in as plain emails. Some come through a portal. Many arrive with attachments that still need to be routed to the right queue.

Assume:

- 5,000 inbound written interactions per day
- 10 top-level classes, each with 10 subclasses (100 possible routes)
- average human handling time for classification: 30 seconds
- average hourly labor cost: €20

Now define the business impact of mistakes. In this setting, the real cost is not just labor. It is delay, rework, repeat contact, escalation risk, and in some cases service-level or compliance risk.

To keep the example concrete, use a management proxy instead of pretending this number falls out of physics.

With an average customer ARR of €600, the team might price one day of avoidable delay at roughly €4 — covering repeat contact, churn risk, and satisfaction impact. (€600/year ÷ 250 working days ≈ €2.40/day; using €4/day means assuming that the operational downside of delay is materially larger than the customer's average daily revenue contribution.)

This is not a universal truth. It is a working estimate. If you dislike the number, replace it with low, base, and high scenarios and rerun the model.

- 1 business day of avoidable delay on a customer interaction costs €4 on average
- if an interaction goes to the wrong department, it usually takes about 2 extra business days to recover
- if the top-level department is correct but the subclass is wrong, it usually creates about half a day of additional delay

That means:

- 1 day of delay = **€4**
- wrong department (2 days delay) = **€8**
- wrong subclass (0.5 day delay) = **€2**

Already, that tells us something important: the dominant cost may not be model inference. It may be classification mistakes.

This is a trap I see often in ML. Teams focus on inference cost because it is easy to measure and mentally hide everything that is messy or indirect. But the process does not care which part of the cost structure was easy to log.

ML engineers need more of a process-owner mindset. And business owners need to assign value to the hard-to-measure parts as well. Otherwise you optimize the cheapest visible part of the system and ignore the expensive invisible part.

## First Baseline: Human Cost Per Event

The direct labor cost for manual classification:

$$\text{classification cost} = \frac{\text{hourly wage}}{3600} \times \text{seconds per interaction} = \frac{20}{3600} \times 30 = €0.1667 \text{ per interaction}$$

Call it **€0.17 per event**. That is your floor for manual processing before you even consider error cost.

In many workflows, direct labor is not the expensive part. Delay, rework, escalation, and customer friction dominate the economics long before raw inference cost does.

## Why Accuracy Alone Is Not Enough

A plain accuracy score is often the wrong metric. Not all mistakes cost the same.

A wrong subclass is annoying. A wrong department is much worse.

So instead of looking only at a raw confusion matrix, collapse it into business-relevant error buckets and assign cost to each bucket.

In the general case, the expected error cost is:

$$E[\text{error cost}] = \sum_{i,\,j} P(\hat{y}=j \mid y=i) \cdot c(i,j)$$

Where $i$ is the true class, $j$ is the predicted class, $P(\hat{y}=j \mid y=i)$ is the probability of predicting $j$ when the true label is $i$ (read directly from your confusion matrix), and $c(i,j)$ is the business cost assigned to that specific mistake.

If the system also sends a share $r_{review}$ of cases to human review at cost $C_{review}$, then the all-in expected cost per event becomes:

$$C_{total} = E[\text{error cost}] + r_{review} \cdot C_{review} + C_{system}$$

That extra review term matters because many real automations are not fully automatic. They are partially automatic with selective human fallback.

For a full 100-class routing setup, that can become a large matrix. In practice, group cells into a few business-relevant buckets as long as the grouped mistakes really do have similar downstream cost.

Table: Example business cost buckets for routing errors.

| Outcome | Business meaning | Cost |
|---------|-----------------|------|
| Correct route | Correct dept and subclass | €0 |
| Wrong subclass | Correct dept, wrong subclass | €2 |
| Wrong department | Wrong top-level class | €8 |

This is the step many teams never do. But once you do it, you can compare humans, prompted models, and fine-tuned models on the same axis: expected euros per event.

Accuracy becomes much more useful once you attach money to the error buckets. A 2-point improvement in the wrong place can be worth less than a 0.5-point improvement on the expensive mistake.

## Turn the Confusion Matrix Into Money

Assume the current human process:

- top-level (dept) accuracy: 95%
- full-route accuracy: 80%

That implies: 80% fully correct, 15% right dept / wrong subclass, 5% wrong department.

Table: Example expected error-cost calculation for the current human routing process.

| Outcome | Share | Cost | Contribution |
|---------|-------|------|-------------|
| Correct route | 80% | €0 | €0.00 |
| Wrong subclass | 15% | €2 | €0.30 |
| Wrong department | 5% | €8 | €0.40 |
| **Total error cost** | | | **€0.70** |

Add the direct handling cost:

- manual handling: €0.17
- expected error cost: €0.70
- **total expected cost per interaction: €0.87**

At 5,000 interactions per day: **€4,350 / day**.

## Do Not Tune the Inputs Until the Output Feels Nice

One thing you absolutely should not do is start changing assumptions just because the first number feels uncomfortable.

I have seen this multiple times. A team does the first pass, sees a big cost number, and immediately starts nudging the assumptions until the result looks emotionally acceptable.

That is exactly the wrong move.

If you have a good reason to change an assumption, change it. But do not lower the cost of a bad customer experience just to make the spreadsheet feel nicer.

Sometimes the right response to an uncomfortable number is to ask whether your value model is incomplete. Maybe the delay cost is lower than you thought. Maybe it is actually higher because quick responses improve retention or satisfaction, thus falling behind means you leave money on the table. 

The point is to get a cost model that keeps everybody honest. 

If the first result looks absurd, resist the urge to "fix" the spreadsheet immediately. First ask whether the business process itself is more expensive than people realized.

## Prompted Model vs Fine-Tuned Model

At this point the first decision is already visible: any automated option that gets materially below the manual baseline of €0.87 per interaction is worth considering. (Remember, human cost + cost of error)

Now we can ask the second question — across all 5,000 daily interactions, since every one of them needs to be classified regardless of what happens downstream.

To keep the comparison readable, assume there is no mandatory human review on every interaction and that any selective benchmarking and evaluation workload is already folded into the ops terms below. If your process still reviews a fixed share of cases, add that residual review cost explicitly with the formula above.

Suppose you compare two candidate systems.

### Option A — Prompted General Model

Assume:

- full-route accuracy: 93%
- wrong subclass: 5%
- wrong department: 2%
- inference and ops cost: €0.06 per interaction

Expected variable cost: `0.05 × €2 + 0.02 × €8 + €0.06` = **€0.32 / event**

### Option B — Fine-Tuned Model

This is where fine-tuning tends to shine. For tasks like classification and routing, a fine-tuned model of decent size will rarely lose to a prompted frontier model. The reason is simple: the opaque decision boundaries that make routing hard — which product, which channel, which contract type routes where — are easier to learn from labeled examples than to describe in a prompt. Company-specific intricacies are easier trained than prompted.
The best part? If you introduce new classes later, you can always add a "prompt escape hatch" for the model even after training. In the end it still an LLM. 
Example Prompt: "If the presented data does not fit any of the classes, route it to the new class |XYZ| or send to a human via |HUMAN|".


Assume:

- full-route accuracy: 96%
- wrong subclass: 3%
- wrong department: 1%
- inference, serving, and routine ops cost: €0.015 per interaction

Expected variable cost: `0.03 × €2 + 0.01 × €8 + €0.015` = **€0.155 / event**

Incremental variable savings: **€0.165 / event** → **€825 / day** → **€74,250 over 90 days**.

If the fine-tuning project costs €40,000 all-in as a fixed investment: break-even at roughly 49 days, clearly positive at 180 days (€148,500 gross savings over 180 days, or **€108,500 net** after the project cost).

One important caveat: in many real projects, the expensive part is not the training run itself. It is data preparation, labeling cleanup, evaluation, integration, and operating the thing properly afterward. And "afterward" is not a one-time event. Someone has to monitor model quality over time, retrain when the distribution shifts, maintain the serving infrastructure, and handle the inevitable edge cases that only surface in production.

All of those ongoing costs — MLOps time, retraining cycles, serving and monitoring infrastructure — should be folded into $C_{ft}$. If you only count the initial project as your investment and ignore the running cost of keeping the system alive, your model will look too optimistic.

Fine-tuning cost is rarely just "GPU hours for one training run." In real projects, data cleanup and evaluation dominate the upfront work. Ongoing MLOps, retraining, and serving costs dominate the long-term picture. Both belong in your cost model.

That is exactly why volume and horizon matter so much. If the eligible volume was much lower, or the misclassification cost was smaller, fine-tuning may not be worth the additional complexity. The answer comes from the economics, not from fashion.

## The Real Decision Rule

Define:

- $N$ = number of relevant predictions over the decision horizon
- $C_{prompted}$ = all-in variable cost per event of the prompted system
- $C_{ft}$ = all-in variable cost per event of the fine-tuned system
- $I_{ft}$ = fixed fine-tuning investment over the horizon

Then fine-tuning is worth it when:

$$N \cdot (C_{prompted} - C_{ft}) > I_{ft}$$

That is the core inequality. Everything else is implementation detail.

If you prefer to amortize the project investment into $C_{ft}$, that is fine too. Just set $I_{ft} = 0$ in the formula and make sure you do not count the same cost twice.

Once you have this calculation for one project, run it for all of them. Most organizations have more candidate projects than they have team capacity to execute. Fine-tuning projects share a constrained resource — the people who can build and operate them — so some investments are mutually exclusive. Comparing the expected net benefit across projects is essentially capital budgeting: in classical finance, this is where you would compute an internal rate of return (IRR) for each project and rank them. You do not need that level of formality here, but having even rough per-project economics side by side makes the portfolio decision visible instead of political.

## Reverse the Formula

Once a prompted model is already cheaper than a human, the next practical question is often not "Should we fine-tune?" but:

- What is the maximum amount we are allowed to invest in fine-tuning?
- Or, if we already know the investment budget, what is the maximum per-prediction cost difference we can tolerate between the prompted model and the fine-tuned model?

From here on, let $I_{ft}$ mean fixed fine-tuning investment that has not already been amortized into the per-event cost.

Define:

- $N$ = number of relevant predictions over the horizon
- $C_{prompted}$ = all-in variable cost per event of the prompted model
- $C_{ft}$ = all-in variable cost per event of the fine-tuned model
- $I_{ft}$ = total fixed fine-tuning investment over the horizon

Break-even condition:

$$N \cdot (C_{prompted} - C_{ft}) = I_{ft}$$

Maximum justified fine-tuning investment:

$$I_{ft,max} = N \cdot (C_{prompted} - C_{ft})$$

This gives you a hard ceiling. If the project costs more than that, the economics do not work.

This is the formula I would bring into a budgeting discussion. It turns "We think fine-tuning might help" into "This is the most we should be willing to spend."

### Same Token Cost Assumption

If you assume the prompted model and the fine-tuned model have the same direct prediction cost, then the difference is driven only by error reduction. In that case, let $E_{prompted}$ and $E_{ft}$ be the expected error costs per event:

$$I_{ft,max} = N \cdot (E_{prompted} - E_{ft})$$

That is often a very helpful sanity check. It tells you how much accuracy improvement is worth in euros before you even discuss infrastructure details.

### Maximum Allowed Prediction Cost Difference

Now take the other direction.

Suppose you already know the total fine-tuning investment budget $I_{ft}$, and you want to know how much more expensive the fine-tuned prediction is allowed to be compared with the prompted baseline.

Define:

- $E_{prompted}$ = expected error cost per event of the prompted model
- $E_{ft}$ = expected error cost per event of the fine-tuned model
- $P_{prompted}$ = direct inference cost per event of the prompted model
- $P_{ft}$ = direct inference cost per event of the fine-tuned model

Then the maximum allowed fine-tuned prediction cost premium is:

$$\Delta P_{max} = (E_{prompted} - E_{ft}) - \frac{I_{ft}}{N}$$

Equivalent form:

$$P_{ft,max} = P_{prompted} + (E_{prompted} - E_{ft}) - \frac{I_{ft}}{N}$$

If that value is negative, the message is simple: your fine-tuned model does not just need to be better — it also needs to be cheaper to run.

This reverse view is useful when infra teams and ML teams talk past each other. One side talks model quality, the other talks serving cost. This formula gives both sides one shared number.



## A Note on Organizational Readiness

The formulas above are agnostic about who does the work. But in practice, the decision to fine-tune is not just an economic one — it is also a capability question.

Running a fine-tuned model in production means somebody in your organization needs to own the pipeline: data versioning, training runs, evaluation, deployment, monitoring, and retraining when the world changes. That does not require a huge ML platform team — the tooling has gotten dramatically easier — but it does require operational discipline and at least a small team that is comfortable with the workflow.

If your organization does not have that capability today, the cost of building it becomes part of $I_{ft}$. If it does, much of that cost is already sunk and the marginal cost of one more fine-tuned model is lower than people assume.

This is why the framework in this post is more naturally suited to organizations with some MLOps maturity. If you are a two-person startup with no ML infrastructure, the overhead of standing up a fine-tuning pipeline may dominate the economics regardless of what the per-event math says. A prompted API call that costs more per event but requires zero infrastructure may still be the rational choice — for now.

The good news: this is not a permanent constraint. As tooling improves and as your volume grows, the break-even point shifts. The model tells you when.

The per-event cost model assumes someone can actually operate the system. If standing up a fine-tuning pipeline is new for your team, include the ramp-up cost honestly. Do not pretend it is zero, but do not pretend it is insurmountable either.

## A Note on Supply Chain Risk

There is a mirror image of the organizational readiness argument that people bring up far less often.

When teams say "We just call an API, so we do not have to manage a model," they are making an implicit assumption: that calling an API is maintenance-free. It is not.

You are still running a model. You just do not own it. And that creates its own risk profile:

- **Silent distribution shifts.** The provider updates or replaces the model behind the endpoint. Your prompts still run, but accuracy may change. If you are not running evaluations regularly, you will not notice until downstream metrics degrade.
- **Model deprecation.** Models get retired. If your workflow depends on a specific model version, you may be forced to migrate on someone else's timeline.
- **Throughput constraints.** You cannot scale an API endpoint indefinitely. At high volume, you need quota increases, which require lead time and sometimes negotiation. During demand spikes, you may hit rate limits you did not plan for.
- **No control over the training cycle.** If the model's behavior drifts in a way that hurts your use case, your only lever is prompt engineering. You cannot retrain, you cannot freeze a version indefinitely, and you cannot inject your own data to correct the drift.

In supply chain terms, this is a single-supplier dependency with limited contractual control over the product specification. In German manufacturing, we would call this a *Lieferkettenrisiko* — a supply chain risk. The mental model is the same: you trade operational complexity for external dependency, and that dependency has a cost that belongs in your model.

None of this means API-based inference is wrong. For many use cases, especially at lower volumes, it is the clearly rational choice. But the comparison should be honest. If you count MLOps cost for the fine-tuned path, you should also count evaluation overhead, migration risk, and throughput constraints for the API path.

Calling an API does not mean you have no model to manage. It means someone else manages the model, and you manage the dependency. Both have costs. Both belong in the comparison.

## Where Fine-Tuning Usually Makes Sense

The equation tells you whether fine-tuning pays off. But there are also use-case patterns where it tends to make sense more often than people realize.

### 1. When You Already Have Labeled Data and Did Not Notice

One of the best cases for fine-tuning is when training data already exists because humans have been labeling it for operational reasons anyway.

Email classification is a classic example. Most companies already have historical email traffic that ended up in a department, got forwarded, got corrected, or got resolved by a certain team. That operational history is often an implicit label source.

The same is true for many document-heavy workflows:

- document routing
- packet splitting
- extracting key information into UI masks
- assigning documents to specialists

If somebody in the past designed the process well enough that humans were already creating structured outcomes, you already paid a good chunk of the labeling cost. That is a serious head start.

And where the routing history is messy or ambiguous, you can still use humans as judges to clean up the tail of the distribution instead of starting from zero.

One of the best sources of training data is boring operational history. If humans already routed, corrected, approved, or extracted something, there is a good chance you already own part of the dataset you need.

### 2. When the Decision Boundary Lives in People's Heads

Another strong signal is whenever you hear phrases like:

- "You just need to know."
- "It becomes obvious once you have worked here for a while."
- "This one is tricky because of how product A interacts with product B."

Those are signs that the decision boundary is real, valuable, and poorly documented.

Take email classification again. The first few classes may be easy. But then you get opaque routing logic built up over years: if it references one product but comes from another channel, and mentions a certain contract type, then it belongs somewhere unexpected.

At that point you have two choices:

1. clean up the process and simplify the rules
2. learn from historical labeled behavior

Sometimes the correct answer is process redesign. Sometimes the correct answer is fine-tuning on the accumulated operational data.

Phrases like "you just need to know" or "it depends on who reads it" are often signals that useful label information exists, but lives only in people's heads.

### 3. When People Fear That New Classes Mean Starting Over

"If we add a new class, do we have to retrain everything?"

Usually, no.

With LLM-based classifiers, you are often not building a rigid prediction head that must be redesigned every time the taxonomy changes. You are teaching the model to output a label or structured token sequence. That is more flexible than people assume.

Taxonomy changes are not free — they still need evaluation, data updates, and operational discipline. But they are not the catastrophic reset that some teams imagine.

### 4. When Batch Economics Matter

There is another case where fine-tuning can become very attractive: batch-heavy workloads.

Yes, many API providers offer batch inference. But often those windows are around 24 hours, and 24 hours is just a bit too slow for many operational processes. You may not need strict real-time, but you also do not want to wait until tomorrow.

That is where a fine-tuned model served in your own environment can get interesting. You can accumulate enough work, run batch prediction on your own schedule, and drive the cost down.

This can mean:

- a scheduled ECS task
- a containerized batch worker
- a self-hosted inference service
- a serving framework optimized for throughput and caching

The point is not that self-hosting is automatically better. The point is that once volume is high enough and latency constraints are awkward enough, the economics can shift very quickly.

## Food for Thought: Multi-Layer Cost Structures

Everything in this post models a single layer: routing accuracy and its direct cost implications. But real processes rarely have just one layer.

In the insurance example, correctly classifying an email is only the first gate. After it has been routed to the correct department, a specialist in that team opens the correspondence, reads it, and handles it — finding the right template, checking eligibility, or preparing a response. Specialists typically carry a higher cost structure than the routing layer: more domain knowledge, higher hourly rates, and scarcer capacity. For a large share of incoming interactions — FAQ-style requests, standard acknowledgements, routine document requests — that handling step is largely mechanical. A human still does it, but it is the kind of work that an agent system with access to the right business logic could own. That handling cost is real, sits downstream of the routing decision, and is not modelled here.

This creates a multi-layer benefit structure: improvements in upstream classification accuracy unlock downstream automation potential. If the routing is wrong, the downstream handler — whether human or agent — starts from a bad position. If the routing is right and the intent is unambiguous, you have the option to automate the response entirely.

The broader point: upstream efficiency is a prerequisite for downstream automation. Strong routing is not proof that fully automated handling is safe, but it is one of the gates you have to pass before that conversation is even serious. Once each layer clears its own quality bar, the economics can compound.

## Practical Takeaways

- Do not ask "Should we fine-tune?" before asking "What does an error cost us?"
- Do not use raw accuracy as your primary decision metric when different mistakes have different business consequences.
- Collapse the confusion matrix into economically meaningful buckets.
- Compare options on expected cost per event, not on benchmark vanity.
- Separate the decision to automate from the decision to fine-tune.
- Do not massage assumptions until the output feels emotionally acceptable.
- Look for hidden label sources in existing operational systems.

## Conclusion

Apply this framework to any ML or GenAI project you are running right now.

If you can estimate:

- volume
- handling time
- hourly labor cost
- confusion matrix or error bucket rates
- business cost per error type
- system cost

then you can usually answer the fine-tune-or-not question much faster than most people expect.

And if you cannot answer it, that is useful too. It means the next problem to solve is not model architecture. It is measurement.

## Grab the Skill

If you want to turn this framework into a working prep workflow, I packaged it as a two-file skill you can drop into any coding agent with skill support.

- [SKILL.md](/skills/ai-project-economics-buddy/SKILL.md) — the agent instructions (drop this into your skills folder)
- [REFERENCE.md](/skills/ai-project-economics-buddy/REFERENCE.md) — formulas, question bank, and the worked example from this post

It is designed for BDs, AI strategists, product leads, and solution architects. Instead of giving you generic AI advice, it pushes you to quantify the use case, identify missing assumptions, and generate the right questions to take back to operations, finance, compliance, or the business.

## Outlook

In follow-up posts, I want to go deeper into how to estimate the true cost of fine-tuning projects (spoiler: training is rarely the expensive part) and fine tune an LLM for information extraction to mimic a real use case and break down all the cost of the fine tuning that we incured. 

Hope to see you there! 

]]></content:encoded>
            <author>Philipp Kaindl</author>
            <category>ML Engineering</category>
            <category>Fine-Tuning</category>
            <category>Economics</category>
        </item>
        <item>
            <title><![CDATA[Secure Execution of Code Generated by Large Language Models]]></title>
            <link>https://philippkaindl.com/posts/secure-llm-code-execution</link>
            <guid isPermaLink="false">https://philippkaindl.com/posts/secure-llm-code-execution</guid>
            <pubDate>Thu, 08 Aug 2024 00:00:00 GMT</pubDate>
            <description><![CDATA[Explore attack vectors, threat modeling, and AWS-based mitigation strategies for securely executing LLM-generated code.]]></description>
            <content:encoded><![CDATA[
*Co-authored with [Michael Tschannen](https://medium.com/u/aae8d67feab7) and [rollwag](https://medium.com/u/80877c7eda78).*

Large Language Models (LLMs) such as Anthropic’s Claude 3 have revolutionized applications, from chatbots to automated content generation. However, their increasing integration into digital services has also introduced novel security challenges. This blog explores the vulnerabilities of LLMs, common attack vectors, and strategies to mitigate these risks. Specifically, we look at a solution that generates code with an LLM and executes it in a corporate environment.

We take as a prominent example the option to interrogate CSV files by generating [pandas](https://pandas.pydata.org/) code and create visualizations for it — a use case with high popularity. The [pandasAI package](https://github.com/sinaptik-ai/pandas-ai) already has 12.2k stars on Github at the moment of writing this blog.

In this context, we keep a fictitious use case in mind:

> *You work as a lead data scientist and ML Architect for the fitness and health tracker app ‘AnyFitness’ and want to allow your users to interrogate their fitness and health data. Your CTO is delighted, but health data is very sensitive so you need to really be sure that the privacy and security of the handled data is not compromised.*

This means we think about a Business-to-Consumer (B2C) application where the data is already available in the cloud. You could also extend the thought exercise to a version where the users are allowed to upload the data themselves. However, this opens another threat vector through the instructions potentially contained within this data. More on this later.

What you find in this blog:

* A theoretical introduction to common attack vectors against LLMs and how to mitigate them
* A practical guideline on how to approach security risks when dealing with LLMs by using the threat modeling technique
* A case study on how to securely create an execution environment for LLM-generated code through AWS Lambda functions
* A summary of how AWS services can help you in this journey

## 1. TL;DR:

Large Language Models (LLMs) have revolutionized applications, but their integration into digital services has introduced novel security challenges. This blog explores common attack vectors against LLMs, such as insecure output handling, sensitive information disclosure, and excessive agency. It also provides practical guidance on threat modeling and mitigation strategies, including using ephemeral and isolated execution environments, limiting network connections, isolating languages and libraries, and implementing guardrails. Additionally, using AWS services like AWS Lambda, Amazon S3, and Amazon CloudWatch can streamline the secure execution of LLM-generated code.

## 2. Large Language Models (LLMs): A comer security perspective

Within this blog, we focus on *using large language models (LLM)* and exclude the process of creating them.

An LLM consists of two pieces: first, a small amount of code, and second, a large file containing the model weights. In essence, an LLM is not too much different from traditional software: It takes some input, runs code on it, respectively applies some transformation, and returns an output. As a result, many of the common security controls that are commonly used in software are applicable to using an LLM as well.

However, there are also differences. Traditional software is deterministic, which means that while we might be facing untrusted input, the output data is commonly under our control. With LLMs, a single execution steps is deterministic, but the *combination of all steps* that are taken within an inference process is not. With this, also the output data that is returned by an LLM is not fully deterministic, which can be translated to “not entirely under our control”. Additional controls are, therefore, required to safeguard an application. One example is the use of [guardrails](https://aws.amazon.com/bedrock/guardrails/), which are available in [Amazon Bedrock](https://aws.amazon.com/bedrock/).

In the use-case presented in this article, we have an LLM generate source code and then run it. This is an excellent example of having a potentially untrusted output, which is used in subsequent steps in the process.

### 2.1 LLM-specific attack vectors

Common frameworks and knowledge bases that identify risks and threats in Generative AI applications include the [OWASP Top 10 for Large Language Model Applications](https://genai.owasp.org/llm-top-10/) and [MITRE ATLAS](https://atlas.mitre.org/matrices/ATLAS). The known risks, techniques, and tactics for attacks are still changing regularly; novel approaches are found in the various research that is currently conducted.

In this chapter, we discuss a few which are relevant to our use-case, based on the OWASP Top 10 framework, including a shortened definition.

1. [LLM02](https://genai.owasp.org/llmrisk/llm02-insecure-output-handling/): Insecure Output Handling
   **OWASP Definition**: Insecure Output Handling refers specifically to insufficient validation, sanitization, and handling of the outputs generated by large language models before they are passed downstream to other components and systems. Since LLM-generated content can be controlled by prompt input, this behavior is similar to providing users indirect access to additional functionality.
   **Why is it relevant?** This is the most directly relevant risk when executing LLM-generated code, because the code is generated by an LLM and is thus not fully under our control.
2. [LLM06](https://genai.owasp.org/llmrisk/llm06-sensitive-information-disclosure/): Sensitive Information Disclosure
   **OWASP Definition**: LLM applications have the potential to reveal sensitive information, proprietary algorithms, or other confidential details through their output. This can result in unauthorized access to sensitive data, intellectual property, privacy violations, and other security breaches.
   **Why is it relevant?** The dynamically generated code could potentially contain or access sensitive information if not properly controlled.
3. [LLM08](https://genai.owasp.org/llmrisk/llm08-excessive-agency/): Excessive Agency
   **OWASP Definition**: An LLM-based system is often granted a degree of agency by its developer — the ability to interface with other systems and undertake actions in response to a prompt. The decision over which functions to invoke may also be delegated to an LLM ‘agent’ to dynamically determine based on input prompt or LLM output. Excessive Agency is the vulnerability that enables damaging actions to be performed in response to unexpected/ambiguous outputs from an LLM (regardless of what is causing the LLM to malfunction; be it hallucination/confabulation, direct/indirect prompt injection, malicious plugin, poorly-engineered benign prompts, or just a poorly-performing model).
   **Why is it relevant?** Giving an LLM the ability to generate and execute code grants it a high degree of agency that needs to be carefully controlled.
4. [LLM09](https://genai.owasp.org/llmrisk/llm09-overreliance/): Overreliance
   **OWASP Definition**: Overreliance can occur when an LLM produces erroneous information and provides it in an authoritative manner. While LLMs can produce creative and informative content, they can also generate content that is factually incorrect, inappropriate, or unsafe. This is referred to as hallucination or confabulation. When people or systems trust this information without oversight or confirmation, it can result in a security breach, misinformation, miscommunication, legal issues, and reputational damage.
   **Why is it relevant?** Blindly trusting and executing LLM-generated code without proper oversight/controls can lead to security vulnerabilities.

While other risks like prompt injection ([LLM01](https://genai.owasp.org/llmrisk/llm01-prompt-injection/)) or training data poisoning ([LLM03](https://genai.owasp.org/llmrisk/llm03-training-data-poisoning/)) could indirectly impact code generation, the below four risks are most directly relevant when it comes to actually executing dynamically generated LLM code.

### 2.2 General controls

Understanding techniques to compromise an LLM is crucial for both developers and users of LLMs. As mentioned earlier, an LLM shares a lot of the risks that we know from traditional software and, thus, well-known controls. Some of them include but are not limited to:

1. **Robust Access Controls**
   Implement strict access controls to limit who can interact with the LLM and what kind of inputs they can provide. Make sure to follow a least-privilege principle throughout the whole pipeline, especially when additional elements, such as vector stores for RAG, are involved.
2. **Data Classification and Handling**
   Perform a risk analysis and classification of any involved data. This includes training data, supporting data such as knowledge bases (for RAG), etc. Consider removing or sanitizing sensitive data from your workflows.
3. **Input Validation and Sanitization**
   Use input validation to filter out obvious malicious inputs and input sanitization to neutralize potentially harmful elements within user inputs.
4. **Output Validation and Sanitization**
   Validate and sanitize the data that was produced by the LLM before further processing it.
5. **Output Monitoring**
   Employ real-time monitoring of the LLM’s outputs to detect and respond to suspicious activities promptly. Consider logging prompts and responses to allow anomaly detection.
6. **Verify: Is LLM-generated data needed, or is the data deterministic / known?**
   Only use LLM-generated data when you need it. If information can be taken from a deterministic / verified source, use it from there. An example can be a user ID: If you know a user information in your code, use the proven source instead of a generated one.
7. **Regular Security Updates**
   Continuously install patches and update security protocols and defenses to keep up with evolving techniques and tactics.
8. **Awareness and Training**
   Educate developers and users about the potential risks and best practices for interacting with LLMs securely.

In this post, we deliberately focus on a technical perspective and omit equally important topics such as compliance and governance. Those should be considered for a holistic risk analysis, too.

As mentioned earlier, some controls we should consider are different from traditional software. Let’s dive into more details and a practical approach to them in the next chapter.

## 3. Threat modeling our LLM application

### 3.1 Threat Modeling Methodology

When dealing with potential security risks from executing LLM-generated code, it’s crucial to follow a structured approach for identifying and mitigating threats. One effective methodology for this is threat modeling, which we outline in this section before applying it to our “talk to my CSV” example.

On a high level, our use-case involves “executing code”. Let’s assume this code was developed by an engineering team — we would probably use mechanisms like code reviews, code scanners ([Static Application Security Testing](https://en.wikipedia.org/wiki/Static_application_security_testing)), vulnerability scans or penetration testing (incl. [Dynamic Application Security Testing](https://en.wikipedia.org/wiki/Dynamic_application_security_testing)), etc. However, here we generate code dynamically, i.e., “on-the-fly”, and execute it. While we could use some of the mentioned mechanisms on the generated code, this step might add latency to the point that the usability of the solution is reduced and, in extreme cases, rejected by the users.

Instead, we think of the LLM-generated code as being “untrusted” and assume it’s potentially malicious. We put on the threat actor’s hat and think about where and how malicious code could be injected. Of course, we don’t want the LLM to produce such code, but having this mental model helps define the controls that are needed to safeguard the execution of this code.

There are various ways to approach risk modeling for an application. One of them is [threat modeling](https://docs.aws.amazon.com/wellarchitected/latest/security-pillar/sec_securely_operate_threat_model.html), for example, using the [STRIDE methodology](https://en.wikipedia.org/wiki/STRIDE_model). During threat modeling, we answer the following questions:

* What are we working on?
* What can go wrong?
* What are we going to do about it?
* Did we do a good job?

To answer the first question, we use a [data flow diagram](https://en.wikipedia.org/wiki/Data-flow_diagram). This not only helps understanding how the system is intended to work, but also where and how data is accessed and used — including code generated by an LLM, as in our example.

Next, we think about “what could go wrong”, i.e., we are identifying the threats to our system. We use the data flow diagram to associate them with components, processes, or communication channels.

Once we have defined the threats, we will use a methodology inspired by insurance math (Actuarial science), as well as failure mode and effects analysis (FMEA) from engineering science. A risk can be looked upon as an event that has (1) a certain probability to happen and (2) an impact. The risk score can be calculated by generating the product of the two. Rating the different scores allows the categorization of risks using a [risk matrix](https://en.wikipedia.org/wiki/Risk_matrix). The color coding can be used to imply how to cope with a fix (for example: *red* = must be mitigated, *amber* = can be accepted or mitigated, *green* = is automatically accepted). To prevent differences in the ratings, make sure to define a specific scale for both (e.g. “happens every x days/years” for the probability and a financial value for the impact).

Next, we plan mitigations for the threats we have identified in our applications and then re-assess our risks “after applying a mitigation”. These are the answers to “What are we going to do about it?”. Each mitigation either reduces the impact or probability of the risk and changes its position in the risk matrix. We do this until we reach a level where the residual risk associated with our application is acceptable.

The fourth and final question (“Did we do a good job?”) aims to improve both the quality of threat models and the velocity at which you are performing threat modeling over time. It is thus similar to retrospectives in agile software development processes.

In the next chapter, we execute these steps to assess the demo use-case we have introduced at the beginning of this article.

### 3.2 Creating the foundation

Let’s apply this and examine the example to interrogate a CSV file we upload in our application. We keep it AWS service agnostic and focus on the system itself.

First, lets look at a system overview (not yet a fully-fledged data flow diagram, we’ll do this in the next step):

![](https://cdn-images-1.medium.com/max/800/0*BFhCn0wpIJQuBxwH)

The overview diagram for the generative AI application.

1. The user uploads a .csv file, or the data is supplied by a backend system for the user interaction. You can imagine the .csv file as a mental placeholder for any data provisioning you might want to implement for your app.
2. The LLM is supplied with the user query as well as some information about the data for which it should write code.
3. Code and data are supplied to the Code interpreter (we place the Python symbol as a placeholder for other programming languages).

Let’s now create a data flow diagram for the application and start our threat modeling. Note that we have kept all steps minimal here for brevity. In the data flow diagram, we are specifically focusing on what data is passing through which part of the system:

![](https://cdn-images-1.medium.com/max/800/0*nPdP6cGRGHucltvr)

The data flow diagram for the generative AI application.

Here’s a description of the steps depicted in the diagram:

1. The user uploads a CSV containing medical data to a cloud storage
2. The user navigates to the web application
3. The user sends a prompt on the client side, which is forwarded to logic by the API
4. Logic downloads the CSV containing medical data from cloud storage to extract some basic information (e.g., number of rows, name of columns, etc.)
5. Logic sends CSV information and user prompt to the LLM, which generates and returns Python code
6. Logic sends the generated code and pointer to the CSV to the execution environment
7. The execution environment downloads the CSV containing medical data from the cloud storage
8. The execution environment executes the Python code and sends back the response to the logic
9. The execution environment writes to the execution environment logs
10. The logic environment writes to the logic environment logs
11. Logic sends back the final response to the user

This diagram not only helps us better understand what data is involved, how it is used and transmitted, but also facilitates security-related conversations about the solution.

In the next step, we ask, “What can go wrong?” and we think about potential threats and risks.

### 3.3 Define threats

Let’s assume that we have an adversary trying to exploit our system through the prompt. Expecting the worst case is a good basis for this exercise. Try to “think like the threat actor we want to protect against.” They know the system is executing code that was generated on-the-fly; how could this potentially be abused?

This exercise leads to definitions of threats following a standardized pattern:

> A [threat source] with [pre-requisites],can [threat action],which leads to [threat impact],negatively impacting [goal] of [impacted assets].

This format helps to be very specific about a threat, which supports the definition of countermeasures later in the process.

In this article, for the sake of brevity, we focus on some of the most relevant threats and use a formal definition for the first one only as an example (we are very happy for you to reach out and tell us more):

* **(a)** A threat actor with access to the public-facing application can inject malicious prompts that overwrite existing system prompts, which leads to the LLM generating malicious code and the execution environment executing it, for example, to mine cryptocurrencies or execute lateral movement, impacting the isolation of the execution environment.
* **(b)** A threat actor can trick the execution environment into loading a CSV from a different user and thus gain access to sensitive data.
* **(c)** The adversary gains continuous access (persistence), waiting for the next interaction to send out information.
* **(d)** A threat actor can bypass authentication and authorization and thus use the solution as an anonymous user.
* **(e)** A threat actor can trick the execution environment into writing malicious log data in the execution environment, which impacts the log data management software.

In the next step, we rate the impact and possibility of our threats and add them to a risk matrix. In a real-world scenario, we recommend pre-defining the different ratings for both impact and probability, which helps ensure a consistent rating of the risks. Also, note that you can use any size (for example, 5x5) that fits your requirements.

![](https://cdn-images-1.medium.com/max/800/0*c9Xx33YwjBDWfFVT)

LLM-Generated Code Execution Risk Matrix.

The color coding is often bound to the expected handling of the risk. An example is:

* **Red**: Not acceptable; must apply mitigations (which either reduce the probability or impact)
* **Orange / Yellow**: Can be accepted or mitigated
* **Green**: Are accepted by default

In a real-world scenario, also other actions on risks are common: Avoidance and transfer. For simplicity reasons, we focus on acceptance and mitigation/reduction.

With the defined classification in mind, let us re-examine the data flow diagram for our application, color-coding the threats in our data flow to examine which parts of the application are at risk when asking an LLM to generate code.

![](https://cdn-images-1.medium.com/max/800/0*XLtVOAS-jpLp7Ozx)

Threat Propagation in LLM-Powered Data Analysis Architecture

The data flow diagram for the generative AI application. Data flow is extended by the possible threat associated with it. Threat vectors are color-coded according to the severity outlined in the threat severity table above. It is clear that through one attack vector — the malicious prompt — the adversary can realize multiple attacks downstream of the LLM. Therefore, guardrails should be our first level of defense, fighting against attackers at the root of this problem.

Note that in common threat modeling exercises, color-coding a data flow diagram is optional. Use it when it helps you understand or explain the risk profile of your application.

### 3.4 Define countermeasures

Now that we have an overview of the risks for our use case let’s create a list of countermeasures that address the risks that have been collected. The risk matrix becomes handy here: Every countermeasure reduces a risk’s probability and/or its impact. With that, the risk changes its position in the matrix. We will continue this exercise until we are fine with the updated risk ratings.

**(1) Ephemeral and isolated execution environment:**

* *Addresses*: a, b, c, e
* *Description*: We want to have an ephemeral execution environment that is torn down after it fulfills its duty. Also, we need to ensure that the adversary is not able to stand up long-lasting tasks (e.g., mining crypto currencies on our system). Therefore, we need to limit the execution time to what we would expect as the time it takes our model to generate the code and the answer. An ephemeral infrastructure allows us to “accept that the environment breaks during execution”.

**(2) No outgoing network connections:**

* *Addresses*: a, c
* *Description*: To contain the reach of the code, we prohibit any outgoing network requests except for the ones needed to download a CSV file.

**(3) Isolate language and libraries in the execution environment**:

* *Addresses*: a, c
* *Description*: We design and build the execution environment in a “final” setup, i.e. prevent installing additional libraries. Also, we enforce the usage of the programming language that we define.

**(4) Limit files that can be downloaded from the cloud storage:**

* *Addresses*: b
* *Description*: We ensure that only legitimate files can be loaded by the execution environment. Since we know the user and their files in the “Logic” component, we can e.g. create a temporary download URL for a single file to the execution environment.

**(5) Use Guardrails:**

* *Addresses*: a, b, c, d, e
* *Description*: An effective measure and first layer of defense is setting up guardrails for incoming prompts as well as outgoing generated code.

**(6) Sanitize logging of the execution environment:**

* *Addresses*: e
* *Description*: We sanitize all logs written by the execution environment.

Once we have a final set of countermeasures and are fine with the residual risks, we can start implementing them.

## 4. Architecture on AWS

Let’s walk step-by-step through the defined countermeasures and how AWS services can be used to achieve them.

### 4.1 Ephemeral and isolated execution environment

[AWS Lambda](https://aws.amazon.com/lambda/) is an excellent choice for implementing this.
It not only gives us the possibility to use an ephemeral infrastructure, but it also provides workload isolation of execution environments through [Firecracker](https://firecracker-microvm.github.io/).
Firecracker is a virtual machine monitor (VMM) that uses the Linux Kernel-based Virtual Machine (KVM) to create and manage microVMs, which provide enhanced security and workload isolation over traditional VMs while enabling the speed and resource efficiency of containers.
At the time of this writing in preview, we could also change the architecture entirely to use the “[Code interpretation](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-code-interpretation.html)” feature of [Agents for Amazon Bedrock](https://aws.amazon.com/bedrock/agents/).

### 4.2 No outgoing network connections

We run AWS Lambda in a [VPC](https://docs.aws.amazon.com/lambda/latest/dg/foundation-networking.html) and attach a [Security Group](https://docs.aws.amazon.com/vpc/latest/userguide/vpc-security-groups.html) to the Elastic Network Interface (ENI) to control any network connectivity. Assuming that CSV is stored in [Amazon S3](https://aws.amazon.com/s3/), we use a [Gateway Endpoint](https://docs.aws.amazon.com/vpc/latest/privatelink/vpc-endpoints-s3.html) to allow this access. In addition to that, we can optionally define [Network Access Control Lists (NACL)](https://docs.aws.amazon.com/vpc/latest/userguide/vpc-network-acls.html) to restrict network traffic further.

### 4.3 Isolate language and libraries in the execution environment

We can manage Python library dependencies either through directly including the required libraries in the Lambda deployment, or via [AWS Lambda Layers](https://docs.aws.amazon.com/lambda/latest/dg/chapter-layers.html).
For both cases, to prevent that a threat actor download additional libraries, outgoing network access for the Lambda is blocked (see 4.2).

### 4.4 Limit files that can be downloaded from the cloud storage

We can use short-living, dynamically created [Amazon S3 pre-signed URLs](https://docs.aws.amazon.com/AmazonS3/latest/userguide/ShareObjectPreSignedURL.html) to download the CSV files from Amazon S3.
In combination with a [bucket policy](https://docs.aws.amazon.com/AmazonS3/latest/userguide/example-bucket-policies.html), which only allows access to the bucket to the before-mentioned Gateway endpoint, we can additionally lock down our S3 bucket. You can find more information about this [pattern in the AWS documentation](https://docs.aws.amazon.com/prescriptive-guidance/latest/patterns/set-up-private-access-to-an-amazon-s3-bucket-through-a-vpc-endpoint.html).

### 4.5 Use Amazon Bedrock Guardrails

In [Amazon Bedrock](https://aws.amazon.com/bedrock/), we can use [Guardrails](https://aws.amazon.com/bedrock/guardrails/) to implement safeguards to the application requirements and responsible AI policies. In our example, we would for example use the “prompt attack” guardrail for prompts.

### 4.6 Sanitize logging of the execution environment

We can use the [data protection policies](https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/mask-sensitive-log-data.html) feature of [Amazon CloudWatch](https://aws.amazon.com/cloudwatch/) to sanitize our logs.

## 5. Conclusion

As Large Language Models (LLMs) continue to revolutionize various applications, it is crucial to address the novel security challenges they introduce. This blog has explored common attack vectors against LLMs, such as insecure output handling, sensitive information disclosure, and excessive agency. By following a structured threat modeling approach and implementing robust mitigation strategies, organizations can securely leverage the power of LLMs while minimizing risks.

Key strategies discussed include using ephemeral and isolated execution environments, limiting network connections, isolating languages and libraries, implementing guardrails, and sanitizing logs. Additionally, leveraging AWS services like AWS Lambda, Amazon S3, and Amazon CloudWatch can streamline the secure execution of LLM-generated code.

As the adoption of LLMs increases, developers and organizations should regularly review and update their security protocols to address new potential attack techniques. By prioritizing security and responsible AI practices, we can harness the full potential of LLMs while safeguarding against misuse and protecting sensitive data.

Explore the AWS services and best practices outlined in this blog to secure your LLM-powered applications. Stay informed about the latest developments in LLM security by following trusted cybersecurity resources and actively participating in community discussions.

**Special mentions:**

* Special thanks to Massimiliano Angelino, who created the first “talk to your CSV” implementation on Amazon Bedrock that inspired this blog, and to Luca Perrozzi for the valuable feedback and review.

Further reading:

* [Build safe and responsible generative AI applications with guardrails](https://aws.amazon.com/blogs/machine-learning/build-safe-and-responsible-generative-ai-applications-with-guardrails/)
* [LLM security primer blog post](https://kleiber.me/blog/2024/03/17/llm-security-primer/)
* [Blog on hacking LLMs](https://blog.gopenai.com/hacking-llms-101-attacks-on-llms-186e3ebff0cb)
* [LLM security blog](https://shellcode.blog/llm1/)
* [Blog on prompt hacking of LLMs](https://www.comet.com/site/blog/prompt-hacking-of-large-language-models/)
* [Blog on LLM security digest](https://adversa.ai/blog/llm-security-digest-top-llm-attacks-vc-initiatives-llm-hacking-incidents-and-research-papers-in-november/)]]></content:encoded>
            <author>Philipp Kaindl</author>
            <category>Security</category>
            <category>LLM</category>
            <category>AWS</category>
            <category>Agents</category>
        </item>
        <item>
            <title><![CDATA[Building a RAG Agent with LangGraph, LLaMA3-70b, and Scaling with Amazon Bedrock]]></title>
            <link>https://philippkaindl.com/posts/rag-agent-langgraph-llama3-bedrock</link>
            <guid isPermaLink="false">https://philippkaindl.com/posts/rag-agent-langgraph-llama3-bedrock</guid>
            <pubDate>Wed, 22 May 2024 00:00:00 GMT</pubDate>
            <description><![CDATA[Building a local RAG agent using LLaMA3 and LangGraph with adaptive routing, web search fallback, and self-correction — then scaling with Amazon Bedrock.]]></description>
            <content:encoded><![CDATA[
Explore how to build a local Retrieval-Augmented Generation (RAG) agent using LLaMA3, a powerful language model from Meta. This RAG agent integrates several cutting-edge ideas from recent research to enhance its capabilities.

* **Adaptive Routing** ([paper](https://arxiv.org/abs/2403.14403)): Based on the Adaptive RAG paper, our agent intelligently routes questions to different retrieval approaches, leveraging the strengths of each method for optimal performance.
* **Fallback to Web Search** ([paper](https://arxiv.org/pdf/2401.15884.pdf)): Inspired by the Corrective RAG paper, our agent falls back to web search if the retrieved documents are not relevant, ensuring comprehensive coverage.
* **Self-Correction** ([paper](https://arxiv.org/abs/2310.11511)): Incorporating concepts from the Self-RAG paper, our agent identifies and corrects hallucinations or incomplete answers.

We'll use LangGraph to build complex, multi-step workflows that involve language models and other components. By combining these research ideas, our agent will provide accurate and informative responses to challenging queries.
This blog post expands on the work presented in [Langchain-ai's](https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_rag_agent_llama3_local.ipynb) notebook, illustrating how you can scale with cloud processing power. We will port the original notebook to utilize Amazon Bedrock for LLM inference to scale our embedding and text generation capabilities. The choice of vector stores (local chromaDB) will remain unchanged, and we will discover how to scale this part in future blog posts.


**A brief intro to Amazon Bedrock:**Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs) from leading AI companies like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon through a single API, along with a broad set of capabilities you need to build generative AI applications with security, privacy, and responsible AI. Using Amazon Bedrock, you can easily experiment with and evaluate top FMs for your use case, privately customize them with your data using techniques such as fine-tuning and Retrieval Augmented Generation (RAG), and build agents that execute tasks using your enterprise systems and data sources. Since Amazon Bedrock is serverless, you don't have to manage any infrastructure, and you can securely integrate and deploy generative AI capabilities into your applications using the AWS services you are already familiar with.

Get started by setting up the environment and installing the required dependencies. We'll walk through the step-by-step process of building the RAG agent, covering document retrieval, question routing, answer generation, and self-correction mechanisms. Code examples and detailed explanations will be provided to ensure a smooth learning experience. By the end, you'll have a solid understanding of how to build a sophisticated RAG agent tailored to your specific use case.

Let's dive into the exciting world of RAG agents powered by LLaMA3 and cutting-edge research!

First, we need to install some libraries if we do not have them.

```
! pip install -U langchain_community tiktoken chromadb langchain langgraph tavily-python langchain-aws
```

## Setting Up LangSmith Tracing

LangSmith is a powerful tracing tool that allows you to monitor and debug LangChain-based applications. Setting up a connection to LangSmith in this notebook will give us valuable insights into the inner workings of our RAG agent.

To enable tracing, set the `LANGCHAIN_TRACING_V2` environment variable to "true". Then specify the `LANGCHAIN_ENDPOINT` and `LANGCHAIN_API_KEY` with the appropriate values for your LangSmith account. If you don't have an account, sign up at [LangSmith Signup](https://www.langchain.com/langsmith) and obtain your API key.

If you are looking for a self-hosted alternative to LangSmith, check out [Langfuse](https://langfuse.com/).

Furthermore, we will set our keys to connect to the AWS cloud.

If you work on an AWS resource, like Amazon SageMaker, your Execution Role already has the default connection set for you. To give you a broad place to play with the tech, find below a method for Google Colab as well as your local computer (with `[AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/cli-chap-configure.html)` setup through the `aws configure` command).

```
import os
```

Tracing allows LangSmith to monitor and record the execution of our LangChain components. This is incredibly useful for understanding how the different elements interact, identifying potential bottlenecks or issues, and optimizing the overall performance of our RAG agent.

As you build the agent, you can use LangSmith's insights to make adjustments and ensure it functions as intended.

## Verifying ChromaDB Installation

Before we proceed with building our RAG agent, it's crucial to ensure that the ChromaDB library is installed and ready to use. ChromaDB is a vector database that we'll use for efficient document storage and retrieval.

ChromaDB will serve as the backbone of our document retrieval system, allowing us to store and retrieve relevant documents based on vector similarity. By checking its installation status, we can confirm that our development environment is set up correctly and ready for the next steps.

We deliberately did not change the local vector store to keep the notebook as accessible as possible.

```
!pip show chromadb
```

## Creating a Bedrock Runtime Client

We'll create a Bedrock runtime client to connect to the Amazon Bedrock service. Bedrock, a fully managed service by AWS, allows developers to build and deploy generative AI models like large language models (LLMs). This client will enable us to leverage pre-trained LLMs from Amazon, such as the powerful LLaMA3 model from Meta.

Connecting to Bedrock is crucial for building our scalable and secure RAG agent, as it provides the necessary language model for generation capabilities. With the Bedrock runtime client in place, we can integrate LLaMA3 into our workflow and use its advanced natural language processing capabilities to generate accurate responses.

```
import boto3
import json, re

# Create a bedrock runtime client in us-west-2
bedrock_rt = boto3.client("bedrock-runtime",
                          region_name=aws_region,
                          aws_access_key_id=aws_access_key_id,
                          aws_secret_access_key=aws_secret_access_key,
)
```

If you use short-term credentials, make sure to include the tokens in the call to initialize the `boto3.client`.

```
# bedrock_rt = boto3.client("bedrock-runtime",
#                         region_name=aws_region, #e.g. us-west-2
#                         aws_access_key_id=aws_access_key_id,
#                         aws_secret_access_key=)
```

For a primer on authentication with AWS boto3, check out the following [documentation](https://docs.aws.amazon.com/sdkref/latest/guide/access-users.html).

## Amazon Bedrock with LangChain

After we have established a connection to the Amazon Bedrock service by creating a Bedrock runtime client, we can now connect the client to the LangChain Embeddings.

Bedrock provides a variety of models that we can choose from; please see a list of selected options below (current state in us-west-2 @ 17th May 2024)

```
from langchain_community.embeddings import BedrockEmbeddings

# Choose from a set of embedding models hosted on Amazon Bedrock
# Provider  | Model Name                        | Model ID
# -------------------------------------------------------------------------------
# Amazon    | Titan Embeddings G1 - Text 1.x   | amazon.titan-embed-text-v1
# Amazon    | Titan Embedding Text v2 1.x       | amazon.titan-embed-text-v2:0
# Cohere    | Embed English 3.x                 | cohere.embed-english-v3
# Cohere   | Embed Multilingual 3.x           | cohere.embed-multilingual-v3
embedding_model_id = "amazon.titan-embed-text-v2:0"
embeddings = BedrockEmbeddings(client=bedrock_rt, model_id=embedding_model_id)
vector = embeddings.embed_documents(
    ["This is a content of the document", "This is another document"]
)
```

Let's inspect the first part of the vector we got for the sentence "This is a content of the document", as well as the embedding dimension.

```
print(vector[0][:50])
print(len(vector[0]))
```

```
[-0.07963294, 0.022934286, 0.035994086, -0.004260362, 0.005773388, -0.0063308184, 0.031534642, -0.017678512, 0.034082897, 0.024049146, -0.028030794, 0.071669646, 0.026278868, -0.004519169, -0.023252817, 0.057654247, -0.041409127, 0.027234465, 0.03169391, 0.041090596, 0.07740321, 0.025960337, -0.020545298, 0.01951007, -0.04172766, -0.022137957, 4.9770583e-06, 0.01003375, 0.02388988, -0.022297222, 0.038701605, 0.050646547, 0.0033644915, -0.072625235, 0.0793144, 0.110849045, 0.066573136, -0.0055743055, -0.01274127, -0.021341627, 0.012582004, -0.015130258, -0.0033445833, 0.023252817, 0.030260516, 0.048416823, 0.023252817, 0.022934286, 0.07230671, 0.008321642]
1024
```

## Setting Up a Document Retrieval System with ChromaDB

To build our Retrieval Augmented Generation (RAG) agent, we start by setting up a document retrieval system using ChromaDB, a robust local vector database.

We'll first load a set of URLs related to LLM agents, prompt engineering, and adversarial attacks. These documents form our initial knowledge corpus. For simplicity, we use a limited number of URLs here, but in a real-world scenario, you would include a comprehensive set of documents relevant to your use case.

Next, we define a custom embedding function called `MyEmbeddingFunction` using `BedrockEmbeddings` from `langchain_community`. This function converts text documents into vector embeddings for efficient storage and retrieval in ChromaDB.

We then load and split the content from the URLs into smaller chunks using `RecursiveCharacterTextSplitter`. This improves retrieval accuracy by allowing the system to fetch relevant document portions rather than whole documents.
It shall be noted that the `RecursiveCharacterTextSplitter` is one of the simplest forms of chunking a document, and more sophisticated methods might yield a better result. Check out [this medium post](https://medium.com/%40anuragmishra_27746/five-levels-of-chunking-strategies-in-rag-notes-from-gregs-video-7b735895694d) for a short but comprehensive overview.

With the document chunks and custom embedding function, we create a ChromaDB vector store called `vectorstore`. This store acts as our persistent storage for the document embeddings, enabling fast and efficient retrieval based on vector similarity.

Finally, we create a `retriever` object from the `vectorstore`, which will fetch relevant documents based on user queries. This retriever is critical for our RAG agent, as it fetches the most relevant information for generating accurate and informative responses.

By setting up this document retrieval system, we are preparing our RAG agent to effectively utilize external knowledge sources and provide context-aware responses to user queries.

### Defining a Custom Embedding Function with `BedrockEmbeddings`

As the landscape of Generative AI is advancing at unprecedented speeds, we often are left with a package that does not support our favorite model provider. In our case, chromaDB does not support Amazon Bedrock at the time of writing. However, we simply need to adapt the base `EmbeddingFunction` class of chromaDB to get it to work.

We utilize `BedrockEmbeddings` from the `langchain_community` library to create a custom embedding function.

The custom class `MyEmbeddingFunction` inherits from `EmbeddingFunction` provided by ChromaDB and includes methods `embed_query` and `embed_documents`. These methods convert queries and documents into vector embeddings using Bedrock Embedding models so we can later match the query with the document vectors.

By implementing this custom embedding function, we ensure seamless integration of Bedrock Embeddings into our document retrieval system, enhancing our RAG agent's ability to retrieve relevant documents based on vector similarity.

```
from chromadb import Documents, EmbeddingFunction, Embeddings
```

### Setting Up the Document Retrieval System with ChromaDB

First, load a set of URLs related to LLM agents, prompt engineering, and adversarial attacks. These URLs serve as our initial corpus.

Load the content from the URLs and split documents into smaller chunks using `RecursiveCharacterTextSplitter`. This improves retrieval by allowing relevant document portions rather than entire documents.

Create a ChromaDB vector store called `vectorstore` to store document embeddings. Generate a `retriever` from the vector store to fetch relevant documents based on user queries.

```
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0.2
)
doc_splits = text_splitter.split_documents(docs_list)

embedding_function = MyEmbeddingFunction(client = bedrock_rt,
                                         region_name=aws_region,
                                         model_id=embedding_model_id)
# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    embedding=embedding_function,
    collection_name="rag-chroma",
)
retriever = vectorstore.as_retriever()
```

## The Graph we are building with LangGraph

Before we start, we want to give a good understanding of the graph we want to build with LangGraph:

![](https://cdn-images-1.medium.com/max/800/1*wnKYNY0TWCQLyIAkc5S68g.png)

## Implementing the Retrieval Grader

The "retrieval grader" is crucial for ensuring the relevance of retrieved documents to the user's question. It filters out irrelevant or erroneous results before generating an answer.

This grader uses a language model (LLM) and a prompt template. The prompt instructs the LLM to decide if a document contains keywords related to the user's question, providing a binary 'yes' or 'no' score in JSON format.

We define the prompt using `PromptTemplate` from LangChain, which dynamically includes the user's question and retrieved document. `ChatBedrock`, leveraging Meta's LLaMA3 model via Amazon Bedrock, performs the grading task.

To test, provide a sample question and use the `retriever` to get a document, then invoke the `retrieval_grader` to get the binary score.

By incorporating this grader, we ensure that only relevant documents are considered for answer generation, improving the overall response quality.

```
### Retrieval Grader
from langchain.prompts import PromptTemplate
from langchain_aws import ChatBedrock
from langchain_core.output_parsers import JsonOutputParser

llm = ChatBedrock(
    client = bedrock_rt,
    model_id="meta.llama3-70b-instruct-v1:0",
    model_kwargs={"temperature": 0.0},
)

prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance
    of a retrieved document to a user question. If the document contains keywords related to the user question,
    grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
    Provide the binary score as a JSON with a single key 'score' and no premable or explaination.
     <|eot_id|><|start_header_id|>user<|end_header_id|>
    Here is the retrieved document: \n\n {document} \n\n
    Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """,
    input_variables=["question", "document"],
)

retrieval_grader = prompt | llm | JsonOutputParser()
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
```

```
{'score': 'yes'}
```

## Building the Control Flow with LangGraph

LangGraph allows us to define a graph-based workflow for our RAG agent, integrating document retrieval, question routing, answer generation, and self-correction into an efficient pipeline.

Key steps include:

1. **Question Routing**: Deciding whether the question should go to the document retrieval system or a web search.
2. **Document Retrieval**: Fetching relevant documents from the vector database.
3. **Document Grading**: Evaluating the relevance of retrieved documents.
4. **Answer Generation**: Generating an answer using the retrieved documents and the language model.
5. **Hallucination Grading**: Ensuring the generated answer is grounded in the retrieved documents.
6. **Answer Grading**: Checking the usefulness of the generated answer.
7. **Web Search**: Supplement with a web search if necessary.

LangGraph lets us seamlessly integrate these steps into a modular, adaptable workflow, enhancing the agent's ability to handle diverse queries.

### Generating Answers

Using `PromptTemplate` from LangChain, we create a prompt for the assistant to answer questions concisely, utilizing retrieved context.

The prompt converts the user's question and relevant documents into an input format for the language model. The resulting answer is processed and returned.

Here's the process:

1. Define the prompt template.
2. Format the documents as context.
3. Chain the prompt, LLM, and output parser to generate the response.

By integrating this chain, the RAG agent generates accurate, concise answers based on the retrieved context.

```
### Generate
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
    Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
    Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
    Question: {question}
    Context: {context}
    Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question", "document"],
)

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
question = "agent memory"
docs = retriever.invoke(question)
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)
```

```
In the context of LLM-powered autonomous agents, memory refers to the processes used to acquire, store, retain, and later retrieve information. There are two main types of memory: short-term memory, which utilizes the models in-context learning, and long-term memory, which provides the capability to retain and recall information over extended periods, often by leveraging an external vector store and fast retrieval.
```

### Setting Up the Hallucination Grader

To ensure that generated answers are grounded in retrieved documents, use a hallucination grader. This grader assesses whether the answer is supported by the provided facts.

Define a prompt using `PromptTemplate` to instruct the LLM to output a binary 'yes' or 'no' score in JSON, indicating if the answer is grounded in the documents.

Invoke the grader with the generated answer and documents to validate the accuracy.

Incorporating this grader helps maintain the factual accuracy of the RAG agent's responses, ensuring high-quality outputs.

```
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

### Hallucination Grader

# Prompt
prompt = PromptTemplate(
    template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether
    an answer is grounded in / supported by a set of facts. Give a binary score 'yes' or 'no' score to indicate
    whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
    single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>
    Here are the facts:
    \n ------- \n
    {documents}
    \n ------- \n
    Here is the answer: {generation}  <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["generation", "documents"],
)
hallucination_grader = prompt | llm | JsonOutputParser()
hallucination_grader.invoke({"documents": docs, "generation": generation})
```

```
{'score': 'yes'}
```

### Establishing the Answer Grader

To ensure answers are helpful and resolve the user's question, use an answer grader.

Define a prompt with `PromptTemplate` that directs the LLM to output a binary 'yes' or 'no' score, indicating the answer's usefulness.

Invoke the grader with the generated answer and user question to validate the relevance.

This step ensures the RAG agent produces not only accurate but also helpful and relevant responses.

```
### Answer Grader

# Prompt
prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
    answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
    useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
     <|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
    \n ------- \n
    {generation}
    \n ------- \n
    Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["generation", "question"],
)
answer_grader = prompt | llm | JsonOutputParser()
answer_grader.invoke({"question": question, "generation": generation})
```

```
{'score': 'yes'}
```

### Creating the Router

The router decides whether to direct the question to document retrieval or perform a web search.

Use `PromptTemplate` to instruct the LLM to make this decision based on the question's relevance to predefined topics.

Invoke the router with the user question and route accordingly.

Efficient routing ensures the most appropriate and effective method is used for each question, optimizing the RAG agent's performance.

The following prompt template has the topics of the knowledge base hard coded. Ideally, we would add metadata fields containing the topic of each document that we add to our vector store to be able to dynamically expand the correct topics so that the router is up to date.

```
### Router

# Topics should be dynamically fetched and updated whenever a new topic gets put in the vector store
topics = ["LLM Agents, Prompt Engineering, Adversarial Attacks on LLMs"]

prompt = PromptTemplate(
    template=f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a
    user question to a vectorstore or web search. Use the vectorstore for questions on the following topics: {', '.join(topics)}. You do not need to be stringent with the keywords
    in the question related to these topics. Otherwise, use web-search. Give a binary choice 'web_search'
    or 'vectorstore' based on the question. Return the a JSON with a single key 'datasource' and
    no premable or explaination. Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
    input_variables=["question"],
)
question_router = prompt | llm | JsonOutputParser()
question = "llm agent memory"
docs = retriever.get_relevant_documents(question)
doc_txt = docs[1].page_content
print(question_router.invoke({"question": question}))
```

```
{'datasource': 'vectorstore'}
```

### Integrating Web Search with Tavily

Set up environmental variables for the Tavily API and initialize the `TavilySearchResults` tool.

When necessary, perform a web search to gather additional documents. Invoke the tool with the user query and append the results to the existing documents.

Integrating web search complements the document retrieval system, covering a broader range of information for generating responses.

```
### Search
os.environ["TAVILY_API_KEY"] = tavily_ai_api_key

from langchain_community.tools.tavily_search import TavilySearchResults

web_search_tool = TavilySearchResults(k=3)
```

At this point, it is worth mentioning that the LangChain Tavily Search tool is leaving out a lot of the power of Tavily-AI API.

In a future post, we will explore how to fully utilize the Tavily AI API search with query rewriting to be used by an Agent in conjunction with Amazon Bedrock Agents.

### Defining the Control Flow with LangGraph Nodes and Edges

Implement nodes representing key actions: document retrieval, document grading, web search, and answer generation.

Define conditional edges for decision-making: route the question, decide on document relevance, and grade the generated answer.

Set up the workflow graph with entry points, nodes, and edges to ensure a logical progression through the RAG agent's steps.

```
from typing_extensions import TypedDict
from typing import List

### State
class GraphState(TypedDict):
    """
    Represents the state of our graph.
    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """
    question: str
    generation: str
    web_search: str
    documents: List[str]

from langchain.schema import Document
### Nodes

def retrieve(state):
    """
    Retrieve documents from vectorstore
    Args:
        state (dict): The current graph state
    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]
    # Retrieval
    documents = retriever.invoke(question)
    return {"documents": documents, "question": question}

def generate(state):
    """
    Generate answer using RAG on retrieved documents
    Args:
        state (dict): The current graph state
    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]
    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}

def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question
    If any document is not relevant, we will set a flag to run web search
    Args:
        state (dict): The current graph state
    Returns:
        state (dict): Filtered out irrelevant documents and updated web_search state
    """
    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]
    # Score each doc
    filtered_docs = []
    web_search = "No"
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "document": d.page_content}
        )
        grade = score["score"]
        # Document relevant
        if grade.lower() == "yes":
            print("---GRADE: DOCUMENT RELEVANT---")
            filtered_docs.append(d)
        # Document not relevant
        else:
            print("---GRADE: DOCUMENT NOT RELEVANT---")
            # We do not include the document in filtered_docs
            # We set a flag to indicate that we want to run web search
            web_search = "Yes"
            continue
    return {"documents": filtered_docs, "question": question, "web_search": web_search}

def web_search(state):
    """
    Web search based based on the question
    Args:
        state (dict): The current graph state
    Returns:
        state (dict): Appended web results to documents
    """
    print("---WEB SEARCH---")
    question = state["question"]
    documents = state["documents"]
    # Web search
    docs = web_search_tool.invoke({"query": question})
    web_results = "\n".join([d["content"] for d in docs])
    web_results = Document(page_content=web_results)
    if documents is not None:
        documents.append(web_results)
    else:
        documents = [web_results]
    return {"documents": documents, "question": question}

### Conditional edge

def route_question(state):
    """
    Route question to web search or RAG.
    Args:
        state (dict): The current graph state
    Returns:
        str: Next node to call
    """
    print("---ROUTE QUESTION---")
    question = state["question"]
    print(question)
    source = question_router.invoke({"question": question})
    print(source)
    print(source["datasource"])
    if source["datasource"] == "web_search":
        print("---ROUTE QUESTION TO WEB SEARCH---")
        return "websearch"
    elif source["datasource"] == "vectorstore":
        print("---ROUTE QUESTION TO RAG---")
        return "vectorstore"

def decide_to_generate(state):
    """
    Determines whether to generate an answer, or add web search
    Args:
        state (dict): The current graph state
    Returns:
        str: Binary decision for next node to call
    """
    print("---ASSESS GRADED DOCUMENTS---")
    question = state["question"]
    web_search = state["web_search"]
    filtered_documents = state["documents"]
    if web_search == "Yes":
        # All documents have been filtered check_relevance
        # We will re-generate a new query
        print(
            "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---"
        )
        return "websearch"
    else:
        # We have relevant documents, so generate answer
        print("---DECISION: GENERATE---")
        return "generate"

### Conditional edge

def grade_generation_v_documents_and_question(state):
    """
    Determines whether the generation is grounded in the document and answers question.
    Args:
        state (dict): The current graph state
    Returns:
        str: Decision for next node to call
    """
    print("---CHECK HALLUCINATIONS---")
    question = state["question"]
    documents = state["documents"]
    generation = state["generation"]
    score = hallucination_grader.invoke(
        {"documents": documents, "generation": generation}
    )
    grade = score["score"]
    # Check hallucination
    if grade == "yes":
        print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
        # Check question-answering
        print("---GRADE GENERATION vs QUESTION---")
        score = answer_grader.invoke({"question": question, "generation": generation})
        grade = score["score"]
        if grade == "yes":
            print("---DECISION: GENERATION ADDRESSES QUESTION---")
            return "useful"
        else:
            print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
            return "not useful"
    else:
        pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
        return "not supported"

from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("websearch", web_search)  # web search
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generatae
```

Designing the control flow as a structured graph ensures a coherent and efficient pipeline for the RAG agent's operations.

```
# Build graph
workflow.set_conditional_entry_point(
    route_question,
    {
        "websearch": "websearch",
        "vectorstore": "retrieve",
    },
)

workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "websearch": "websearch",
        "generate": "generate",
    },
)
workflow.add_edge("websearch", "generate")
workflow.add_conditional_edges(
    "generate",
    grade_generation_v_documents_and_question,
    {
        "not supported": "generate",
        "useful": END,
        "not useful": "websearch",
    },
)
```

### Compiling and Testing the Workflow

Compile the LangGraph workflow and test it with sample inputs.

Provide questions to the compiled application and track the output through each step.

Debug and refine based on the results to ensure the RAG agent performs as expected and delivers accurate and relevant answers.

Compiling and iterative testing ensure the RAG agent meets quality and performance standards before deployment.

```
# Compile
app = workflow.compile()

# Test
from pprint import pprint
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Finished running: {key}:")
pprint(value["generation"])
```

```
---ROUTE QUESTION---
What are the types of agent memory?
{'datasource': 'vectorstore'}
vectorstore
---ROUTE QUESTION TO RAG---
---RETRIEVE---
'Finished running: retrieve:'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
'Finished running: grade_documents:'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
'Finished running: generate:'
('The types of agent memory are: Sensory Memory, Short-term memory, and '
 'Long-term memory.')
```

### Testing the WebSearch

For a test on the web search, we utilize the question, "Who are the Bears expected to draft first in the NFL draft?". As this request needs up-to-date information and should not be based on the knowledge base, we expect it to go directly to a web search.

```
# Compile
app = workflow.compile()

inputs = {"question": "Who are the Bears expected to draft first in the NFL draft?"}
for output in app.stream(inputs):
    for key, value in output.items():
        pprint(f"Finished running: {key}:")
pprint(value["generation"])
```

```
---ROUTE QUESTION---
Who are the Bears expected to draft first in the NFL draft?
{'datasource': 'vectorstore'}
vectorstore
---ROUTE QUESTION TO RAG---
---RETRIEVE---
'Finished running: retrieve:'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---
'Finished running: grade_documents:'
---WEB SEARCH---
'Finished running: websearch:'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
'Finished running: generate:'
('The Bears are expected to draft USC quarterback Caleb Williams with the No. '
 '1 pick in the 2024 NFL Draft. This is according to multiple reports and '
 'projections. Williams was widely considered the top prospect in a draft '
 'class loaded with talented quarterbacks.')
```

### Further Improvements

If you look to improve the above code further, here are a few suggestions:

* **Expand the Web Search Capabilities**
  Currently, the web search capabilities only use the Tavily-AI Langchain tool, although Tavily AI has much more functionality to offer. You can, for example, include and exclude certain domains and websites, add the current date to target a specific timeframe of the results, and include a generic search on Google or DuckDuckGo to complement the Tavily-AI search.
* **Retrieve the topics dynamically**
  We hard-code the topics represented in our knowledge base. If we know the topics ahead of time, this is a perfectly valid approach. However, for more fine-grained retrieval, we should consider including vector metadata in our vector store.
* **Change the knowledge base to a more scaleable option**
  As the chromaDB vector store is running on our local instance, we have a scalability limit that is imposed upon us. We can use products like [Pinecone](https://www.pinecone.io/?utm_term=pinecone%20db&utm_campaign=Brand+-+US/Canada&utm_source=adwords&utm_medium=ppc&hsa_acc=3111363649&hsa_cam=21023356007&hsa_grp=156209469342&hsa_ad=690982079000&hsa_src=g&hsa_tgt=kwd-1627713670685&hsa_kw=pinecone%20db&hsa_mt=p&hsa_net=adwords&hsa_ver=3&gad_source=1&gclid=CjwKCAjwr7ayBhAPEiwA6EIGxNyLyx_YQepdzzLAB-XwuFN5aUL5y5PBCXr2gZBm_tBtqz3BwWF_FxoCm64QAvD_BwE) or [Amazon OpenSearch](https://aws.amazon.com/opensearch-service/), just to name a few.]]></content:encoded>
            <author>Philipp Kaindl</author>
            <category>RAG</category>
            <category>Agents</category>
            <category>LangGraph</category>
            <category>Bedrock</category>
        </item>
        <item>
            <title><![CDATA[Natural Language to SQL: Experiments with CodeLlama on Amazon SageMaker — Part 2]]></title>
            <link>https://philippkaindl.com/posts/nl2sql-codellama-experiments-part-2</link>
            <guid isPermaLink="false">https://philippkaindl.com/posts/nl2sql-codellama-experiments-part-2</guid>
            <pubDate>Wed, 10 Apr 2024 00:00:00 GMT</pubDate>
            <description><![CDATA[Experiments with CodeLlama for NL2SQL: comparing DDL vs SQL-PaLM schema descriptions, chunk length effects, and repetition penalty tuning.]]></description>
            <content:encoded><![CDATA[
Part 1 of the series explores fine-tuning a CodeLlama model for NL2SQL tasks using QLoRA on Amazon SageMaker. The blog guides you through setting up the development environment, preparing the fine-tuning dataset, and dives into the importance of tokenization and chunk length. It concludes with detailed steps on fine-tuning CodeLlama using QLoRA and deploying the model on SageMaker. Finally, we compare the improvements made over the base models by only a few epochs of training. If you missed part 1 and are looking for a practical how-to guide, follow this [link](/posts/nl2sql-fine-tuning-codellama-part-1).

In part 2, we will examine the experiments we undertook with CodeLlama more closely, as they will serve as helpful guidelines whenever we undertake such an endeavour with other models.

## TLDR what you can find in this blog

* A comparison between different database descriptions and their pros and cons ([SQL-PaLM](https://arxiv.org/abs/2306.00739)-inspired and Data Definition Language). A quick check to see if the model is able to generalize between both database descriptions.
* Investigations on the chunk length of the training batches and its link to the performance of the trained model.
* Influence of the repetition penalty on the performance of CodeLlama.

## Exploring Database Description Methods DDL vs SQL-PaLM

First, we’re going to examine the differences between the Data Definition Language (DDL) and the SQL-PALM-based database schema descriptions that we used in [Part 1](/posts/nl2sql-fine-tuning-codellama-part-1).

DDL, or Data Definition Language, is a subset of SQL (Structured Query Language) used in database management systems. It focuses specifically on defining, modifying, and removing database structures but not on manipulating the data within these structures. If you are not too much into SQL — like I was at the beginning of this journey — here’s a brief overview:

### Key Features of DDL for NL2SQL

DDL has three key features that we can use to our advantage:

**Creating Database Structures:**DDL is primarily used to define the structure of a database and its elements, such as tables, indexes, and constraints. Here’s a simple example of a DDL command to create a table:

```
CREATE TABLE Students (
```

This `CREATE TABLE` statement defines a new table named `Students` with columns for student ID, name, and age, and declares the `StudentID` column as the primary key.

**Defining Constraints:**DDL statements include constraints to define rules for data in the database, such as primary keys, foreign keys, unique constraints, and check constraints. Revisiting the example from above, the table `Courses` and `Students` are linked through their `Enrollments` via foreign and primary keys. Therefore, we have all the information needed in the DDL to construct our queries.

```
CREATE TABLE Students (
```

### Ease of recreating a valid DDL statement for the database schema
DDL allows for uniform syntax across different databases, making it easier to target specific flavors of SQL than SQL-PaLM schema definition.

This ease of creation naturally tends to the subjective observation that more datasets utilize the DDL format for describing the databases. If your LLM has already been pre-trained on data for SQL, the might be a higher chance it has been trained on DDL than on SQL-PALM datasets.

Additionally, you can extend the type definition for the tables with example data or a description of the content of the table and column. This led to a significant improvement when experimenting with closed-sourced models like Anthropic Claude 2 or OpenAI GPT-4. One example of how to give extra context and business knowledge through comments can be seen below:

```
CREATE TABLE Students (
```

Although this is only a qualitative anecdote, I encourage the reader to add those examples and comments to their training datasets and observe the difference. Please leave a comment if you made the same observation.

### Token length for DDL and SQL-PaLM

An essential aspect of a database description framework for the NL2SQL tasks is the number of tokens needed to describe our table. Therefore, we are inspecting and comparing the token lengths for the databases in the training set of the Spider SQL dataset, once for SQL-PaLM and DDL format.

![](https://cdn-images-1.medium.com/max/800/1*RfhA-pda8lYa9db8S619-Q.png)

Token length for both the SQL-PaLM and the DDL description of the databases of the training set of SPIDER SQL. The DDL needs more tokens to describe the dataset.

The plot clearly indicates that the DDL description — just like the SQL-PaLM — has distinct spikes for every database and the associated training questions (The training set has 180 unique databases). However, the DDL version seems to consume considerably more tokens to describe the underlying structure of the database.

To further inspect the tokenization, we will visualize the tokenization by color coding it. For the SQL-PALM notation:

![](https://cdn-images-1.medium.com/max/800/1*CP7y9p2wyZiMPQdBjkSwPA.png)

SQL-PALM `descriptions tokens are visualized by color-coding them._ indicate a space for better readability.`

And for the DDL-notation:

![](https://cdn-images-1.medium.com/max/800/1*tySg2xT_unCwfJ8vOEV98g.png)

`DDL descriptions tokens are visualized by color-coding them. Clear to observe is the increased token overall. <0x0A> is the Hex for "\n" newline. You can check the tokenizer vocabulary via the following [link](https://huggingface.co/codellama/CodeLlama-7b-hf/raw/main/tokenizer.json).`

We can observe, that for the DDL we need a lot more tokens to represent the table. It is unclear if it would be beneficial to clear out `\n` as well as `“` and have a conversion function to migrate it to a more human readable, natural text format — more tests are necessary on this end.

### Model performance depending on the database description

First, let’s examine the performance of the models after training them for a few epochs with the different schema descriptions:

![](https://cdn-images-1.medium.com/max/800/1*JhL6AU3kda6UK728d2mQRA.png)

Validation Set Execution Match Accuracy Performance of different training runs

Analyzing the performance scores from the table provided in the context of the database schema descriptions (DDL vs. SQL-PALM) and other training parameters, we can draw several insights:

### Performance Based on Schema Description

* **SQL-PALM** **Descriptions:** The runs using SQL-PALM descriptions show a performance score range from 38.59% to 60.35%. The lower end of this range is represented by the base models, with the 13B and 7B models achieving 38.59% and 38.88%, respectively. This finding indicates that neither of the base models didn’t undergo extensive training in Text-2-SQL.
* **DDL Descriptions:** The runs with DDL descriptions show a marked improvement, with scores of 56.67% for the 7B model at five epochs and an even higher score of 59.38% for the 13B model at already three epochs. This suggests that the more verbose and detailed DDL format might provide a richer context or structure that the model can leverage to improve performance. However, this comes at the cost of a higher token usage by the model.

### Influence of Training Duration

1 vs. 5 Epochs: When looking at the 7B model trained for one epoch with SQL-PALM, we see a performance score of 49.71%. This score increases slightly to 50.87% when the model is trained for five epochs, indicating a modest improvement with additional training for SQL-PALM data.

### Effect of Model Size

There is a clear trend that the larger 13B models outperform the 7B models. Although starting from a similar baseline, the larger models significantly outperform the smaller models when fine-tuned. However, a further parameter study on the size of the LoRA Adapter should be undertaken to judge more this more closely.

### Generalization from SQL-PALM to DDL

The performance comparison between models trained on DDL and then used for inference with a SQL-PALM schema offers insightful conclusions. Specifically, the CodeLlama 7B model, which was trained exclusively on a DDL schema for five epochs and then tested with SQL-PALM for inference, exhibited a performance of 51.64%. This is a notable decrease from the 56.67% performance score achieved when inference was conducted with DDL on the same model configuration.

This decrease suggests that while the model has been well-trained to understand and generate predictions based on DDL, it may not generalize as effectively when presented with a schema in a different format, such as SQL-PALM, during inference. The difference in performance shows the importance of consistency between training and inference formats. It also suggests that for optimal performance, additional fine-tuning or a dedicated training regimen might be necessary for the model to become equally proficient with SQL-PALM formatted inputs as it is with DDL. This finding is crucial for practitioners who may consider switching between schema formats, indicating the potential need for format-specific fine-tuning to maintain high performance.

### Conclusion on the formatting for NL2SQL fine-tuning

We can conclude that both SQL-PALM and DDL yield comparable performance results. The availability and ease of generation of the DDL format is the strongest driver for my decision to continue with the DDL format only.

## Analyzing Chunk Length

Building on our exploration from the first blog, where we navigated the intricacies of tokenization and chunk length for natural language processing in decoder-only models, we now turn our focus to the tangible impact these factors have on model performance, particularly for complex database schemas. This is a continuation of our deep dive into fine-tuning the CodeLlama model for the Natural Language to SQL (NL2SQL) tasks, where the crux lies in the model’s ability to translate verbose instructions into precise SQL queries.

In Part 1, we discussed the challenges associated with lengthy instructions that exceed the chunk length, potentially leaving the model with incomplete context for learning. As we move forward, we investigate how varying the chunk lengths affects the model’s performance, especially when dealing with complex examples that are critical to our model’s understanding and generalization capabilities.

We conducted experiments with the 7B CodeLlama model, varying the chunk lengths across four different training jobs — 256, 512, 1024, 2048, and the notably larger 4096 tokens — to observe the performance impact on our dataset. Below are the extracted results explicitly pertaining with these chunk lengths:

![](https://cdn-images-1.medium.com/max/800/1*0LoFKhiUJe4vVPNDiPQaIA.png)

Performance over the chunk length for training CodeLlama 7B Model with DDL schema information.

The performance scores clearly illustrate that chunk length plays a significant role in model efficacy. The smaller chunk length (256 and 512 tokens) result in notably lower performance scores of 31.72% and 44.97%, respectively. This suggests that the model is struggling to comprehend the full context of the instructions within these constrained token limits.

As we increase the chunk length to 2048 tokens, a standard size for many model training scripts, we see an improvement in performance across the board. The model’s ability to grasp the entire context — likely fitting most of the instruction and solution pairs within a single chunk — translates into better performance scores, ranging from 49.71% to 51.64%.

Interestingly, when we expand the chunk length further to 4096 tokens, we observe a significant leap in performance to 55.42%. This indicates that for complex queries, where the database schema and instructions are exceptionally verbose, larger chunk length can provide the model with a more holistic view, thus enabling it to generate more accurate SQL queries.

These findings underscore a crucial consideration for those working with decoder-only models like GPT-style architectures: the trade-off between chunk length and performance. By strategically choosing the chunk length, we can optimize the training process, particularly for complex examples that are paramount for the model’s learning.

## Conclusion

In conclusion, this part of our exploration reveals that while managing tokenization and chunking is a delicate balancing act, doing so effectively can have a substantial impact on the performance of an NL2SQL model. It offers a stepping stone for those seeking to fine-tune their models, providing a benchmark for how chunk length can influence the accuracy and reliability of the resulting model. It also creates a bridge to Part 1, reinforcing the importance of the model’s training environment on SageMaker and how these foundational decisions play a pivotal role in achieving superior model performance.

---

References:

[Natural Language to SQL: Fine-Tuning CodeLlama with Amazon SageMaker — Part 1](/posts/nl2sql-fine-tuning-codellama-part-1)]]></content:encoded>
            <author>Philipp Kaindl</author>
            <category>Fine-Tuning</category>
            <category>NL2SQL</category>
            <category>SageMaker</category>
            <category>LLM</category>
        </item>
        <item>
            <title><![CDATA[Natural Language to SQL: Fine-Tuning CodeLlama with Amazon SageMaker — Part 1]]></title>
            <link>https://philippkaindl.com/posts/nl2sql-fine-tuning-codellama-part-1</link>
            <guid isPermaLink="false">https://philippkaindl.com/posts/nl2sql-fine-tuning-codellama-part-1</guid>
            <pubDate>Thu, 07 Dec 2023 00:00:00 GMT</pubDate>
            <description><![CDATA[Practical strategies for fine-tuning CodeLlama using QLoRA on Amazon SageMaker for natural language to SQL tasks.]]></description>
            <content:encoded><![CDATA[
Natural Language to SQL (NL2SQL) has been a field of open research for the last years, as the integration of large-scale language models in business applications is rapidly evolving.

With the advancements of Generative AI and Large Language Models (LLMs) for code generation, the focus of the community has partially shifted from using smaller fine-tuned models for NL2SQL to larger, commercially available LLMs like GPT-4, Anthropic Claude, or Amazon Titan. Although the zero-shot capabilities of those models are impressive on benchmark datasets (e.g., [SPIDER](https://yale-lily.github.io/spider)), their performance sees substantial improvements when using few-shot prompting. These results are encouraging for ML practitioners to realize the dream of using natural language for data analytics. Still, many real-world challenges remain outside the benchmarks.

## Small Language Models

Why would you go to the lengths of fine-tuning your own smaller model, like [CodeLlama](https://huggingface.co/codellama/CodeLlama-7b-hf) or [Mistral](https://huggingface.co/mistralai/Mistral-7B-v0.1) on NL2SQL? There are many reasons, but the main ones are:

* **Data privacy and full control over the model**
  Your data, your model. When you fine-tune an open-source model, you are in total control. Run it on the cloud? Sure! Are your databases still on-prem? Then you can take it and run it close to them as well.
* **Overfitting is good**
  In the current benchmarks, we are trying to test our models out-of-sample by presenting them with unseen databases and tables. However, if you fine-tune your own model on your own data, then it is a good thing if the model has already seen the tables and databases before in the training data. Of course, this also introduces the problem of the training data becoming stale as the tables change. We will discuss strategies to combat this problem in the following blogs and below.
* **LoRA Adapters**
  When you fine-tune your model with [LoRA](https://huggingface.co/blog/trl-peft), you can load and unload the LoRA adapters at runtime. As LoRA adapters can be loaded and unloaded by simple addition and subtraction of the weight matrices and are independent of each other, we can swap in and out our fine-tuning knowledge.
  With this ability, we could train one LoRA adapter per database or business domain. If the underlying database schema changes with the consequence of a performance degradation on the generation of correct SQL statements, we simply unload that adapter, retrain it, and load it again.
* **Finding the right information**
  In contrast to most benchmarks, where the correct database and tables are already supplied, you first need to find the correct information to construct a query. A well-oiled system for finding accurate information is crucial. However, it can still be improved by fine-tuning our model to our systems. Especially if the table and column names are not that descriptive, fine-tuning can help make the right choice for the model and goes hand in hand with overfitting.

With that in mind, let’s look at a model family that has reportedly been able to outperform even GPT-4 on the task of NL2SQL — [Llama2](https://arxiv.org/abs/2307.09288) (e.g., the blog post by [Kourosh Hakhamaneshi and Rehaan Ahmad](https://www.anyscale.com/blog/fine-tuning-llama-2-a-comprehensive-case-study-for-tailoring-models-to-unique-applications)). One notable derivate of the Llama2 model is [CodeLlama](https://huggingface.co/codellama), which has been further trained on coding tasks. In the following blog, we will investigate how to fine-tune CodeLlama with Amazon SageMaker. More blogs about experiments on the chunk length, table schema format, repetition penalty, and synthetic data generation with Amazon Bedrock will follow.

## TLDR what you can find in this blog:

* A recipe on how to prepare your data and how to fine-tune CodeLlama on Amazon SageMaker with QLoRA, a Parameter Efficient Fine Tuning (PEFT) technique.
* The influence of the chunk length on the performance of the fine-tuned model.
* Deployment to production-grade endpoint with Amazon SageMaker.
* Further improvement ideas for using the methodology on your own company’s data.

NL2SQL spans many interesting challenges (information retrieval, synthetic training data generation with LLMs etc.); however, our focus will not be on how to achieve the highest scores on benchmarks but on how to fine-tune a model for your company.

## Why CodeLlama?

What makes CodeLlama unique is its specialization: it’s fine-tuned with a massive 500 billion tokens related to coding tasks. This means that, compared to its base model, it processes more data with a substantial context length of 16K tokens during training. This additional context length gave it the possibility to take more surrounding code into account when trying to infill code e.g., or when used for auto-completion.

To give you an overview of the training methodology for CodeLlama, please consult the [original paper](https://arxiv.org/abs/2307.09288) (Rozière et al., 2023) or the diagram below.

![The training methodology for CodeLlama can be seen below as outlined in the original paper (Rozière et al., 2023).](https://cdn-images-1.medium.com/max/800/1*X9ZaDfcDmW4zbUNjaQrYbg.png)

The training methodology for CodeLlama is outlined in the [original paper](https://arxiv.org/abs/2307.09288) (Rozière et al., 2023).

## QLoRA and PEFT

Although Amazon SageMaker offers us the ability to scale our compute, we want to be frugal with the resources consumed.

Therefore, our fine-tuning approach with CodeLlama will utilize [QLoRA](https://arxiv.org/abs/2305.14314) on Amazon SageMaker.
QLoRA stands out for its ability to effectively quantize pre-trained language models to 4 bits without substantially compromising performance.

Parameter-Efficient Fine-Tuning (PEFT) methods are a way to adapt pre-trained language models for specific tasks while minimizing resource usage. Unlike traditional fine-tuning, which adjusts all of a model’s parameters, PEFT focuses on modifying only a small subset. This approach significantly reduces the computational and storage demands often associated with fine-tuning large-scale LLMs. Despite altering fewer parameters, PEFT methods can achieve performance levels comparable to those of full fine-tuning. The Huggingface library [PEFT](https://github.com/huggingface/peft) is particularly useful for adapting large models efficiently and supports tools like [Accelerate](https://huggingface.co/docs/accelerate/index), which facilitates their use with large-scale models through integrations like [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [Big Model Inference](https://github.com/OpenBMB/BMInf).

Before continuing further, I would like to thank Phil Schmid from Huggingface for his [excellent blog](https://www.philschmid.de/sagemaker-llama2-qlora) on fine-tuning LLaMA 2 models, which was the base for most of my work.

## 1. Setting up the development environment

To get started with your fine-tuning journey on Amazon SageMaker, all you need is your laptop, that has the correct permissions setup to interact with Amazon SageMaker.

To get started, we will need to configure the access to
1. Hugging Face — Access CodeLlama model
2. Weights & Biases — Track training progress with Wanda library
3. Amazon SageMaker — Run the actual training.
Furthermore, we need to make sure that the correct libraries are installed. CodeLlama, for instance, is available from `transformers version 4.33.0`. Use the following commands to ensure that you have the needed packages:

```
import sys
```

To access CodeLlama assets, log into your [Hugging Face account](https://huggingface.co/welcome). Create a `.env` file that holds your Huggingface Hub token. If you work locally, you can take advantage of it and specify your [SageMaker Execution](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-amazonsagemakerfullaccess-policy) role there as well. We will use [dotenv](https://pypi.org/project/python-dotenv/) to read the sensitive information from there.

```
%matplotlib inline
from dotenv import load_dotenv
import matplotlib.pyplot as plt
# Load the environment variables from the .env file
load_dotenv()

# Get the Hugging Face token from the environment variable
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Check if the token is available
if huggingface_token is None:
    raise ValueError("Hugging Face token not found. Please check your .env file.")
# Login using the Hugging Face CLI with the token
subprocess.run(["huggingface-cli", "login", "--token", huggingface_token])
```

For tracking purposes, we are going to connect the SageMaker training container to [Weights and Biases](https://wandb.ai/site), where we can comfortably check the progress made, even when on the go.

Next, log into your [wandb account](https://wandb.ai/home) from the output cell output. `wandb.sagemaker_auth(path=”scripts”)` saves the login data to a `scripts` folder that we are going to pass on to the training container. wandb will fetch the login details from there automatically. However, we still need to make sure to install wandb in the training container — which will be done via the `requirements.txt`.

```
import wandb

wandb.login()
wandb.sagemaker_auth(path="scripts")
```

If you are going to use Sagemaker in a local environment, you need access to an AWS Identity and Access Management(IAM) role with the required permissions for Sagemaker. You can find more information about it [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html). Next, we are going to set up all our connections to the Amazon SageMaker service.

```
import sagemaker
import boto3

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket = None
local_mode = True
if not local_mode:
    if sagemaker_session_bucket is None and sess is not None:
        # set to default bucket if a bucket name is not given
        sagemaker_session_bucket = sess.default_bucket()
    try:
        role = sagemaker.get_execution_role()
    except ValueError:
        iam = boto3.client("iam")
        role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]
else:
    # Load .env file
    load_dotenv()
    # Get the SageMaker execution role
    role = os.getenv("SAGEMAKER_EXECUTION_ROLE")
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"SageMaker Version: {sagemaker.__version__}")
```

## 2. Load and prepare the dataset

Our experiment will be based on the [Spider Dataset](https://arxiv.org/abs/1809.08887), which is an open-source dataset for natural language to SQL, licensed under the [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode) license. In a preprocessing step, I enriched the dataset that is available on the HuggingFace hub with more information on the table schemas, foreign and primary keys for fine-tuning Llama models. SQL-PaLM inspired this enrichment step and the dataset is available on the [Huggingface dataset hub](https://huggingface.co/datasets/philikai/Spider-SQL-LLAMA2_train).

As the final dataset would most probably be used as a tool in a chat setting, our final dataset should reflect an instruction tuning approach, which incorporates instruction, context, and answer, as shown by an example below.

In the answer, we want CodeLlama to include &lt;SQL&gt;&lt;/SQL&gt; tags for the SQL query for easier parsing of the output.

After formatting, a training example has the following parts

* **Instruction**
  An instruction helps the model to pick up the intent. This instruction has not been tuned. Feel free to improve it!

```
"""
```

* **Context**The aim is to include information about the table structure via the “context” of our query. This information can be extracted from the table itself. Below is a shortened example of a relatively simple database.

```
"""
```

* **Task**Experimentation with API-based and non-fine-tuned model have shown that controlling for the output — to only include the SQL query — can be a challenging task. We want our model to follow precisely our formatting rules, even when running thousands of queries. For responses in JSON format, update the instructions.

```
"""
```

To load the `philikai/Spider-SQL-LLAMA2_train` dataset, we use the `load_dataset()` method from the 🤗 Datasets library.

```
from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("philikai/Spider-SQL-LLAMA2_train")
print(f"Train dataset size: {len(dataset)}")
```

Inspecting the dataset, we find the following features:

```
'db_id': Value(dtype='string', id=None),
'query': Value(dtype='string', id=None),
'question': Value(dtype='string', id=None),
'schema': Value(dtype='string', id=None),
'primary_keys': Value(dtype='string', id=None),
'foreign_keys': Value(dtype='string', id=None),
```

To instruction tune our model via the prompt parts outlined above, we need to convert our structured examples into a collection of tasks described via instructions. We define a `formatting_function` that takes a sample and returns a string with our format instruction.

```
# Play around with the instruction prompt to maximize the model performance further
def format_spider(sample):
    instruction_prompt = f"""Given an input question, use sqlite syntax to generate a sql query by choosing one or multiple of the following tables.
    The foreign and primary keys will be supplied. Write query in between .
    Answer the following question with the context below: \n{sample['question']}"""
    instruction = f"### Instruction\n{instruction_prompt} "
    context = f"### Context\n{sample['schema']} | {sample['foreign_keys']} | {sample['primary_keys']}"
    response = f"### Answer\n {sample['query']} "
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return prompt
```

Applying the formatting function to a random sample (shortened):

```
"""
### Instruction
Given an input question, use sqlite syntax to generate a sql query by choosing one or multiple of the following tables.
The foreign and primary keys will be supplied. Write query in between .
Answer the following question with the context below:
What are the different names of the colleges involved in the tryout in alphabetical order?
p
### Context
[Schema (values) (types)]: | soccer_2 |  College : cname (text)...
[Foreign Keys]: tryout : cname = college : cname |
tryout : pid = player : pid |
[Primary Keys]: college : cname, player : pid, tryout : pid
```

```
### Answer
 SELECT DISTINCT cName FROM tryout ORDER BY cName 
"""
```

Although an integral part of any LLM application, the role of the tokenizer is often overlooked. However, adding or forgetting to add a `eos_token` to your datasets samples can make the difference between a successful or failed fine-tuning job.
You can download the right tokenizer from the Huggingface Hub. The pad\_token gets set to the EOS token to ensure that the model will pick it up during the training process.

```
from transformers import AutoTokenizer
```

With that tokenizer function, we will first tokenize and check the distribution of the token lengths in our dataset.

To make sure that CodeLlama will stop predicting new tokens, we will add an `eos_token` at the end of each example. This ensures, that we not only pad our examples with it, but also enter it when batching up multiple examples together.

```
from random import randint
from itertools import chain
from functools import partial

# template dataset to add prompt to each sample
def template_dataset(sample):
    sample["text"] = f"{format_spider(sample)}{tokenizer.eos_token}"
    return sample

# apply prompt template per sample
dataset_train_tokenized = dataset_train.map(
    template_dataset, remove_columns=list(dataset_train.features)
)
dataset_validation_tokenized = dataset_validation.map(
    template_dataset, remove_columns=list(dataset_validation.features)
)
print(dataset_train_format_ok[randint(0, len(dataset_train_format_ok))]["text"])
```

```
...
### Answer
 SELECT count(*) FROM Ref_locations 
```

An `` has been added at the end of the example solution. CodeLlama will lean throughout the fine-tuning that it has to close every example with &lt;/SQL&gt;, followed by a &lt;/s&gt;. As soon as the &lt;/s&gt; has been predicted, the generation will stop.

### Importance of chunk length on NL2SQL fine-tuning

The next step would be to run a chunking and batching operation on our dataset. Many ML practitioners choose just to apply defaults to those operations. However, we want first to examine our dataset regarding the token length.

![](https://cdn-images-1.medium.com/max/800/1*lGGD__mxasMym9ZQ2dLo5w.png)

Histogram of the input token lengths (instruction and solution) for training. The token length is strongly governed by the complexity of the database, leading to distinctive spikes.

Observe how the token length has distinctive spikes stemming from the database schema description. As we supply the full database description, the model has to find the right tables and columns in the context.

Next, we want to chunk our dataset and batch it. This is done via the following code snippet:

```
# empty list to save remainder from batches to use in next batch
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

chunk_length = 2048
def chunk(sample, chunk_length=2048):
    # define global remainder variable to save remainder from batches to use in next batch
    global remainder
    # Concatenate all texts and add remainder from previous batch
    concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
    concatenated_examples = {
        k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()
    }
    # get total number of tokens for batch
    batch_total_length = len(concatenated_examples[list(sample.keys())[0]])
    # get max number of chunks for batch
    if batch_total_length >= chunk_length:
        batch_chunk_length = (batch_total_length // chunk_length) * chunk_length
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
        for k, t in concatenated_examples.items()
    }
    # add remainder to global variable for next batch
    remainder = {
        k: concatenated_examples[k][batch_chunk_length:]
        for k in concatenated_examples.keys()
    }
    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result

# tokenize and chunk training dataset
lm_dataset = dataset_train_tokenized.map(
    lambda sample: tokenizer(sample["text"]),
    batched=True,
    remove_columns=list(dataset_train_tokenized.features),
).map(
    partial(chunk, chunk_length=chunk_length),
    batched=True,
)

# tokenize and chunk validation dataset
lm_dataset_validation = dataset_validation_tokenized.map(
    lambda sample: tokenizer(sample["text"]),
    batched=True,
    remove_columns=list(dataset_validation_tokenized.features),
).map(
    partial(chunk, chunk_length=chunk_length),
    batched=True,
)
# Print total number of samples
print(f"Total number of samples: {len(lm_dataset)}")
print(f"Total number of samples: {len(lm_dataset_validation)}")
```

In a scenario where you have a decoder-only model for natural language processing, such as GPT-style models, and you’re working with instruction and solution pairs where the instruction is significantly longer than the solution (e.g., 3000 tokens for instruction and 500 tokens for the solution) and a chunk length is smaller than the chunk length, consider the following aspects.

* **Tokenization and Chunking**
  Instruction (e.g., 3000 tokens) and Solution (e.g., 500 tokens): Both the instruction and the solution are tokenized. Given an example chunk size of 2048 tokens, the instruction exceeds this limit significantly.
  The instruction is, therefore, chunked into segments. The first chunk will contain 2048 tokens from the instruction. The remaining 952 tokens become part of the next chunk.
  The solution, being only 500 tokens, can fit entirely in a chunk without needing to be split.
* **Sequence Formation**
  In decoder-only models, the input is often formatted as a single sequence where the instruction and solution are concatenated, typically with a special token separating them (in our case, we train on `### Answer)`
* **Dealing with Length Discrepancy**
  First Chunk: The first 2048 tokens of the instruction are processed. But since there’s no room left in this chunk for the solution, the solution is not included here.
  Second Chunk: The remaining 952 tokens of the instruction are placed in the next chunk, and here, the 500-token solution can also be included, as the total tokens (952 + 500) are within the 2048 token limit.
* **Casual Language Modeling (CLM)**
  Decoder-Only Model: This type of model generates text one token at a time, predicting the next token based on the previous tokens. It does not use the bidirectional context like encoder models.
  Training: During training, the model learns to predict the next token in the sequence. For our dataset it learns to continue the text from the instruction to the solution.
  Attention Masking: The model uses attention mechanisms to weigh the importance of different tokens in the sequence when predicting the next token. In a chunk containing both instruction and solution, it learns the transition from the instructional context to the solution context.

**Training Implications**The main challenge is that the model may not always see the instruction and solution together in the same chunk, especially for very long instructions. This can impact its ability to learn the relationship between specific instructions and their solutions.
Partial Context: In cases where the instruction is cut, the model gets only a part of the instruction in one chunk and the rest with the solution in the next. This affects the learning process, as the model doesn’t always see the complete instruction with the corresponding solution.

In many datasets, instruction lengths are often far below the chunk window. However, for complex databases, schema information can get very long and exceed the chunk length. The implication is that we are not training optimally on our most complex examples.

To understand the impact of the chunk length in our dataset, we kicked off four training jobs with 256, 512, 1024, and 4096 tokens.

Finally, upload your datasets to Amazon S3:

```
# Function to upload a directory to S3 bucket and verify upload
```

## 3. Fine-Tuning CodeLlama with QLoRA on Amazon SageMaker

In order to optimize CodeLlama using the QLoRA methodology, we’ve drawn inspiration from the invaluable insights shared in Phil Schmid’s blog post on [fine-tuning LLaMA models with QLoRA on SageMaker](https://www.philschmid.de/sagemaker-llama2-qlora).

From the same source we can reuse the [run\_clm.py](https://github.com/philschmid/sagemaker-huggingface-llama-2-samples/blob/master/training/scripts/run_clm.py), which implements QLoRA using PEFT to train a model. Post-training, this script integrates the LoRA weights into the model’s architecture when setting `merge_weights=True`. For models that exceed memory capacity, temporary offloading to disk is implemented.

Please take a look at the code snippet that merges the adapter weights:

```
sagemaker_save_dir="/opt/ml/model/"
    if args.merge_weights:
        # merge adapter weights with base model and save
        # save int 4 model
        trainer.model.save_pretrained(output_dir, safe_serialization=False)
        # clear memory
        del model
        del trainer
        torch.cuda.empty_cache()
```

```
from peft import AutoPeftModelForCausalLM
        # load PEFT model in fp16
        model = AutoPeftModelForCausalLM.from_pretrained(
            output_dir,
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16,
        )
        # Merge LoRA and base model and save
        model = model.merge_and_unload()
        model.save_pretrained(
            sagemaker_save_dir, safe_serialization=True, max_shard_size="2GB"
        )
    else:
        trainer.model.save_pretrained(
            sagemaker_save_dir, safe_serialization=True
        )
```

To initiate a SageMaker training job, we utilize a [HuggingFace Estimator](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html). This HuggingFace estimator simplifies the Amazon SageMaker training and deployment process by managing the necessary infrastructure. SageMaker orchestrates the setup and management of required EC2 instances, supplies the appropriate HuggingFace container, uploads scripts, and downloads data from our S3 bucket to the container at `/opt/ml/input/data`, before commencing the training job.

For even easier LLM fine-tuning, you can try A[mazon SageMaker Jumpstart](https://aws.amazon.com/sagemaker/jumpstart/?p=pm&c=sm&z=2), which allows you to fine-tune a large set of models with a click of a button. At the time of writing, CodeLlama was not available on [Amazon SageMaker Jumpstart](https://aws.amazon.com/sagemaker/jumpstart/?p=pm&c=sm&z=2).

An important aspect to keep in mind is that the HuggingFace Estimator version available on SageMaker may not always be in sync with the latest release of the [Transformers](https://pypi.org/project/transformers/) library. To address this, ensure that your environment is running the desired version of the Transformers library by specifying it in a `requirements.txt` file, which allows you to upgrade to a specific version or even install the latest version directly from the GitHub repository. By leveraging this approach, you gain the flexibility to work with the current features and updates from the transformers library, ensuring that your SageMaker environment is equipped with the cutting-edge tools needed for your machine-learning projects.

### Hardware requirements

Please take a look at a few selected options on how to run this training job.

![](https://cdn-images-1.medium.com/max/800/1*yj8qbFpecGJ-f-4y5zqYwQ.png)

Different chunk lengths, instance types, and successful training runs.

You can also use `g5.2xlarge` instead of the `g5.4xlarge` instance type, but then it is not possible to use the `merge_weights` parameter since to merge the LoRA weights into the model weights, the model needs to fit into memory. But you could save the adapter weights and merge them `merge_adapter_weights.py` after training.

```
import time
from sagemaker.huggingface import HuggingFace
from huggingface_hub import HfFolder
```

```
# define training job name; no special chars allowed, no "_"
job_name = f'huggingface-nltosql-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'
# hyperparameters, which are passed into the training job through CLI parsing
hyperparameters = {
    "model_id": model_id,  # pre-trained model
    "dataset_path": "/opt/ml/input/data/training",  #for saving the dataset on the training container
    "epochs": 5,  # number of epochs
    "per_device_train_batch_size": 4,  # batch size for training
    "lr": 2e-4,  # learning rate
    "hf_token": HfFolder.get_token(),  # huggingface token to access CodeLlama
    "merge_weights": True,  # Merging model weights with LoRA adapters after training
    "report_to": "wandb",  # Inspection of the training process via weights and biases
}
# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point="run_clm.py",  # train script
    source_dir="scripts",  # dir that will be copied over to the training container
    instance_type="ml.g5.12xlarge",  # instances type used for the training job
    instance_count=1,  # the number of instances used for training
    base_job_name=job_name,  # the name of the training job
    role=role,  # IAM role used in training job to access AWS resourcee; use your SageMakerExecutionRole
    volume_size=300,  # the size of the EBS volume in GB
    transformers_version="4.28",  # the transformers version of the container
    pytorch_version="2.0",  # the pytorch_version version used in the training job
    py_version="py310",  # the python version used in the training job
    hyperparameters=hyperparameters,
    environment={
        "HUGGINGFACE_HUB_CACHE": "/tmp/.cache"
    },  # set env variable to cache models in /tmp
    keepAlivePeriod=600, # needed to allow time for downloading larger models
)
```

Start the training job with the `.fit()` method. The training container receives the training and validation dataset directly from s3.

```
# define a data input dictonary with our uploaded s3 uris
data = {"training": training_input_path, "validation": validation_input_path}

# starting the train job with our uploaded datasets as input
huggingface_estimator.fit(data, wait=False)
```

The training job will then start to run on AWS. Fine-tuning the 7B parameter model for 5 epochs on a ml.g5.4xlarge instance and a chunk length of 2048 took 17 hours in the us-east-1 region. Which equates to a cost of roughly 35 USD.

As you might lose the kernel while the training job is executing on AWS, you can always attach back to a training job as such:

```
# In case your can't remember the training job name, you can look it up via the management console or via the API
# Create a SageMaker client
sagemaker_client = boto3.client('sagemaker')
# List all training jobs and get the latest training job
response = sagemaker_client.list_training_jobs(
    SortBy="CreationTime", SortOrder="Descending"
)
# Print the training job names
for job in response['TrainingJobSummaries']:
    print(job['TrainingJobName'])
```

```
TrainingJobName = "YOUR TRAINING JOB NAME"

huggingface_estimator = HuggingFace.attach(TrainingJobName)
```

In case you want to perform local testing, further analysis, or deployment of the model outside of SageMaker, you can of course, download it. Please see the snippet below.

```
from sagemaker.s3 import S3Downloader

S3Downloader.download(
    s3_uri=huggingface_estimator.model_data,  # S3 URI where the trained model is located
    local_path="./codellama/",  # local path where *.targ.gz is saved
    sagemaker_session=sess,  # SageMaker session used for training the model
)
```

## 4. Deployment to an Amazon SageMaker Real-Time Endpoint

Now that we have laid the foundations on how to fine-tune Code Llama on your own dataset, we need to deploy the model.

With SageMaker, this is relatively simple, as it automatically saves a zipped version of your model to s3 when you successfully complete a training job.

It is important that the archive directly contains all files and not a folder with the files — Amazon SageMaker takes care of that for you. If you train outside of SageMaker, your file should look like this:

```
```
model.tar.gz/
|- config.json
|- model-00001-of-00005.safetensors
|- tokenizer.json
|- …
```
```

You can use the [pigz-python](https://pypi.org/project/pigz-python/) package to parallelize the archiving.

### Hugging Face TGI container in Amazon SageMaker

There are many options you can choose from to run inference on your model, e.g., when creating benchmarks, one could simply extend the `run_clm.py` script to include a pass on the benchmark dataset at the end.

Another way is to spin up a SageMaker endpoint with the Text Generation Inference container from Huggingface. First, retrieve the container URI that points to the desired Docker image. Amazon SageMaker’s `get_huggingface_llm_image_uri` method, being part of the Amazon SageMaker SDK, facilitates this. It allows for the acquisition of the Hugging Face LLM DLC’s URI, tailored to your specific requirements, including backend, session, region, and version. To explore available versions, refer to the [list of available images](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-text-generation-inference-containers).

```
from sagemaker.huggingface import get_huggingface_llm_image_uri

# retrieve the llm image uri
llm_image = get_huggingface_llm_image_uri("huggingface", version="1.0.3")
# print ecr image uri
print(f"llm image uri: {llm_image}")
```

### Deploying the Fine-Tuned Code Llama on Amazon SageMaker

To deploy models such as CodeLlama on Amazon SageMaker, we create a HuggingFaceModel class. This class forms the basis of our endpoint configuration, encompassing parameters like `hf_model_id`, `instance_type`, and others. We opt for the `ml.g5.12xlarge` instance type, boasting 4 NVIDIA A10G GPUs and a significant 96GB of GPU memory. Feel free to use a smaller instance for your deployment. You can review the instance details on the official [aws ec2 website](https://aws.amazon.com/de/ec2/instance-types/g5).

As we have seen from our token length plot, we need to be able to handle the longest question. Therefore, increase the `MAX_INPUT_LENGTH` parameter to accommodate that.

```
### Deploying the Fine-Tuned Code Llama on Amazon SageMaker
import json
from sagemaker.huggingface import HuggingFaceModel

# sagemaker config
instance_type = "ml.g5.12xlarge"
number_of_gpu = 4
health_check_timeout = 500
# Define Model and Endpoint configuration parameter
config = {
    "HF_MODEL_ID": "/opt/ml/model",  # path to where sagemaker stores the model
    "SM_NUM_GPUS": json.dumps(number_of_gpu),  # Number of GPU used per replica
    "MAX_INPUT_LENGTH": json.dumps(3072),  # Max length of input text
    "MAX_TOTAL_TOKENS": json.dumps(
        4096
    ),  # Max length of the generation (including input text)
    # 'HF_MODEL_QUANTIZE': "bitsandbytes",# Comment in to quantize
}
# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
    role=role, image_uri=llm_image, model_data=s3_model_uri, env=config
)

# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy
llm = llm_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    # volume_size=400, # If using an instance with local SSD storage, volume_size must be None, e.g. p4 but not p3
    container_startup_health_check_timeout=health_check_timeout,  # 10 minutes to be able to load the model
)
```

SageMaker will now create our endpoint and deploy the model to it. This can take up to 10 minutes.

### Running a benchmark on Spider validation dataset

We load the Spider SQL dataset for validation purposes and prepare to run a benchmark test. Prepare the dataset for the model in the same way we fine-tuned it. However, we exclude the answer to the question — it is our models job to predict this one.

```
# Load dataset from the hub
dataset = load_dataset("philikai/Spider-SQL-LLAMA2_train")

def format_spider_validation(sample):
    instruction_prompt = f"""Given an input question, use sqlite syntax to generate a sql query by choosing one or multiple of the following tables.
    The foreign and primary keys will be supplied. Write query in between .
    Answer the following question with the context below: \n{sample['question']}"""
    instruction = f"### Instruction\n{instruction_prompt} "
    context = f"### Context\n{sample['schema']} | {sample['foreign_keys']} | {sample['primary_keys']}"
    response = f"### Answer\n" # LLM Prediction comes here:
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return prompt
```

We test a single example from the validation set to ensure the formatting and model interaction are working as expected.

```
from random import randrange

example_nr = randrange(len(dataset["validation"]))
print(f"Example number: {example_nr}")
random_example = format_spider_validation(dataset["validation"][example_nr])
print(f"Picked this example: \n\n{random_example}")
print("*" * 190)
print(dataset["validation"][example_nr]["query"])
```

```
"""
Example number: 410
Picked this example:

### Instruction
Given an input question, use sqlite syntax to generate a sql query by choosing one or multiple of the following tables.
The foreign and primary keys will be supplied. Write query in between .
Answer the following question with the context below:
What are the names of the teachers whose courses have not been arranged?
### Context
[Schema (values) (types)]: | course_teach |  course : course_id (text) , staring_date (number) , course (text) | teacher : teacher_id (text) , name (number) , age (text) , hometown (text) | course_arrange : course_id (text) , teacher_id (number) , grade (text); | [Foreign Keys]: course_arrange : teacher_id = teacher : teacher_id | course_arrange : course_id = course : course_id | [Primary Keys]: course : course_id, teacher : teacher_id, course_arrange : course_id
### Answer
*******************************
SELECT Name FROM teacher WHERE Teacher_id NOT IN (SELECT Teacher_id FROM course_arrange)"""
```

Now we can send a request to our endpoint:

```
# hyperparameters for llm
payload = {
    "inputs": random_example,
    "parameters": {
        "do_sample": True,
        "top_p": 0.95,
        "temperature": 0.001,
        "max_new_tokens": 256,
        "repetition_penalty": 1.03,
        "stop": [""],
    },
}

# send request to endpoint
response = llm.predict(payload)

# print(response[0]["generated_text"][:-len(":")])
print(response[0]["generated_text"])
```

```
 SELECT name FROM teacher WHERE teacher_id NOT IN (SELECT teacher_id FROM course_arrange) 
```

Exactly, as we want it to be. We have the &lt;SQL&gt; &lt;/SQL&gt; tags for easier parsing.

Proceed to run the actual benchmark on the entire validation dataset. This process involves iterating over the dataset, sending each formatted sample to the model, and collecting the responses. At this point, we are calling the endpoint synchronously and waiting for the endpoint’s response to each query.

```
from time import time
from tqdm import tqdm

answers = []
dataset_val = dataset["validation"]
start = time()
for idx, sample in tqdm(enumerate(dataset_val), total=len(dataset_val)):
    formatted_sample = format_spider_validation(dataset_val[idx])
    # hyperparameters for llm execution
    payload = {
        "inputs": formatted_sample,
        "parameters": {
            "do_sample": True,
            "top_p": 0.95,
            "temperature": 0.001,
            "max_new_tokens": 256,
            "repetition_penalty": 1.03,
            "stop": [""],
        },
    }
    # send request to endpoint
    response = llm.predict(payload)
    answer = response[0]["generated_text"]
    answers.append(answer)
duration = time() - start
avg_duration = duration / len(dataset_val)
```

The benchmark only takes 16:25 minutes, which means the endpoint takes an average of a second to make a prediction.

Save the benchmark answers on disk.

```
import pickle
import os

# Specify the directory and filename
dir_name = "./results"
file_name = "yourfilename.pickle"
# Check if the directory exists, if not, create it
if not os.path.exists(dir_name):
    os.makedirs(dir_name)
# Write the data to the file
with open(os.path.join(dir_name, file_name), "wb") as fp:
    pickle.dump(answers, fp)
```

Analyze the results to determine how many of the model’s responses correctly follow the expected format.

```
# Initialize a count
count = 0

# Loop through the answers
for answer in answers:
    if answer.startswith("") and answer.endswith(""):
        count += 1
print(f"{count} answers have  at the beginning and  at the end.")
```

```
1034 answers have  at the beginning and  at the end.
```

For our experiment, we get 100% format alignment, as all of the 1034 samples have the &lt;SQL&gt; tags at the beginning and end.

### Clean up

Finally, to manage resources effectively, we clean up by deleting the model and endpoint. This is a crucial step to avoid unnecessary charges and keep the environment tidy.

```
llm.delete_model()
llm.delete_endpoint()
```

## 5. Execution Match improvements

To judge the improvements we have made through fine-tuning, we run a SQL benchmarking script to calculate the execution match (EM)accuracy.

### Execution Match Accuracy

To obtain the EM accuracy, we extract and execute each SQL query generated by our model. Next, run the ground-truth “golden” query against your database. If the returned value of both queries matches up, we count it as one true positive. An exemplary script to evaluate the execution match accuracy for all results in a specified directory can be found below. The `df_def.feather` holds all the needed information to run all examples.

```
import pandas as pd
import sqlite3
import pickle
import os
import re
import sys
from pathlib import Path

path = Path(os.path.dirname(__file__))
sys.path.append(str(path.parent.parent.absolute()))

def extract_sql_content(text):
    pattern = r"(.*?)"
    matches = re.findall(pattern, text, re.DOTALL)
    return matches[0]

def clean_results(answerlist: list, return_failed: bool):
    if return_failed:
        failed_dict = {}
    # Clean the list of strings
    clean_list = []
    for idx, item in enumerate(answerlist):
        # extracting the SQL statement from the  tags
        try:
            item = extract_sql_content(item)
        except Exception as e:
            failed_dict[idx] = {"model_output": item}
        # Remove any leading or trailing whitespace
        item = item.strip()
        # Remove any trailing double quotes
        if item.startswith('"') and item.endswith('"'):
            item = item.rstrip('"')
            item = item.lstrip('"')
        # Remove any newlines
        item = item.replace("\n", " ")
        # Add the cleaned item to the clean_list
        clean_list.append(item)
    return clean_list, failed_dict

def run_exact_match_bench(
    df,
    model,
):
    results = []
    counter = 0
    for idx in range(0, df.shape[0]):
        sql_query = df.iloc[idx]["query"]
        prediction_query = df.iloc[idx][model]
        db_id = df.iloc[idx]["db_id"]
        db_file_path = (
            f"{path.parent.absolute()}/spider/database/{db_id}/{db_id}.sqlite"
        )
        conn = sqlite3.connect(db_file_path)
        cursor = conn.cursor()
        try:
            # Fetching the gold standard
            cursor.execute(sql_query)
            result_gold = cursor.fetchall()
            gold_query = f"gold Query: {sql_query}\n"
            gold_result = f"gold result: {result_gold}\n"
            try:
                # Fetching prediction results
                cursor.execute(prediction_query)
                result_preds = cursor.fetchall()
                pred_query = f"prediction Query: {prediction_query}\n"
                pred_result = f"prediction: {result_preds}\n"
            except Exception as e:
                pred_query = f"prediction Query: {prediction_query}\n"
                pred_result = f"error: {e}\n"
                results.append(gold_query + gold_result + pred_query + pred_result)
                continue
            # Comparing the results
            if result_gold == result_preds:
                match = "match\n"
                counter += 1
            else:
                match = "no match\n"
            results.append(gold_query + gold_result + pred_query + pred_result + match)
        except:
            error = "General error\n"
            results.append(error)
    exact_execution_match_accuracy = counter / df.shape[0]
    print(f"Accuracy: {exact_execution_match_accuracy}")
    return exact_execution_match_accuracy

def orchestrate_bench(df_eval, model_id: str, model_id_answers):
    model_id_answers_cleaned, queries_cleaning_failed = clean_results(
        model_id_answers, return_failed=True
    )
    df_eval[model_id] = model_id_answers_cleaned
    exact_match_accuracy = run_exact_match_bench(df_eval, model_id)
    return exact_match_accuracy, df_eval

def run_bench_on_folder(df_eval, results_folderpath):
    results_dir = {}

    df = df_eval.copy()
    for root, dirs, files in os.walk(results_folderpath):
        for result_file_name in files:
            filepath = os.path.join(root, result_file_name)
            if result_file_name.startswith("answers"):
                print(f"working on {filepath}")
                with open(filepath, "rb") as fp:  # Unpickling
                    model_id_answers = pickle.load(fp)
                exact_match_accuracy = orchestrate_bench(
                    df, result_file_name, model_id_answers
                )
                results_dir[result_file_name], df = exact_match_accuracy
    return results_dir
df_eval = pd.read_feather(f"{path.parent.absolute()}/data/dev_df.feather")
path_to_results_folder = f"{path.parent.absolute()}/results/"
results_dir = run_bench_on_folder(df_eval, path_to_results_folder)
print(results_dir)
```

### Execution match accuracy of the models after fine-tuning

The following results are from models that have been fine-tuned with a chunk length of 2048 tokens and an inference setting of a repetition penalty of 1.03.

![](https://cdn-images-1.medium.com/max/800/1*MEoZvtxMwkWCDbvZoZSVZw.png)

Execution match accuracy for different training configurations.

Substantial improvements can be achieved when fine-tuning CodeLlama models.

## Conclusion

Throughout this first part of the blog series, the reader should get an intuition of the main components that need to be taken into account when fine-tuning CodeLlama on Amazon Sagemaker for NL2SQL.

For more experiments and performance tests on chunk length, repetition penalty, training with database schemas through the Data Description Language (DDL), and more, stay tuned for part two of the blog.

**For the code, head over to** [**Github**](https://github.com/philikai/NaturalLanguage2SQL)**!**]]></content:encoded>
            <author>Philipp Kaindl</author>
            <category>Fine-Tuning</category>
            <category>NL2SQL</category>
            <category>SageMaker</category>
            <category>LLM</category>
        </item>
        <item>
            <title><![CDATA[From Prompt Engineering to Auto Prompt Optimisation]]></title>
            <link>https://philippkaindl.com/posts/auto-prompt-optimisation</link>
            <guid isPermaLink="false">https://philippkaindl.com/posts/auto-prompt-optimisation</guid>
            <pubDate>Mon, 02 Oct 2023 00:00:00 GMT</pubDate>
            <description><![CDATA[A case study for marketing content generation using prompt engineering and automatic prompt optimization with AWS Generative AI services.]]></description>
            <content:encoded><![CDATA[
*Co-authored with [Tristan Nguyen](https://medium.com/u/d2dd9861abae).*

In today’s digital landscape, Artificial Intelligence (AI) is no longer a mere accessory in marketing; it’s becoming the core driver of innovative strategies. One AI technology that’s making a significant impact in this realm is Generative AI. This blog aims to guide you through the intricacies of using Prompt Engineering in conjunction with AWS Generative AI services such as [Amazon Bedrock](https://aws.amazon.com/bedrock/) to generate marketing content that not only engages but also converts.

Prompt Engineering is a specialized field within AI that focuses on crafting effective prompts or instructions to guide AI models in generating desired outputs. This technology is particularly transformative for marketers as it allows creating tailored content at scale and more personalized audience engagement.

Automatic Prompt Optimization employing Large Language Models (LLMs) as optimizers in automated prompt engineering is a pivotal advancement inspired by “[Large Language Models as Optimizers](https://arxiv.org/abs/2309.03409)” by Yang et al. We will be using an LLM as a prompt engineer and outline the obstacles that we will encounter when using LLMs as Prompt Optimizers. In this blog, we delve into leveraging this approach with AWS Generative AI services, focusing on the transformation it brings to marketing strategies by enabling the generation of more engaging and converting content.

What You Will Learn:

* **Prompt Engineering Basics**: Understand the art and science behind crafting effective prompts.
* **Best Practices**: Learn how to align Prompt Engineering with Amazon Generative AI Services for optimal results that is both engaging yet compliant.
* **Case Study:** We will also walk you through a case study of how Prompt Engineering can be used to generate effective marketing content as well as how to be aware of common pitfalls and how to avoid them.
* **Automatic Prompt Engineering:** We will go from a first outline of how automatic prompt optimization can work to using it and examining the outcomes. Furthermore, we will outline the current challenges of the method and how to overcome them.

## What is Prompt Engineering?

Prompt Engineering is the practice of crafting text inputs to guide generative AI models, including text-to-text and text-to-image models, in producing specific outputs. This technique leverages in-context learning, which is the model’s capacity to adapt temporarily based on the given prompts. In-context learning is a characteristic that naturally arises in LLMs.

In the realm of text-to-text models, a prompt can take various forms. It could be a question like “What is Fermat’s Little Theorem?”, a directive such as “Compose a poem about autumn leaves,” or even feedback phrases like “too wordy” or “more formal.” Prompts can also be more elaborate, incorporating context, guidelines, and even sample data. For instance, using a few-shot learning approach, a prompt might include examples like “Maison → house, chat → cat, Chien →” to guide the model’s responses — in our example to translate from French to English.

For text-to-image or text-to-audio models, prompts usually describe the desired end product. For example, “a high-resolution image of an astronaut on horseback” or “slow-tempo Lo-fi electro chill music featuring natural sounds.” When using text-to-image models, the prompt can be fine-tuned by adding or removing words, or by adjusting their order to influence the subject, style, layout, and even lighting of the generated image.

## Setting Up

For our endeavor, we will need to set up some fake information to work with.

**Product information** in a JSON file:

```
{
"products": [
{
"id": "L001",
"Name": "Unsecured Installment Loan",
"Title": "Lending Solutions for Accomplishing Your Goals:",
"Description": "An Unsecured Installment Loan from FNB1 can provide you with the borrowing power you need to meet your financial objectives. These loans are repaid over a set time frame (term) with a specific number of payments.",
"Key Features": [
"Borrow anywhere from USD 2500 - USD 25000",
"One-time, lump-sum distribution of loan proceeds",
"Fixed repayment schedule",
"Available with a fixed rate",
"Interest rate discount available when you set up automatic monthly payments from an FNB Deposit account",
"Flexible repayment terms"
],
"Key Benefits": [
"Immediate spending power",
"Flexible and convenient payment methods including by mail, in-branch, online or automatic debit from your deposit account"
],
"Great For": [
"Larger, one-time purchases",
"Home renovations",
"New appliances",
"Debt consolidation"
]
},
{
"id": "L002",
"Name": "Secured Installment Loan",
"Title": "Enjoy Real Flexibility:",
"Description": "A Secured Line of Credit from FNB1 gives you the flexibility to borrow up to your credit limit at any time. Whether you use your line for overdraft protection or just as a cushion for those unexpected expenses, a line of credit provides you the borrowing power you need to help meet life's financial demands.",
"Key Features": [
"Secured with an FNB Savings Account or Certificate of Deposit",
"Flexibility to borrow as needed, repay and borrow again up to your credit limits",
"Borrow up to USD 250000",
"Variable interest rate tied to the U.S. Prime Rates",
"Access your funds via check or online transfer"
],
"Great For": [
"Providing overdraft protection for your checking account",
"Covering unexpected expenses",
"Providing liquidity to bridge the gap between when your bills are due and when you get paid"
]
},
{
"id": "CC003",
"Name": "SmartRate Credit Card",
"Title": "Get More Savings with Lower Rates",
"Description": "With the SmartRate Card, you can get more savings with our lowest interest rate. You can use your card anywhere Visa is accepted to conduct everyday purchases. For added convenience, consolidate your monthly bills on one statement for one monthly payment.",
"Key Features": [
"No Annual Fee (Please review rates and fees in the pricing information link below for costs associated with this credit card)",
"Competitive Variable APR"
],
"Great For": [
"Everyday purchases, from groceries to gas to monthly bills",
"Consolidating your monthly expenses on one convenient monthly bill",
"Saving on interest every time you use your card to make a purchase"
]
}
]
}
```

**Customer information:**

```
Name: Emily
Age: 80
Channel: Email
Recommended Product: L001
Preferred Language: English
```

## Prompt Engineering

In the next section, we will examine the different prompt engineering techniques undertaken.

### Prompts For Different Communication Channels

For creating the different templates, we will add instructions on how we want to have our messages structured for our outbound communication. These instructions will be kept the same.

In the `Prompt` portion, we have provided instructions on the channel (email or SMS), style, the target customer and target product,

This is followed by the `Output-Formatting-Instructions` which outlines how the LLM outputs should be formatted to be picked up by AWS Communications Developer Services like Amazon Pinpoint or Amazon Simple Email Services.

The general structure is as follows:

```
{Optional: Model specific instructions}
{Prompt}
{Output-Formating-Instructions}
{Optional: Model specific instructions}
```

### Email

```
{Prompt}

Given the above details, generate 3 email parts in the specified format:

        Subject: Subject of the email
        HTML Body: Content of the email but formatted nicely in HTML
        Text Body: Same content of the email formatted in plaintext

        Format:
        The returned string should be constructed as follows:
        1. Start with the delimiter "###SUBJECT###" followed by the subject content, and then end with "###END###".
        2. Next, start with the delimiter "###HTMLBODY###" followed by the HTML body content, and then end with "###END###". Make sure the generated HTML code has opening and ending  tags.
        3. Finally, start with the delimiter "###TEXTBODY###" followed by the text body content, and then end with "###END###".
        4. Only output the text not any instructions.
        5. Output language is {lang}
```

### SMS

```
{Prompt}

Given the above details, generate content for an SMS message in the specified format:

        Text Body: Content of the SMS message in plaintext

        Format:
        1. Start with the delimiter "###TEXTBODY###" followed by the SMS message content, and then end with "###END###".
        2. Only output the text not any instructions.
        3. Output language is {lang}!
        4. Limit the text body content to 160 characters or less.
```

The above prompt approach with its format instruction points serves a few purposes. The reasons for the instruction 1–4 in order are:

1. **Delimiter instruction:** We need to be explicit with the delimiter (e.g. “###TEXTBODY###”), and it also needs to be a sufficiently unique string to ensure that there won’t be instances that the actual marketing content generated by the LLMs is incorrectly spliced.
   Alternatively, we can make multiple calls to Amazon Bedrock for the Subject, HTML Body, and Body respectively. The drawback of such an approach is that context might be lost between calls, and there will be a higher chance that the model’s output will be inconsistent, e.g. the plaintext email body does not reflect the plain-text version of the HTML email body.
2. **Output part instruction:** This is necessary as a consequence of the way the prompt is structured. Since we are asking the LLM models to role-play in a conversational context, there will be instances in which it will output instructions/conversational fillers.
3. **Language output instruction:** The desired output language can be passed inside the prompt to specify the output language, which provides a convenient way to output marketing content according to the recipient’s preferred language. Do take note that out-of-the-box LLM models may not perform as well in languages other than English and would need fine-tuning.
4. **SMS Output limitation:** For SMSes, since messages that are more than 160 characters will get broken up and incur additional costs for sending, we can attempt to impose a character limit to the generated message.
   At the core, general-purpose LLMs are not tuned explicitly for mathematical or arithmetic operations. Therefore, the character count result can often be inaccurate. However, we can provision for this problem, by drafting a correction prompt whenever we exceed the character limit. A correction prompt could be as simple as: “\{LLM\_generated\_SMS\}\n The above-created SMS is too long, as we can only send 160 characters. The SMS is \{SMS\_char\_length\}, which is a factor of \{SMS\_char\_length/160\} too long. Shorten it to fit in one SMS.”

For example, an un-tuned model can give the following result for the SMS channel when supplied with the outlined prompt.

> ###TEXTBODY###
> Dear Emily, this is John from FNB1.
> I know you've worked hard over the years.
> Let me tell you about our Unsecured Installment Loan that can help make your dreams come true.
> With fixed rates and flexible terms, you can get the money you need for that dream vacation.
> Don't wait, call me today and I'll help you get started.
> ###END###

Some problems with that SMS:

1. Excluding the delimiters, the message itself is 322 characters.
2. Therefore, it could be necessary to tune the model based on compliant text messaging contents sent out in the past, which will give a higher likelihood the model will produce compliant SMS suggestions.
3. Alternatively, you can insert guardrails as proposed by the Prompt — Iteration 3 section.
4. Lastly, it’s best practice to have an automated content quality/content moderation check for outgoing marketing materials. A simple character count to reject messages that are too long would work here, alternatively, with the correction prompt approach outlined above.

## Prompt - Iteration 1

In the first iteration, we are starting with a prompt template, instructing the model to write a high-converting email or SMS. We will be incorporating several techniques that are common in marketing into the communication, such as:

* power of storytelling
* social proof
* persuasive language

As we want to generate one template that we can scale to all our customer base, we substitute the marketing channel and user attributes like name and age.

> You are a banking advisor named John Smith from First National Bank. Write a high-converting \{channel\} that will leverage the power of storytelling, social proof and persuasive language to captivate the customer \{name\} — aged \{age\} years old and drive them to buy \{product\_name\}.
> The product description is: \{product\_description\}
> The key features are: \{product\_key\_features\}
> The product is great for: \{product\_great\_for\}

The inference settings for the experiments were set as outlined below, and we would like to point the interested reader to the following resources for an in-depth explanation of the LLM settings [temperature and top\_p](https://www.promptingguide.ai/introduction/settings.en).

* max\_tokens\_to\_sample = 4096
* temperature = 0.5
* top\_p0.9

With all these instructions, the model acts like a salesman and starts hallucinating and making up stories to capture the customer.

## Results — Iteration 1

```
###SUBJECT###
Emily, let me tell you about our Unsecured Installment Loan
###END###

###HTMLBODY###




Dear Emily,

My name is John Smith and I'm an advisor at First National Bank.
I wanted to tell you about our Unsecured Installment Loan which I
think could be a great option for you.

I know you've been looking to finance that new kitchen renovation
you told me about last week. An Unsecured Installment Loan is perfect
for bigger, one-time expenses like home renovations. Here's how it works:


  Borrow anywhere from USD 2500 - USD 25000 for your project
  Get the money all at once in a lump-sum when your loan is approved
  Repay on a fixed schedule with a set number of payments
  Lock in a low fixed interest rate
  Save even more when you set up automatic payments from your FNB account
  Choose flexible terms that fit your budget - 2 to 7 years


I think this loan would be ideal to get the cash you need to renovate your
 kitchen just how you envisioned. Many of our customers have used it for
similar home projects and were thrilled with the results.

Let's schedule some time to go over the details and get you pre-approved.
I'm excited to help make your kitchen dreams a reality!

Sincerely,

John Smith



###END###

###TEXTBODY###
Dear Emily,

My name is John Smith and I'm an advisor at First National Bank. I wanted to
tell you about our Unsecured Installment Loan which I think could be a great
option for you.

I know you've been looking to finance that new kitchen renovation you told
me about last week. An Unsecured Installment Loan is perfect for bigger,
one-time expenses like home renovations. Here's how it works:

- Borrow anywhere from USD 2500 - USD 25000 for your project
- Get the money all at once in a lump-sum when your loan is approved
- Repay on a fixed schedule with a set number of payments
- Lock in a low fixed interest rate
- Save even more when you set up automatic payments from your FNB account
- Choose flexible terms that fit your budget - 2 to 7 years

I think this loan would be ideal to get the cash you need to renovate your
kitchen just how you envisioned. Many of our customers have used it for
similar home projects and were thrilled with the results.

Let's schedule some time to go over the details and get you pre-approved.
I'm excited to help make your kitchen dreams a reality!

Sincerely,

John Smith
###END###
```

## Challenges in Promoting Responsible AI Deployment

In the context of deploying machine learning models, we encounter several critical challenges that warrant attention for a seamless production deployment.

1. **Ethical Content Generation**: One significant challenge arises from the model’s attempt to boost conversion rates through the generation of fictitious narratives or by assuming an undue familiarity with the customer. For instance, it may produce statements like, “Looking to finance a new kitchen renovation,” which can come across as inauthentic and resemble the tactics of untrustworthy salespeople.
2. **Hallucinatory Responses**: The model’s behavior sometimes extends to hallucinating entire conversations with customers, such as, “I know you’ve been looking to finance that new kitchen renovation you told me about last week.” This not only raises concerns about data privacy but also risks misleading customers.
3. **Unprompted Social Proof**: Additionally, the model may attempt to strengthen its recommendations by citing social proof, like, “Many of our customers have used it for similar home projects and were thrilled with the results.” It’s important to note that the model has no access to information regarding recent conversations, emails, or call scripts involving the customer.
4. **Potential for Personalization**: While our current prompts do not supply specific customer information, it’s worth highlighting that companies with a comprehensive understanding of their customers, acquired through various touchpoints, can harness this knowledge to personalize messages more effectively. By incorporating insights from customer interactions and lifecycle data, a more tailored and meaningful AI-driven communication strategy can be developed.

In the pursuit of responsible AI deployment, addressing these challenges is essential to ensure that AI models enhance customer experiences without compromising trust and ethical standards.

## Prompt — Iteration 2

On the second prompt, we will be trying to address the challenges of the prompt version 1 as well as give some more information on the sender, so that the email can be tailored to it:

> You are an marketing content creator assistant for bank advisors for “First National Bank”.
> You are assting John Smith, a 54 years old bank advisor working for First National Bank for 20 years.
> As a respectable bank, we need to keep our marketing content factual, as the banks reputation is at risk when making up stories.

> Create a marketing \{channel\} to sell buy \{product\_name\} for customer \{name\} who is \{age\} years old.
> Adjust the language depending on the age of the customer and the typical needs for customers of that age.

> The product description is: \{product\_description\}
> The key features are: \{product\_key\_features\}
> The product is great for: \{product\_great\_for\}

The two prompts provided have some notable differences, particularly in terms of the context, tone, and approach they take to create marketing content for a bank advisor. Here’s a breakdown of the key differences between the two prompts:

### Context and Persona:

* Prompt V1: In this prompt, the writer assumes the persona of a banking advisor named John Smith and is tasked with creating high-converting marketing content that leverages storytelling, social proof, and persuasive language. The customer’s name and age are variables.
* Prompt V2: In this prompt, the writer serves as a marketing content creator assistant for bank advisors at “First National Bank” and is assisting John Smith, an experienced bank advisor. The focus is on creating factual marketing content that aligns with the bank’s reputation, and the writer is instructed to adjust the language based on the customer’s age and typical needs.

### Tone and Style:

* Prompt V1: The tone is more focused on persuasion and captivation. It encourages the use of storytelling and persuasive language to sell the product.
* Prompt V2: The tone is more factual and cautious. It emphasizes the importance of maintaining the bank’s reputation and advises against making up stories.

### Content Approach:

* Prompt V1: The content approach is more focused on creating a persuasive narrative that aims to captivate and persuade the customer to buy the product. It encourages the use of social proof and persuasive techniques.
* Prompt V2: The content approach is centered on providing accurate and age-appropriate information to the customer. It requires adjusting the language based on the customer’s age and needs while avoiding fictional storytelling.

### Persona’s Role:

* Prompt V1: The writer assumes the role of a specific banking advisor named John Smith, but there is no explicit guidance on the writer’s experience or the bank’s reputation.
* Prompt V2: The writer is positioned as an assistant to John Smith, emphasizing the importance of maintaining the bank’s reputation, and John Smith is an experienced bank advisor.

In summary, while both prompts involve creating marketing content for a bank advisor, Prompt V1 focuses more on persuasive storytelling and social proof, whereas Prompt V2 prioritizes factual and age-appropriate content while emphasizing the bank’s reputation.

## Results — Iteration 2

```
###SUBJECT###
A loan to simplify your finances, Emily
###END###

###HTMLBODY###


Loan Offer


Dear Emily,

As your trusted financial advisor at First National Bank,
I wanted to let you know about an excellent borrowing option
that may help simplify your finances - our Unsecured Installment Loan.

At your age, I know you may have some large upcoming expenses
for healthcare, home repairs or other needs. This loan provides
flexible borrowing up to $25,000, with fixed regular payments spread out
over a timeframe that works for your budget.

With our online application and fast approval process, the funds can be
deposited directly into your account within days. The fixed rate gives you
predictability, and the automated payments from your FNB account can even
save you money.

Please let me know if you would like to discuss this option in more
detail. I'm here to help provide solutions tailored for your financial
situation.

Sincerely,
John Smith
Financial Advisor
First National Bank



###END###

###TEXTBODY###
Dear Emily,

As your trusted financial advisor at First National Bank, I wanted to let you
know about an excellent borrowing option that may help simplify your
finances - our Unsecured Installment Loan.

At your age, I know you may have some large upcoming expenses for healthcare,
home repairs or other needs. This loan provides flexible borrowing up to
$25,000, with fixed regular payments spread out over a timeframe that works
for your budget.

With our online application and fast approval process, the funds can be
deposited directly into your account within days. The fixed rate gives you
predictability, and the automated payments from your FNB account can even
save you money.

Please let me know if you would like to discuss this option in more detail.
I'm here to help provide solutions tailored for your financial situation.

Sincerely,

John Smith
Financial Advisor
First National Bank
###END###
```

## Iteration 3 — Automatic prompt optimization

In the last version, we are going to use our model to improve the prompt itself. The inspiration for our approach is largely based on the ideas of the publication “[Large Language Models as Optimizers](https://arxiv.org/pdf/2309.03409.pdf)” by Yang et al. at Deepmind. They outline using an LLM to improve a prompt, which will be evaluated against an objective function evaluator. This objective function evaluator, in essence, runs the new prompt generated by the LLM against a test set and calculates a performance metric on the test set. (e.g., rouge-score, accuracy,…).

Contrary to the paper, we do not have an objective value function that will serve as the input to our optimization problem. Therefore, we are going to utilize an LLM to self-evaluate itself. Therefore, we are moving ahead with a subjective function evaluator. However, in a real-world setting, we could measure, e.g. conversion rates of different prompt templates, and use this as feedback to the model.

When using a subjective function evaluator, such as an LLM, for the optimization problem, we face a few obstacles that we need to be aware of:

* **Bumpy response surface**
  If we ask the model to evaluate its own output, we could observe across multiple runs that the model seems to be replying with discrete intervals, even though it has the whole spectrum between 0 and 100 and consistently attributes itself too high scores.
* **High confidence in its own score**
  Starting out, we prompted the model to evaluate the answer on a score from 1–100. It consistently picked scores >75. Only after prompting it to be *very strict,* the model started lowering its score. The output of the model was further contained by introducing hard limits. For example, *If the model is not staying factually correct, the score can maximum be 25.*
* **Prompts not running successfully.**
  The model is tasked to generate a new prompt format as well as personalize the marketing content for the end user. This can lead to new variables being introduced in the prompt, that cannot be satisfied when running the prompt template. One such example was the hallucination of the variable \{bank\_advisor\_name\}, which was not in the set of. Empirically, a positive correlation between the model breaching the variable set and the presence of a variable in the prompt template that was not marked as available could be observed.

Next, let’s examine the optimization loop we are going to use.

### Optimisation Loop:

![](https://cdn-images-1.medium.com/max/800/1*IQL52M8nayENTAIKMtTx-w.png)

The optimization loop sketched out.

We start with the Executor LLM, executing the initial prompt that we provide. We are optimizing the prompt for this model. That output will be the input for the Evaluator LLM, which evaluates our prompt based on the guardrails and guidelines that we outline in the Evaluation Prompt and creates a recommendation to the downstream Improvement LLM. The Improvement LLM takes all that information in and crafts a new prompt, which kicks off a fresh cycle.

### Critical Components for Successful Model Operation

To ensure our model operates successfully, there are a few essential components we need to focus on:

* **Evaluation Prompt:**
  The Evaluation Prompt is crucial — it acts as a self-check, enabling the model to assess the output of the previous iteration against our set guidelines. We aim to keep this prompt as universal as possible to make it reusable for various tasks.
* **Improvement Prompt:**In this stage, the model receives all the crucial inputs, the ones we want it to pay attention to during the optimization process. It’s like a roadmap highlighting the areas we want to improve.
* **Prompt History:**To assist the Improvement LLM, we maintain a log of the best prompts along with their scores. This log is updated at every optimization cycle and is re-introduced to the model.
* **Prompt Cleaning, Parsing, and Error Handling:**Before reusing prompts, it’s essential to clean them of all the stop-words. If not, the LLM might generate these stop-words as seen in the improvement prompt. Additionally, we extract various fields like recommendations and instructions by parsing the model’s outputs.

For our upcoming tasks, we’ll be utilizing Anthropics’ Claude models through [Amazon Bedrock](https://aws.amazon.com/bedrock/) for both evaluation and improvement, thanks to their abilities to be steered very closely with natural language tags.

Let’s now explore how we are doing when running the optimizer for ten cycles when starting from our last prompt. All the model outputs can be found on [GitHub](https://github.com/philikai/FromPromptEngineeringToAutoPromptOptimisation) for closer examination. The code will be released shortly after.

Let us add a bit more depth to the customer outreach, by supplying fake user data as we could find it at a bank.

```
user_variables = {
```

We are going to reutilize the prior prompt template that we used in prompt iteration 2. Once again for more effortless reading, we have it here again.

**Initial Prompt:**

```
\n\nHuman:You are an marketing content creator assistant for
```

The evaluation and improvement prompts we are going to be using are:

### Evaluation Prompt

```
\n\nHuman:You are a prompt optimizer and your task is to evaluate if
```

### Improvement Prompt

```
\n\nHuman:You are a prompt optimizer and your task is to improve
```

Let’s see what we have after three iterations of the prompt optimizer.

### Prompt template after 3 Iterations:

Let's have a look at what the model has come up with between the &lt;INST&gt;&lt;/INST&gt; tags, after optimizing the prompt three times to sound more like a banking advisor.

```

```

We can observe that the prompt template starts to add more emphasis on the reputation of the bank, and goes less in the direction of hyper-personalisation. The model does not take all the variables into account that we are supplying and is cautious of using information about the financial situation of the customer. On one hand, that is good, as the customer might have more money with other banks — here data/notes about the customer's financial situation would be very helpful to get a more rounded insight into the situation. However, if we want the model to take different paths and be for example more creative, we can steer it every few iterations with human feedback to align it more to our preference.
The LLM generated output text body after the third iteration:

```
Dear Josh,
```

Already pretty good. Let’s keep this running for a few more rounds and see where this goes without any supervision.

### Prompt template after 10 Iterations:

```
You are an marketing content creator assistant for
```

**Output of the executed prompt template after ten iterations:**

```
Dear Josh Milwakee,
```

Now we have reached a very elaborate prompt template, that really emphasizes the needs of a marketing prompt in the financial industry. The only drawback that we can see is, that after ten iterations, the model has come up with a prompt template that sells a “novel” product, even though we didn’t set out to do so.

## Conclusion

Throughout this article, we’ve delved into the practical application of prompt engineering, illustrating its potency through a hypothetical scenario involving bank marketing. We examined the capabilities of LLMs to assess and refine their own prompt templates through successive rounds, ensuring enhanced customization in email communications.
However, the journey doesn’t end here; there are numerous avenues left to explore and optimize. A user-friendly interface is paramount, enabling us to pause the optimization process and converse with the optimizer, providing real-time feedback on ongoing developments. Additionally, by testing against a broader customer base, we can further ascertain the effectiveness of our prompt templates, ensuring their resilience and adaptability across diverse customer segments. Lastly, the entire marketing pipeline from content generation to content distribution can be automated by integrating with AWS Communications Developer Services such as Amazon Pinpoint or Amazon Simple Email Service to enable the complete journey in one simple interface. Stay tuned if you are interested in Infrastructure as Code to spin up such a solution.

## Further Exploration and Contributions:

For those interested in delving deeper into the optimizations and seeking to contribute, all iterations of the optimizer are accessible on our GitHub repository, and we anticipate releasing the associated code soon.
For more details on how/where a prompt engineering/optimization module can be integrated into your marketing use cases, check out our other blog [Building Generative AI into Marketing Strategies: A Primer](https://aws.amazon.com/blogs/messaging-and-targeting/building-generative-ai-into-marketing-strategies-a-primer/).
This endeavor serves as a stepping stone towards harnessing the full potential of prompt engineering in creating more nuanced and personalized user experiences, and we look forward to the innovations that will stem from these foundational explorations.]]></content:encoded>
            <author>Philipp Kaindl</author>
            <category>Prompt Engineering</category>
            <category>LLM</category>
            <category>Bedrock</category>
            <category>AWS</category>
        </item>
    </channel>
</rss>