Впровадження зовнішніх знань для LLM з використанням CAG. Аналіз і порівняння з RAG

💡 Усі статті, обговорення, новини про AI — в одному місці. Приєднуйтесь до AI спільноти!

Привіт, спільното! Мене звати Олексій, я — Machine Learning Engineer у компанії Svitla Systems. Нещодавно мене зацікавила тема впровадження зовнішніх знань у LLM за допомогою CAG. Тож я вирішив глибше розібратися, як працює цей алгоритм, і порівняти його з RAG, протестувавши на прикладі біографії Степана Гіги.

Зображення редаговано за допомогою ШІ, оригінал

В цій статті ми розглянемо переваги та недоліки різних способів впровадження зовнішніх знань:

  • Через передачу знань через промпт як контекст для моделі.
  • Через алгоритм RAG.
  • За допомогою генерації з використанням кешу або алгоритму CAG.

Ми порівняємо час генерації та використання пам’яті для кожного з цих підходів і зробимо висновки про їх використання.

TL;DR: Алгоритм CAG демонструє кращі результати та швидшу генерацію, коли контекст повністю поміщається в контекстне вікно, але потребує стільки ж пам’яті, як і з контекстом в промпті. RAG демонструє дещо гірші за CAG результати, і потребує більше часу на генерацію, але дає гнучкість в кількості токенів контекстного вікна.

Для прикладів, як працює впровадження зовнішніх знань, я використав модель Llama3.1-8b. Весь код і вихідні дані ви можете знайти за посиланням в репозиторії.

Код моделі:

hf_model_id = "meta-llama/Llama-3.1-8B-Instruct"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,  # Enable double quantization
    bnb_4bit_compute_dtype=torch.bfloat16  # Set compute dtype to float16
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    hf_model_id,
    quantization_config=quantization_config,
    device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(hf_model_id, device_map=device)

Впровадження зовнішніх знань для LLM

Кожна LLM після навчання має обмеження щодо актуальності та деталізації своїх знань. Наприклад, ще у 2024 році ChatGPT володів інформацією лише до січня 2022-го. Проте час не стоїть на місці й нові дані з’являються в інтернеті щосекунди, і LLM мають орієнтуватися в них без потреби у дорогому перенавчанні через кожен новий шматок інформації.

Саме тому під час використання LLM ми або самостійно додаємо потрібний контекст до запиту, або доручаємо це спеціальному алгоритму, наприклад, RAG. Крім того, впровадження додаткових знань допомагає зменшити кількість «галюцинацій» моделей, оскільки забезпечує їх релевантним контекстом для відповіді на поставлене питання.

А зараз перевіримо, як працює впровадження зовнішніх знань на практиці. Спершу протестуємо, чи має модель необхідні знання вже зараз. Для цього я склав наступний системний промпт для моделі:

SYSTEM_PROMPT = """
You are an experienced pop culture observer who is eager to help others learn about pop culture.
Answer questions with a concise response or just a few sentences. Answer in Ukrainian language.
"""

І далі на базі цього промпту в усіх експериментах від імені користувача я даю моделі одне і те саме запитання: «Who is Stepan Giga?»

Спершу розглянемо варіант без надання жодного контексту. Модель Llama 3.1, яка майже не стикалася з інформацією про українську попкультуру та естраду, не знає, що відповісти, і просто повертає токен завершення діалогу.

def get_input_tokens(tokenizer, chat: list[dict]) -> dict:
    """
    Converts a list of chat messages into input tokens suitable for a HuggingFace model.

    This function applies the tokenizer's chat template to a chat conversation,
    tokenizes the input, and returns a dictionary containing input tensors
    (e.g., input_ids, attention_mask) for model inference.

    Args:
        tokenizer: HuggingFace tokenizer with a defined chat template.
        chat: Conversation represented as a list of message dicts,
            where each dict has 'role' (e.g., 'system', 'user') and 'content' (str).

    Returns:
        Dictionary of model-ready input tensors, such as 'input_ids' and 'attention_mask'.
    """
    return tokenizer.apply_chat_template(
        chat,
        tokenize=True,
        return_dict=True,
        continue_final_message=True,
        return_tensors='pt'
    )

