Pull to refresh
38.95

Как ускорить LLM-генерацию текста в 20 раз на больших наборах данных

Reading time7 min
Views6.9K

Всем привет, я Алан, разработчик-исследователь MTS AI. В команде фундаментальных исследований мы занимаемся исследованием LLM, реализацией DPO и валидацией наших собственных языковых моделей. В рамках этих задач у нас возникла потребность в генерации большого количества данных с помощью LLM. Такая генерация обычно занимает много времени. Однако за последний год, с ростом популярности LLM, стали появляться различные инструменты для развертывания таких моделей. Одной из самых эффективных библиотек для инференса языковых моделей является библиотека vLLM. В статье показывается, как с помощью асинхронных запросов и встроенных особенностей vLLM можно увеличить скорость генерации примерно в 20 раз. Приятного чтения!

Библиотека vLLM

vLLM- это библиотека, которая позволяет относительно просто развернуть языковую модель для инференса, в том числе в виде OpenAI- совместимого API. Библиотека поддерживает квантизированные модели (методы GPTQ, AWQ и SqueezeLLM), оптимизацию памяти с помощью PagedAttention, оптимизированные ядра CUDA, непрерывный батчинг и прочее. Также vLLM работает с большинством популярных моделей с HuggingFace.

OpenAI- совместимый сервер в vLLM реализует такие эндпоинты как: list models, ChatCompletion и Completion. Развертывание API выполняется следующим образом:

python -m vllm.entrypoints.openai.api_server --model your_model_path

Теперь мы можем обращаться к модели так, как если бы обращались к официальному API OpenAI:

import openai

openai.api_base = "адрес, на котором развернуто API"
system_prompt = (
        "Your system prompt"
    )
user_prompt = "Как дела?"
messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

result = openai.ChatCompletion.create(
            model="your_model_name",
            messages=messages,
            max_tokens=4096,
            temperature=0.9,
            top_p=0.6
        )
print(result["choices"][0].get("message").get("content"))

Примерно таким образом мы генерируем ответы для X- набора инструкций и это слишком медленно. Но прежде чем переходить к ускорению генерации, для начала вспомним, как происходит инференс в LLM. На вход подается последовательность токенов, затем мы итеративно генерируем следующие токены, пока не будет сгенерирован eos-токен или достигнута максимальная длина последовательности. Рассмотрим некоторые детали инференса:

  • Время обработки входного промпта примерно равно времени генерации выходной последовательности. Это происходит потому, что в начале вычисляются входные данные для механизма внимания, которые остаются постоянными в течении всего времени генерации.

  • Пропускная способность инференса LLM в значительной степени определяется тем, насколько большой батч вы можете поместить в память графического процессора

Вместо того, чтобы каждый раз загружать параметры модели, когда начинается генерация последовательности, мы можем загрузить веса один раз, а затем использовать для обработки батча последовательностей. Кратко рассмотрим два типа батчинга, которые помогут нам генерировать данные значительно быстрее

Статический батчинг

Метод статических матчей предполагает, что размер батча остается постоянным до тех пор, пока каждая последовательность не будет сгенерирована до конца, таким образом мы повышаем пропускную способность. При этом память графического процессора все еще используется не так эффективно, поскольку скорость обработки батча ограничена временем генерации самой длинной последовательности.

Статический батчинг. Изображения взяты отсюда
Статический батчинг. Изображения взяты отсюда

Официальное Completion API OpenAI позволяет генерировать данные батчами, при этом такой способ уже не работает для ChatCompletion API, который используется для работы с современными инструктивными моделями. Ниже представлены два примера: стандартная генерация и батч- генерация

Стандартная генерация:

from openai import OpenAI
client = OpenAI()
 
num_stories = 10
prompt = "Я пошел в магазин и купил"
 

for _ in range(num_stories):
    response = client.completions.create(
        model=model_name,
        prompt=prompt,
        max_tokens=100,
    )
    
    print(prompt + response.choices[0].text)

Генерация статического батча:

from openai import OpenAI
client = OpenAI()
 
num_stories = 10
prompts = ["Я пошел в магазин и купил"] * num_stories

response = client.completions.create(
    model=model_name,
    prompt=prompts,
    max_tokens=100,
)
 
stories = [""] * len(prompts)
for choice in response.choices:
    stories[choice.index] = prompts[choice.index] + choice.text

for story in stories:
    print(story)

Непрерывный батчинг

Вместо того, чтобы ждать, пока каждая последовательность в батче будет сгенерирована, размер батча может определяться отдельно для каждой итерации, с помощью реализации метода ORCA в библиотеке vLLM. Это означает, что как только в батче завершится генерация одной последовательности, на ее место сразу становится новая последовательность, что обеспечивает более высокую нагрузку на видеокарту, чем при статическом батчинге. Подробнее про батчинг в vLLM можно прочитать здесь.

Непрерывный батчинг. Изображения взяты отсюда
Непрерывный батчинг. Изображения взяты отсюда

PagedAttention

Помимо непрерывного батчинга, vLLM также реализует механизм PagedAttention, что также приводит к значительному ускорению инференса. В процессе авторегрессионного декодирования для каждого входного токена генерируются тензоры K и V для слоя внимания, затем эти тензоры сохраняются в памяти GPU для генерации следующих токенов. Кэшированные тензоры key и value часто называют KV- кэшем. KV- кэш занимает достаточно большой объем памяти: при генерации одной последовательности с помощью LLaMA-13B требуется до 1,7 ГБ. Также стоит учитывать, что KV- кэш занимает разный объем памяти, в зависимости от длины последовательности.

