Skip to content

Iteration 11. Pseudo beam-search

25-08-2024

Goal

Can I improve the accuracy of the predictions by using a pseudo-beam-search?

Motivation

Beam-search has been probed to generate more accurate responses than greedy decoding. However it is not efficiently implemented on VLLM.

My idea is to generate n responses for the same prompt and select the one with the highest logprob. This would be similar to beam-search, but the implementation would be much more efficient.

Development

Results

Inference speed

n runtime estimated runtime (min)
1 1m50 1.8
2 2m32 2.5
4 3m55 3.8
8 6m32 6.5
16 11m46 11.8
32 22m43 22.5

The runtime increases linearly with n, however there is a constant time that makes that using n=4 only twice the time as n=1.

python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --temperature=0.1 --output_filepath=submission_qwen05_x8_T01_n1.json --n=1

This table could work as a rule of thumb of the slowdown when using pseudo beam-search.

n slowdown
1 1
4 2
10 4
20 8

Accuracy effect

On a previous iteration I was able to see improvements due to beam-search with just 8 predictions per task. Let's try do to the same. I will be using n=20 and different temperatures.

8 predictions

16 predictions

The tendency when increasing the temperature is completely different to the observed on previous experiments. But the improvement in accuracy is not clear.

# estimated runtime 1h30 ~ 16*6
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T01.json --temperature=0.1
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T02.json --temperature=0.2
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T04.json --temperature=0.4
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T06.json --temperature=0.6
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T08.json --temperature=0.8
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=8 --n=20 --output_filepath=submission_qwen05_x8_n20_T10.json --temperature=1.0

python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T01.json --temperature=0.1
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T02.json --temperature=0.2
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T04.json --temperature=0.4
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T06.json --temperature=0.6
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T08.json --temperature=0.8
python inference.py --model_path=/home/gbarbadillo/data/Qwen2-0.5B-arc --predictions_per_task=16 --n=20 --output_filepath=submission_qwen05_x16_n20_T10.json --temperature=1.0

Full experiment

I have run an experiment for 4h49: submission_qwen05_x128_n20_T08 and I'm comparing it to submission_qwen15_x128. The accuracy improves from 5% to 5.2% so the improvement is tiny, and the inference time has been close to 10 times. So it does not seem this is a promising path.

Conclusion

Accuracy improvement is tiny despite inference time being increased a lot. Is better to use compute in other ways.

Next steps

TODO

  • Modify inference script to support this
    • Are the outputs provided by LLM sorted by logprob, or I have to sort them myself? YES THEY ARE ALREADY SORTED
  • How does the inference speed changes when requesting more responses per prompt?
  • Does the accuracy improves?