def generate_completion(model, tokenizer, chat: list[dict]) -> tuple[str, int]:
    """
    Generates text with sampling and applied chat-template.

     Args:
        model: HuggingFace model
        tokenizer: HuggingFace tokenizer
        chat: list of system and user's prompts

    Returns:
        Generated completion with number of input tokens.
    """
    input_tokens = get_input_tokens(tokenizer, chat)
    input_tokens.to(device)
    outputs = model.generate(**input_tokens)
    input_length = input_tokens.input_ids.shape[1]
    generated_tokens = outputs[0, input_length:]
    generated_answer = tokenizer.decode(generated_tokens, special_tokens=False)
    return generated_answer, input_tokens["input_ids"].shape[-1]

SYSTEM_PROMPT = """
You are an experienced pop culture observer who is eager to help others learn about pop culture.
Answer questions with a concise response or just a few sentences. Answer in Ukrainian language.
"""

chat = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": "Who is Stepan Giga?"},
]
generated_text, n_tokens = generate_completion(model_4bit, tokenizer, chat)
print(f"Кількість вхідних токенів: {n_tokens}")
print(generated_text)

Результат:
```
<|eot_id|>
```

Генерація відповіді на запитання без додаткових знань зазвичай дає порожній або неінформативний результат. Проте навіть незначне додавання контексту до системного промпту вже дозволяє отримати змістовну відповідь.

У якості контекстних даних я використав інформацію про Степана Петровича Гігу зі сторінки Вікіпедії, попередньо обробивши її за допомогою ChatGPT. Ознайомитися з повним прикладом оброблених даних можна за цим посиланням. А ось фрагмент, як ці дані виглядають:

```
Гіга Степан Петрович (іноді Ґіґа , 16 листопада 1959 (1959-11-16), Білки, Закарпатська область) – український естрадний співак (тенор), композитор, народний артист України (2002).

# Життєпис

Степан Гіга народився 16 листопада 1959 року в селі Білки на Закарпатті. Коли навчався в загальноосвітній школі, то брав уроки вокалу та гри на баяні у місцевого вчителя музики Михайла Копинця. Однак до Ужгородського музучилища він вступив лише з четвертої спроби: після 8, 9 та 10 класу стати студентом не вдалося. Тому, закінчивши десятирічку, він пішов працювати слюсарем до місцевої сільгосптехніки, а згодом – водієм вантажівки. Звідти – в армію, а ще через два роки, після повернення із війська, Степан врешті став учнем Ужгородського музучилища.
…
```

Після зчитування даних і збереженняїх в змінну ‘giga_summary’, я сформував наступний системний промпт:

SYSTEM_PROMPT_WITH_CONTEXT = SYSTEM_PROMPT + f"""
Please take the following context about Stepan Giga into account when answering:
<|start_context|>
{giga_summary}
<|end_context|>

"""

І, об’єднавши новий контекст з запитанням користувача, ми отримаємо наступну відповідь:

chat = [
    {"role": "system", "content": SYSTEM_PROMPT_WITH_CONTEXT},
    {"role": "user", "content": "Who is Stepan Giga?"},
]
generated_text, stats_with_context = measure_generation_stats(generate_completion)(model_4bit, tokenizer, chat)
print(generated_text)

Результат:

Кількість вхідних токенів: 1944
Час на генерацію: 2.872 секунд
Пікове використання VRAM: 6085.79 MB
 
Що за людина Степан Гіга?

Степан Гіга — український співак, композитор та народний артист України.<|eot_id|>

Як видно з наведеного вище результату, якість відповіді суттєво покращилась — модель вже повертає релевантну інформацію. Однак кількість вхідних токенів значно зросла порівняно з попереднім прикладом. У цьому випадку ми можемо дозволити собі розмістити всю необхідну інформацію в контексті, щоб отримати повну відповідь. Але зі збільшенням обсягу даних з’являються такі недоліки:

  • Інформація може виходити за межі контекстного вікна.
  • Дані можуть бути неактуальними, але займати місце в промпті.
  • Моделі стає дедалі важче фокусуватися на важливих деталях — вона може їх просто не помітити. Це називається contextial drift (контекстуальний зсув).

І саме для розв’язання подібних питань використовується RAG.

