
Fine-tuning has proven to be the most effective way to improve the performance of large language models on domain-specific tasks.
While every task may be different to us humans, to language models, they all involve one thing - predicting the next token. We know that each token is generated at a computational cost that is a function of the input length. In essence, tasks that produce approximately equal number of tokens will require roughly the same computational cost.
Essentially, a language model of a sufficient architectural capacity will require approximately the same computational cost for both if 2x + 2 = 4 then x = 1
and My name is John Smith
.
If this is the case, what is then the point of causing LLMs to reason when they don’t have the inherent ability to do so?
Well, to language models, the task is not to reason about the problem, it never has been; the task, however, is to generate as many tokens as needed to reach the terminal token. The terminal token signifies the end of the sequence after which the model stops doing further work to generate any more tokens. Two problems may be entirely different, but if they both need just 5 tokens to reach the terminal token, the LLM’s day’s work is done. That is, there is no planning ahead for the number of tokens or to determine when the terminal token should be generated.
So, can LLMs then simulate the process of reasoning?
Reasoning, you see, involves, in a large part, putting more mental effort into the process of determining the solution to a problem. Although LLMs don’t reason (I promise you, they don’t), you can still force them to put more effort into arriving at the terminal token. Since, to an LLM, sequences of equal token lengths are roughly computationally equal and therefore require approximately the same computational cost; then, a cycle over 30 tokens has done more work than that of 25 tokens. In other words, generating more tokens (and by effect, processing more input tokens) directly implies doing more work.
We have seen cases where encouraging an LLM to spend more tokens on solving a problem returns better results. For example, “thinking step by step”, a technique now known as Chain of Thought (CoT) prompting was an early demonstration of this behaviour. Take a look at these examples from the original paper.
Image Source: Wei et al. (2022), shot by https://promptingguide.ai/techniques/cot
In the above example, the LLM was encouraged to generate more tokens involving the thinking steps by providing examples that demonstrated such reasoning output.
In another paper, simply instructing the LLM to think step by step gave better results.
Image Source: Wei et al. (2022), shot by https://promptingguide.ai/techniques/cot
Thus, if we want an LLM to put more effort into generating each idea for a problem solution, we must encourage it to generate more useful intermediate tokens between these ideas.
Quality by Quantity
To LLMs, simulated reasoning is still all about generating tokens. The question then becomes what these tokens should be about. Indeed, if a language model is simply asked the name of a country’s president and then generates the country’s national anthem before outputting the President’s name, that is undoubtedly useless work. This is why it is necessary to instruct the model to “think out loud” about the problem by generating its chain of thought.
So why does it work?
LLMs use previous tokens to generate new ones based on a mechanism called the Attention Mechanism: Very simply put, each part of the input sequence (which could also be the LLM’s previous outputs) is given special “attention” (a calculated weight) to determine its relationship with other parts of the input as well as its role in the overall sequence or context. This calculated information is then used to “query” the model’s already existing knowledge about the world which includes training examples of such sequences with similar attributes and patterns, the model’s general understanding of sequences (in the target language) and so on. It’s like saying: Given this information (which can be in millions) about this subject, what is likely to come next?
If the model can gather sufficient information about the relevant tokens of the input sequence, it is expected to make better predictions of the next token in the sequence and so on. Just as I can more accurately predict the weather if I have enough information about the humidity, the direction of the wind, the cloud density, even down to the information on the distribution of charges in the atmosphere, the collision pattern and the change in momentum of all the atmospheric molecules etc.
As you add more tokens to the sequence, which form the next input in the generation cycle, you are adding more information for the model to process and also enriching the individual tokens with more information about its relationship with every other token in the sequence, thanks to the new neighbour tokens. With this, the model now has more (statistically) useful things to say about a token it already visited earlier and also about the sequence in general. This richer information helps the model find related partners. Technically, neurons that are sensitive to these patterns become activated. It’s like they say “oh, I can see you are trying to think step by step here, I know a thing or two about thinking step by step, I have seen thousands of examples. I can help”.
Can we always attain quality by quantity
No. Not all the time. Getting the best output from chain-of-thought reasoning depends on two factors:
- The capability of the model
- The amount of useful CoT reasoning examples.
A capable enough model is usually able to find much deeper patterns between ideas, which are made up of tokens and sequences. It can therefore make use of those patterns by reapplying them to solve new problems. Finding patterns can be seen as the degree to which a model “understands” the data. That is why you could tell GPT 3.5 to think step by step over a problem it hasn’t seen before and still get a decent performance over say, Llama 13b.
For a model that isn’t as puffed as a multi-billion parameter model, very high data quality provides the strength. Instead of leaving the model to figure out the reasoning patterns on the fly when trying to solve reasoning-based problems, we provide it with lots of examples where deep CoT reasoning is applied to solving different kinds of problems. We can get better results if the data reflects different phases of step-by-step reasoning I like to call reasoning checkpoints; hypothesizing, assumptions, self-evaluation, backtracking, procedural deduction, conclusion and so on. It is even better if these examples are tailored to the specific use cases the model will be applied to.
How to take advantage of CoT reasoning capabilities of LLM
Our understanding of this behaviour helps us achieve two things:
- Effective Prompting
- Preparing high-quality fine-tuning data.
Attaining Effective Prompting
Chain of Thought prompting and its variants are prompting techniques that encourage LLMs to solve a problem using a step-wise approach. However, merely instructing the model to think step by step doesn’t actually guarentee useful results. Instead, probing the model to demonstrate specific reasoning checkpoints have been observed to give better results.
Here is an example of a prompt I can append to any problem I want the LLM to think deeply about:
"Solve this problem and think out loud like a true intelligent deep thinker would. Show your individual reasoning steps. Show the thought process that leads to every deduction you make. For every deduction you make, try to go over it again to make sure there isn't any fault in your reasoning. If you find any fault, go over the problem again to point out the fault, then correct yourself. Do this continuously until you reach the final answers."
The above prompt gives the model more targets other than the final answer which increases the number of “quality” tokens that help contribute to the final answer. Let’s see it in practice.
Intermediate back-and-forths were removed.
A single shot success just by also asking it to say more stuff.
A single shot failure with “Think step by step”.
A single shot success with “Think step by step” and asking for more stuff.
As we have seen, this is different from simply asking it to think step by step. You are imposing checkpoints on the model; so, hypothesis generation, self correction and so on, all become part of the goal and by definition, part of the expected output of the model.
Powerful prompts that are cleverly engineered to take advantage of the model’s full capabilities can result in high quality outcomes.
Effective CoT Fine-tuning
High quality reasoning data is hard to come by. It is not like most people actually “think” when spitting things out on the internet. That is why it is important to carefully engineer fine-tuning data for reasoning. There are ways to acquire high quality data for this purpose.
- Human experts: What I like to call fine-tuning with human reasoning.
- Distillation: The outputs of a powerful model that does well when CoT-prompted can be used to train a smaller, less capable model.
Whatever the means of acquiring fine-tuning data is, a high quality reasoning data should have the following qualities:
- Reasoning Checkpoints/Artifacts: Hypothesising, Deductions, Assumptions, Perspectives, concept identification, pattern identification, concept referencing, recall, constraints, conclusions etc.
- Reasoning Steps: Showing how each of these reasoning artifacts is leads to next and the next and the next, and finally, the answer.
Different problems have different approaches. In our “correctionarry” example, our prompt causes the model to realise that the best approach is comparing each letter to “r” even when we weren’t particular about the reasoning steps. This is because the model was “forced” to visit its representation of how counting is done. However, for cases where such representation of the problem is not present in the model’s latent space, we must introduce it through fine-tuning.
Suppose a home maintenance company wants to fine-tune a language model for conducting home inspection and repair, then the fine-tune data must include the thought processes of professional inspectors. This data must highlight the various reasoning checkpoints mentioned earlier and how they lead to the final solution. This way, we are no longer leaving it to the model to find patterns based on its general understanding of the problem, but are now guiding its operation with real life reasoning data in the domain we are interested in applying it to.
Let’s see an example of what part of such data could look like:
[Situation]: Water presence is noted at the junction of the wall and ground. The source of this observed moisture at the wall's base is unknown.
[Goal]: Investigate the problem and detected its cause.
[Protocol]- Visit the site to personally observe the seepage.
[Observation] A liquid water-like substance is observed on the wall.
[Action] Wipe off the water to observe any cracks.
[Observation] The spot remains wet but no visible cracks in observed.
[Deduction]Water must be seeping through the concrete material itself.
[Knowledge]It takes a while for water to make it through concrete.
[Deduction]Water must have been accumulating on the same spot for a while.
[Observation]: The wall is dry.
[Deduction] it means the water isn't coming from behind the wall.
[Action] Investigate what water channels run run behind the wall.
[Observation] A single 6x6 water pipe runs through the channels.
[Action] Check for any leakage from the pipe.
[Observation] No leakage is detected all through the length of the pipe.
[Deduction] The pipe, while not leaking directly, could be a source of condensation due to temperature differences between the water inside and the surrounding environment.
[Action] Observe the ground and surrounding area for any other potential water sources, such as nearby irrigation, rain gutters, or surface runoff.
[Observation] A slight slope directs rainwater towards the wall's base, and the ground directly adjacent to the wall is consistently damp, even when other areas are dry.
[Deduction] Surface runoff and rainwater accumulating at the base of the wall are likely the primary sources of moisture.
[Action] Inspect the surrounding area for any evidence of recent rainfall or irrigation.
[Observation] Recent rainfall was noted and there are no irrigation systems in direct proximity.
[Deduction] The recent rain caused water to run off towards the wall, and the ground is saturated, causing water to wick into the base of the wall through capillary action.
[Action] Consider the composition of the soil and the wall's foundation.
[Observation] The soil is clay-rich, and the wall's foundation appears to be concrete without a proper moisture barrier.
[Deduction] Clay soil retains moisture, and the lack of a moisture barrier allows water to be drawn into the concrete through capillary action.
[Conclusion] The observed moisture at the base of the wall is primarily due to surface runoff and rainwater accumulation, exacerbated by the clay-rich soil and the absence of a moisture barrier in the foundation. Condensation from the pipe may contribute a small amount of moisture, but is not the primary cause.
When the reasoning data contains highly detailed reasoning checkpoints or CoT artifacts, the model isn’t just memorizing the data but is “learning” from the patterns in the data, i.e learning to make hypothesis from observations, to make references to known ideas or facts, to highlight the observations and make deductions etc. All these contribute to the models ability to solve relative complex problems as long as the base model is sufficiently powerful architecturally (architectural optimisation is beyond the scope of this post… and my expertise 😔).
Chain of Thought reasoning is a fairly reliable way to make LLMs reason over a given problem before arriving at a solution. In many cases, it is enough to instruct the LLM to produce more intermediate tokens, in other special cases, applying CoT fine-tuning with carefully engineered and detailed domain reasoning examples, you can adapt LLMs to domain-specific use cases that require much deeper reasoning and therefore build truly useful problem-solving copilots.