В результате, задача эффективного управления KV- кэшем становится проблемой. Команда vLLM обнаружила, что существующие системы впустую тратят около 80 процентов памяти из- за фрагментации и избыточного резервирования.

Чтобы решить эту проблему, vLLM реализовала механизм PagedAttention, который вдохновляется классической идеей виртуальной памяти и страничной организации памяти. В отличие от традиционных алгоритмов внимания, PagedAttention позволяет хранить непрерывные key и value в несмежном пространстве памяти. PagedAttention разбивает KV- кэш каждой последовательности на блоки, каждый блок содержит key и value для фиксированного количества токенов. Во время вычислений ядро PagedAttention эффективно идентифицирует и извлекает эти блоки.

PagedAttention
PagedAttention

Проводя аналогию со страничной организацией памяти, блоки можно рассматривать как страницы, токены - как байты, а последовательности - как процессы. Непрерывные логические блоки последовательности сопоставляются с несмежными физическими блоками через таблицу блоков. Физические блоки распределяются по мере генерации новых токенов.

PagedAttention
PagedAttention

PagedAttention позволяет системе обрабатывать большие батчи последовательностей, увеличивать загрузку графического процессора и тем самым значительно увеличивать скорость инференса. Также PagedAttention естественным образом обеспечивает совместное использование памяти через таблицу блоков. Аналогично тому, как процессы совместно используют физические страницы, различные последовательности в PagedAttention могут совместно использовать блоки, сопоставляя их логические блоки одному и тому же физическому блоку. Чтобы обеспечить безопасный общий доступ, Paged Attention отслеживает количество ссылок на физические блоки и реализует механизм Copy-on-Write.

Подробнее про PagedAttention можно прочитать здесь.

Chat Completion API и asyncio

Для того, чтобы воспользоваться непрерывным батчингом и значительно ускорить генерацию с помощью vLLM, мы должны сделать к API много запросов одновременно. Для этого было решено использовать acreate и библиотеку asyncio.

acreate- это асинхронная версия метода openai.ChatCompletion.create, который возвращает корутину

корутина- это асинхронный блок кода, который может быть приостановлен, когда у нас есть потенциально длительно выполняющаяся задача, а затем возобновлен, когда эта задача завершена

В первой версии запроса мы отправляли данные на vLLM- сервер батчами по 128 запросов, код приведен ниже:

import openai
import asyncio


async def generate_answers(prompt):
    completion = await openai.ChatCompletion.acreate(
        model="model_name",
        messages=[{"role": "user", "content": prompt}],
        max_length=1028
    )

    return completion["choices"][0].get("message").get("content")

  
results = []
batch_size = 128
async def main(batch):
    tasks = []
    global results
    for idx, prompt in enumerate(batch):
        task = asyncio.create_task(generate_answers(prompt["prompt"]["text"]))
        tasks.append(task)

    answers = await asyncio.gather(*tasks)
    results += answers


for idx in range(0, len(data), batch_size):
    print("IDX:", idx)
    asyncio.run(main(data[idx:idx+batch_size]))

Скорость стала выше, но это не совсем эффективный запрос, потому на vLLM можно подать сразу весь набор данных и сервер автоматически организует очередь обработки последовательностей, максимальный размер батча в таком случае будет составлять 256. Теперь мы отправляем весь датасет сразу:

import openai
import asyncio


async def generate_answers(prompt):
    completion = await openai.ChatCompletion.acreate(
        model="model_name",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=1024,
        request_timeout=10000
    )
    return completion["choices"][0].get("message").get("content")


async def main(data):
    tasks = []
    for idx, prompt in enumerate(data):
        task = asyncio.create_task(generate_answers(prompt))
        tasks.append(task)

    answers = await asyncio.gather(*tasks)
    return answers

# data- это набор данных, на котором генерируются ответы
results = asyncio.run(main(data))

Время генерации данных для 100 и 2000 инструкций на видеокарте A100 40 гб:

Количество инструкций

Время синхронной генерации

Время асинхронной генерации

100

6 м 10 с

1 м

2000

2 ч 50 м

5-10 м

Как результат мы получаем многократное ускорение благодаря асинхронным запросам, непрерывному батчингу и механизму PagedAttention в vLLM.

Стоит упомянуть, что независимо от длины инструкции, PagedAttention выделит память с максимальным количеством токенов для KV- кэша. Следовательно, ускорение так или иначе зависит от статистического распределения количества токенов в наборе данных, поэтому результаты ускорения на разных наборах данных могут отличаться (ускорение от 20 до 30 раз).

Заключение

Непрерывный батчинг вместе с механизмом внимания PagedAttention действительно позволяют многократно увеличить скорость инференса, однако можно выделить два момента:

  • Качество генерации некоторых ответов может падать, возможно это зависит от используемой языковой модели

  • Необходимо выставлять большое время ожидания при запросе, так как некоторые последовательности могут генерироваться достаточно долго(они генерируются в последнюю очередь)

Даже с учетом этого, реализованные в vLLM методы позволяют значительно ускорить эксперименты большими языковыми моделями

Надеемся, что статья окажется полезной в вашей работе с LLM.

Tags:
Hubs:
Total votes 14: ↑13 and ↓1+16
Comments0

Articles

Information

Website
mts.ai
Registered
Founded
Employees
201–500 employees
Location
Россия
Representative
Анна Родина