Впровадження зовнішніх знань з RAG

RAG — це алгоритм, який на базі вашого запиту до моделі підбирає найрелевантніші дані, і надає їх LLM для генерації відповіді. Таким чином LLM отримує уточнені та актуальні дані, які попередньо зберігаються в базі, як правило векторній, і вставляє їх в промпт для відповіді.

Для подальшого порівняння розберемо, як працює RAG детальніше:

  1. Спочатку ми розділяємо релевантну інформацію на окремі фрагменти — документи. Документом може бути, наприклад, речення або параграф тексту.
  2. Для кожного документа створюється embedding — векторне представлення його змісту. Для цього використовують спеціалізовані моделі, наприклад Sentence-BERT або MiniLM, що добре підходять для задач семантичного пошуку.
  3. Документи та їхні embedding-и зберігаються у векторну базу даних, де вони індексуються для ефективного пошуку. Прикладами таких баз є Pinecone, Qdrant або ChromaDB.
  4. Під час запиту користувача текст запиту також перетворюється на вектор, який порівнюється з embedding-ами в базі. Для цього використовується функція відстані — вона визначає, наскільки близькі між собою вектори. Найпоширеніші функції: Cosine Similarity або L2 (евклідова) відстань. Так ми отримуємо топ-N найбільш схожих документів.
  5. Знайдені документи вставляються в промпт як контекст, після чого модель генерує відповідь на основі цієї додаткової інформації.

Щоб застосувати RAG для поточної задачі, я розбив файл з Вікіпедії на документи за параграфами. Для побудови алгоритму використав бібліотеку ‘faiss’ і модель `all-MiniLM-L6-v2` для генерації embedding-ів у поєднанні з функцією пошуку на основі L2. Виглядає код наступним чином:

from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

class SimpleRAG:
    """
    Simple Retrieval-Augmented Generation (RAG) utility.
    
    This class embeds a list of documents, creates a FAISS vector index,
    and enables fast retrieval of the most relevant documents for a given query.
    """

    def __init__(self, documents: list[str]):
        """
        Initialize the SimpleRAG retriever.
        
        Args:
            documents: List of documents (strings) to index and retrieve from.
        """
        self.documents = documents
        self.embeddings_model = SentenceTransformer("all-MiniLM-L6-v2")
        embeddings = self.embeddings_model.encode(documents, convert_to_numpy=True)
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)  # L2 = Euclidean distance
        self.index.add(embeddings)

    def retrieve(self, query: str, top_k: int=2) -> list[str]:
        """
        Retrieve the top_k most relevant documents for a query.
        
        Args:
            query: The query string to search for relevant documents.
            top_k: The number of top documents to retrieve. Default is 2.
        
        Returns:
            List of retrieved documents, ranked by relevance to the query.
        """
        query_embedding = self.embeddings_model.encode([query], convert_to_numpy=True)
        distances, indices = self.index.search(query_embedding, k=top_k)  # top_k results
        return [self.documents[i] for i in indices[0]]

simple_rag = SimpleRAG(giga_summary_documents, top_k=2)
simple_rag.retrieve("Who is Stepan Giga?")

Результат:

['У 1991 році, коли цей гурт розформували, Степан Гіга залишився без роботи. Саме тоді він вперше спробував займатись аранжуванням , почав писати пісні, а згодом створив власну студію звукозапису «GIGARecords».',
 'Гіга Степан Петрович (іноді Ґіґа , 16 листопада 1959 ( 1959-11-16 ) , Білки , Закарпатська область )— український естрадний співак ( тенор ), композитор , народний артист України (2002).']

І тепер, поєднавши результати від пошуку за допомогою RAG з функцією генерації, ми можемо отримати бажаний результат.

