Iteration 11. Pseudo beam-search
25-08-2024
Goal
Can I improve the accuracy of the predictions by using a pseudo-beam-search?
Motivation
Beam-search has been probed to generate more accurate responses than greedy decoding. However it is not efficiently implemented on VLLM.
My idea is to generate n responses for the same prompt and select the one with the highest logprob. This would be similar to beam-search, but the implementation would be much more efficient.
Development
Results
Inference speed
| n | runtime | estimated runtime (min) | 
|---|---|---|
| 1 | 1m50 | 1.8 | 
| 2 | 2m32 | 2.5 | 
| 4 | 3m55 | 3.8 | 
| 8 | 6m32 | 6.5 | 
| 16 | 11m46 | 11.8 | 
| 32 | 22m43 | 22.5 | 
The runtime increases linearly with n, however there is a constant time that makes that using n=4 only twice the time as n=1.
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --temperature=0.1 --output_filepath=submission_qwen05_x8_T01_n1.json --n=1
This table could work as a rule of thumb of the slowdown when using pseudo beam-search.
| n | slowdown | 
|---|---|
| 1 | 1 | 
| 4 | 2 | 
| 10 | 4 | 
| 20 | 8 | 
Accuracy effect
On a previous iteration I was able to see improvements due to beam-search with just 8 predictions per task. Let's try do to the same.
I will be using n=20 and different temperatures.


The tendency when increasing the temperature is completely different to the observed on previous experiments. But the improvement in accuracy is not clear.
# estimated runtime 1h30 ~ 16*6
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T01.json --temperature=0.1
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T02.json --temperature=0.2
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T04.json --temperature=0.4
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T06.json --temperature=0.6
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T08.json --temperature=0.8
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T10.json --temperature=1.0
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T01.json --temperature=0.1
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T02.json --temperature=0.2
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T04.json --temperature=0.4
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T06.json --temperature=0.6
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T08.json --temperature=0.8
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T10.json --temperature=1.0
Full experiment
I have run an experiment for 4h49: submission_qwen05_x128_n20_T08 and I'm comparing it to submission_qwen15_x128.
The accuracy improves from 5% to 5.2% so the improvement is tiny, and the inference time has been close to 10 times.
So it does not seem this is a promising path.
Conclusion
Accuracy improvement is tiny despite inference time being increased a lot. Is better to use compute in other ways.
Next steps
TODO
-  Modify inference script to support this- Are the outputs provided by LLM sorted by logprob, or I have to sort them myself? YES THEY ARE ALREADY SORTED
 
- How does the inference speed changes when requesting more responses per prompt?
- Does the accuracy improves?