def generate_completion_with_RAG(
    model,
    tokenizer,
    chat: list[dict],
    rag_retriever: SimpleRAG,
    top_k: int
) -> tuple[str, int]:
    """
    Generate a language model completion using Retrieval-Augmented Generation (RAG).
    
    This function retrieves the top-k relevant documents from a retriever (such as SimpleRAG)
    based on the user's last message in the chat, injects them as additional context into the
    system prompt, and generates a model completion. The number of input tokens used for the prompt
    is also returned for analysis.
    
    Args:
        model: HuggingFace transformer model for generation.
        tokenizer: HuggingFace tokenizer corresponding to the model.
        chat: List of chat messages in the format 
            [{'role': 'system'/'user', 'content': <str>}].
        rag_retriever: The retriever instance for relevant context.
        top_k: Number of top documents to retrieve for context. Default is 2.
    
    Returns:
        tuple: Generated text completion and number of input tokens.
    """
    user_message = next((m['content'] for m in reversed(chat) if m['role'] == 'user'), "")
    retrieved_docs = rag_retriever.retrieve(user_message, top_k=top_k)
    context_block = "\n\n".join(retrieved_docs)
    SYSTEM_PROMPT_WITH_RAG = (
        SYSTEM_PROMPT + f"""
Please take the following context about Stepan Giga into account when answering:
<|start_context|>
{context_block}
<|start_context|>
        """
    )
    chat_with_rag = [
        {"role": "system", "content": SYSTEM_PROMPT_WITH_RAG},
        {"role": "user", "content": user_message}
    ]
    return generate_completion(model, tokenizer, chat_with_rag)

# Suppose your SimpleRAG instance is named simple_rag
simple_rag = SimpleRAG(giga_summary_documents)
chat = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": "Who is Stepan Giga?"},
]
generated_text, stats_with_RAG = measure_generation_stats(generate_completion_with_RAG)(
    model_4bit, 
    tokenizer, 
    chat, 
    simple_rag, 
    top_k=2
)
print(generated_text)

Результат:

Кількість вхідних токенів: 223
Час на генерацію: 2.564 секунд
Пікове використання VRAM: 5895.84 MB
 
У нього є власна студія звукозапису «GIGARecords», яка була створена після розпаду його попереднього гурту.<|eot_id|>

Як видно з результату, модель повернула не точну відповідь на запитання, але з набагато меншим часом генерації та значно меншою кількістю вхідних токенів. Якщо раніше весь промпт містив 1941 токен, то з використанням RAG — лише 223. Це відкриває можливість додавати додатковий контекст у разі потреби та скорочує час генерації відповіді.

Впровадження зовнішніх знань з CAG

До цього моменту ми розглядали способи інтеграції додаткових знань виключно через текст — шляхом додавання нових токенів до промпту. Альтернативою цьому підходу є алгоритм Cache-Augmented Generation (CAG).

Цей CAG, цей CAG мені щоночі сниться (зображення редаговано, оригінал)

Суть алгоритму полягає в тому, що за умови наявності достатньо великого контекстного вікна можна помістити увесь контекст у промпт і розпочати генерацію відповіді. Після цього — зберегти стан матриць Key та Value, тобто закешувати ці ваги. Далі, коли користувач надсилає запит, ми можемо повторно використати цей кеш, додавши до нього сам запит, і згенерувати відповідь на основі вже збереженого контексту.

Схематично CAG виглядає наступним чином:

Приклад роботи алгоритма CAG. Ілюстрація авторська

Переваги цього підходу порівняно з RAG:

  • Значно спрощує архітектуру.
  • Усуває етап з пошуком і отриманням необхідних даних.
  • Забезпечує неподільний контекст, що дозволяє моделі краще орієнтуватися в інформації.

Розробники CAG зазначають, що алгоритм демонструє покращення точності відповідей у таких бенчмарках, як HotPotQA та SQuAD. Це також проілюстровано на скріншоті нижче.

Результати порівняння CAG та RAG

За наступними посиланнями ви можете знайти подробиці цього дослідження і відповідного коду, на базі якого ґрунтувався і мій алгоритм.

from transformers.cache_utils import DynamicCache

def sample_with_temperature(logits, temperature=1.0):
    logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    return next_token

def clean_up(kv: DynamicCache, origin_len: int):
    """
    Truncate the KV Cache to the original length.
    """
    for i in range(len(kv.key_cache)):
        kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
        kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]

def generate_completion_with_CAG(
    model,
    tokenizer,
    prompt: str,
    past_key_values,
    max_new_tokens: int | None = None
) -> tuple[str, int]:
    """
    Generates text with CAG cache and applied chat-template.

    Args:
        model: HuggingFace model
        tokenizer: HuggingFace tokenizer
        chat: list of system and user's prompts
        past_key_values: KV Cache for knowledge
        max_new_tokens: Maximum new tokens to generate

    Returns:
        Generated completion and number of input tokens.
    """
    max_new_tokens = getattr(model.generation_config, "max_new_tokens", 200)
    temperature = getattr(model.generation_config, "temperature", 1)

    kv_len = past_key_values.key_cache[0].shape[-2]

    prompt = f"""{prompt}<|eot_id|>\n\n<|start_header_id|>assistant<|end_header_id|>\n\n"""
    # 1. Prepare input tokens
    input_tokens = tokenizer(
        prompt,
        return_tensors='pt',
    )
    input_ids = input_tokens['input_ids'].to(model.device)
    origin_ids = input_ids
    num_input_tokens = input_ids.shape[-1]

    output_ids = input_ids.clone()
    next_token = input_ids

    with torch.no_grad():
        # Next token prediction
        for _ in range(max_new_tokens):
            outputs = model(
                input_ids=next_token,
                past_key_values=past_key_values,  # Insert knowledge cache
                use_cache=True
            )
            next_token_logits = outputs.logits[:, -1, :]
            next_token = sample_with_temperature(next_token_logits, temperature)
            next_token = next_token.to(model.device)
            past_key_values = outputs.past_key_values
            output_ids = torch.cat([output_ids, next_token], dim=1)

            # Stop if EOS token
            if next_token.item() in {tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids('<|eot_id|>')}:
                break

    # 2. Decode ONLY the generated tokens
    generated_tokens = output_ids[:, num_input_tokens:]
    generated_answer = tokenizer.decode(generated_tokens[0], skip_special_tokens=False)

    clean_up(past_key_values, kv_len)
    return generated_answer, num_input_tokens

def preprocess_knowledge(
    model,
    tokenizer,
    chat: list[dict],
) -> DynamicCache:
    """
    Prepare knowledge kv cache for CAG.

    Args:
        model: HuggingFace model with automatic device mapping
        tokenizer: HuggingFace tokenizer
        prompt: The knowledge to preprocess, which is basically a prompt

    Returns:
        DynamicCache: KV Cache
    """
    # Chat template with add_generation_prompt parameter
    input_tokens = tokenizer.apply_chat_template(
        chat,
        tokenize=True,
        return_dict=True,
        continue_final_message=True,
        return_tensors='pt'
    )
    input_ids = input_tokens["input_ids"].to(device)
    past_key_values = DynamicCache()
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            output_attentions=False,
            output_hidden_states=False
        )
    return outputs.past_key_values

Як видно з наведеного вище коду, весь алгоритм поділяється на два основні етапи:

  1. Підготовка кешу в функції «preprocess_knowledge».
  2. Використання кешу для генерації в функції «generate_completion_with_CAG».

У своєму коді дослідники формували кеш на основі системного промпту, а контекст додавали вже у промпт користувача. Я ж вирішив діяти трохи інакше: використав системний промпт, який уже містив контекст.

У результаті структура chat template поділяється на дві частини:

  • Системний промпт + початок запиту користувача (на етапі підготовки кешу).
  • Продовження запиту користувача + промпт асистента (для запуску генерації).

Запит до моделі після кешування виглядає наступним чином:

prompt = f"""{prompt}<|eot_id|>\n\n<|start_header_id|>assistant<|end_header_id|>\n\n"""

Тож, зібравши все до купи, можна зробити кеш і згенерувати відповідь на поставлене питання:

chat = [
    {"role": "system", "content": SYSTEM_PROMPT_WITH_CONTEXT},
    {"role": "user", "content": "Question:"}
]
# Створення кешу з використанням chat_templates.
past_key_values = preprocess_knowledge(model_4bit, tokenizer, chat)

# Запит до моделі вже без chat_templates, ніби з продовження закешованого запиту.
user_prompt = "Who is Stepan Giga?"
generated_text, stats_with_CAG = measure_generation_stats(generate_completion_with_CAG)(
    model_4bit,
    tokenizer,
    user_prompt,
    past_key_values,
)
print(generated_text)

Результат:

Кількість вхідних токенів: 14
Час на генерацію: 2.441 секунд
Пікове використання VRAM: 6093.39 MB
Степан Гіга український естрадний співак (тенор), композитор та народний артист України.<|eot_id|>

Окрім основного тесту я вирішив перевірити, як CAG працює з додатковими запитаннями. Наприклад, перечитавши біографію Степана Петровича, я знаю, що його першим музичним інструментом був баян. Тож я поставив відповідне запитання й отримав таку відповідь:

user_prompt = "What was the first musical instrument of Stepan Giga?"
generated_text, _ = generate_completion_with_CAG(
    model_4bit,
    tokenizer,
    user_prompt,
    past_key_values,
)
print(generated_text)

Результат:

Першим інструментом Степана Гіги був баян.<|eot_id|>

Далі я запитав, яким був його найуспішніший альбом (спойлер: Вулиця Наталі, 1 млн копій).

user_prompt = "Name the most successful album of Stepan Giga."
generated_text, _ = generate_completion_with_CAG(
    model_4bit,
    tokenizer,
    user_prompt,
    past_key_values,
)
print(generated_text)

Результат:

Найуспішнішим альбомом Степана Гіги є «Вулиця Наталі».<|eot_id|>

Як видно з наведених прикладів, алгоритм чудово впорався з поставленими запитаннями. Тепер настав час порівняти результати всіх трьох підходів.

Аналіз і порівняння Context vs RAG vs CAG

Маючи результати всіх трьох алгоритмів, я заміряв обсяг використаної відеопам’яті та швидкість генерації для кожного з них.

Дослідники CAG стверджують, що генерація за допомогою цього алгоритму в багатьох випадках швидша, ніж у RAG. Це залежить як від кількості прикладів у контексті, так і від способу отримання даних із векторної бази. Результати їхніх досліджень наведені на скріншоті нижче.

Як видно з результатів, CAG може поступатися RAG при використанні Top-3 документів. Проте він демонструє перевагу у випадках, коли залучено Top-10 документів або повністю використовується контекстне вікно.

У моєму випадку дослідження проводилось не на бенчмарках, а лише на одному запитанні. RAG повертав лише 2 документи з можливих 21, тому різниця в часі генерації між підходами була несуттєвою. Враховуючи ці фактори, можна сказати, що результати моїх експериментів узгоджуються з висновками розробників: швидкість генерації CAG вища, ніж при повному контексті та RAG із Top-2 документами. Водночас якщо збільшити кількість документів у RAG з 2 до 21, то час на генерацію відповіді зменшується.

Графіки показників швидкості генерації. Ліворуч — RAG Top-2 документи, праворуч — Top-21 документ

Далі — порівняння обсягу необхідної відеопам’яті для кожного з підходів. Оскільки CAG базується на збереженні кешу, то потреба в пам’яті більша порівняно з RAG, де ми отримаємо тільки невелику частину документів для контексту. Але при збільшенні кількості документів споживання пам’яті може перевершити CAG.

Порівняння в споживанні пам’яті між різними підходами. Ліворуч — RAG Top-2 документи, праворуч — Top-21 документ

Висновок

Порівнюючи три різні підходи до впровадження зовнішніх знань, не можна однозначно сказати, що якийсь один є універсальним розв’язання усіх проблем. Хоча CAG демонструє високу швидкість генерації та точність відповідей, він споживає більше відеопам’яті та залишає менше простору в контекстному вікні, так само як і in-context learning. До того ж він вимагає попередньої обробки даних: створення кешу та його завантаження безпосередньо в модель, чого не потребує RAG. Натомість RAG зазвичай використовує менше відеопам’яті, однак його точність і швидкість залежать від кількості отриманих документів і їх якості.

Окрім того, в процесі розробки цих прикладів я зіткнувся з тим, що model.__call__ методи бібліотеки transformers не мають такого зручного інтерфейсу для семплінгу, як top_k, top_p токенів і на відміну від model.generation методів, їх потрібно імпортувати окремо.

Проблема полягає у тому, що для продовження генерації з DynamicCache потрібно передати коректний cache_position, а внутрішня реалізація Transformers очікує то скаляр, то діапазон-тензор; ця невідповідність бібліотеки не дає завершити інтеграцію CAG через model.generate(), незалежно від коду. Тому якщо вам цікаво, я (перекладаю відповідальність на читача) пропоную вам ознайомитися з кодом в іншій гілці мого репозиторію, це може стати непоганою відправною точкою. Якщо ж у вас є ідеї або поради, як обійти ці обмеження, буду радий побачити ваші думки в коментарях.

Детально свої висновки я виклав в таблиці нижче.

CAG — не універсальний інструмент, але безперечно перспективний. Його доречно використовувати в сучасних завданнях генерації, особливо з огляду на тренд збільшення контекстного вікна в нових моделях. Крім описаного мною підходу з кешем, вже існують гібридні рішення, де кеш зберігається у векторній базі даних і завантажується за запитом користувача. Це дозволяє поєднати переваги обох підходів і отримати ще більш точні відповіді.

Дякую, що дочитали до кінця! Сподіваюся, мені вдалося додати щось корисне до вашого інструментарію та покращити розуміння теми впровадження зовнішніх знань у генеративні моделі.

Сподобалась стаття? Підписуйтесь на автора, щоб отримувати сповіщення про нові публікації на пошту.

👍ПодобаєтьсяСподобалось15
До обраногоВ обраному4
LinkedIn
Дозволені теги: blockquote, a, pre, code, ul, ol, li, b, i, del.
Ctrl + Enter
Дозволені теги: blockquote, a, pre, code, ul, ol, li, b, i, del.
Ctrl + Enter

Ваша стаття не пройшла повз. Додав собі у закладки. Дякую за працю!

Яка реальна різниця між CAG i Context?
і в одному і в іншому випадку дані ідуть до ЛЛМ разом промптом. Можливо в першому випадку дані ідуть з позначкою, що це системна інструкція або додатковий контекст.
Якщо я в промпті напишу «Context: » потім той текст про гігу а потім сам промпт то буде то саме?

І в першому і в другому випадку на вхід до моделі йде контекст з текстом про Гігу.

В випадку з контекстом модель, перед передбаченням наступного токену для відповіді, повинна обчислити увагу, або значення активації для нейронів, для всіх попередніх токенів контексту. Тобто перерахувати значення для контексту від першого до останнього токену і одразу згенерувати відповідь.

В випадку з CAG значення контексту попередньо обчислюються тільки для KV матриць, тобто умовно для 2/3 нейронів щоб модель «зхопила контекст» і далі ми ці значення кешуємо, на цьому кроці модель ніякої відповіді не генерує. Далі, маючи якийсь промпт від користувача, ми завантажуємо цей кеш в модель і тільки після цього генеруємо відповідь згідно з промпту.

Зрозуміло. Дякую.
Тобто, для rag це не буде працювати, бо там кешування нучого не дасть.

Якщо вікно контексту велике, типу мільйон токенів. То получається, що можна робити якісь операції з цілими книгами, наприклад художній переклад роману з збереженням єдиного стилю.
Весь роман закешувати і передавати як CAG і в промптах перекладати сторінку за сторінкою.
Цікаво чи хтось таке робив вже

Якщо порівнювати rag / cag — в якому підході моделі краще дослухаються до зовнішніх знань, навіть якщо вони суперечать наявним знанням моделі (zero-shoot rag inference)
Бо з RAG — моделі часто ігнорували знання навіть якщо вони максимально релевантні до питання

Дякую за запитання. Більшість мені відомих порівнянь RAG vs CAG виконано на HotPotQA та SQuAD, тобто на даних, які імовірно вже з’являлися у передтренінгу моделей бо обидва засновани на даних з вікі. Але є приклад IBM, де вони робили свій корпоративний FAQ (developer.ibm.com/...​ache-augmented-generation) і там теж непогані вийшли результати.

В моєму випадку, коли я надавав контекст через RAG параграфами top_k=2, то модель давала несенітницю, не дивлячись на отриману релевантну інформацію. В випадку коли надавав контекст семантичними розділами (400-500 токенів в розділі), то результати були більш сталими. В випадку з CAG проблема не виникла, тому що вся довідкова стаття потрапляє в KV-кеш під час пред-обробки, і модель бачить повний контекст при кожній відповіді.

Підписатись на коментарі