Introducere în mecanismele de inferență
În lumea în continuă evoluție a modelelor de limbaj de mari dimensiuni (LLM), eficița computatională a devenit un aspect critic. Dacă ați utilizat vreodată Qwen, Claude sau orice alt chatbot bazat pe inteligență artificială, ați observat probabil un fenomen interesant: există o întârziere inițială înainte ca primul cuvânt al răspunsului să apară, urmat de o generare succesivă a cuvintelor, unul câte unul, cu o frecvență regulată. Acest comportament nu este întâmplător, ci rezultatul direct al modului în care funcționează arhitectura internă a acestor modele.
La baza funcționării tuturor LLM-urilor se află un concept fundamental: predicția următorului token. Un model de limbaj procesează întregul prompt inițial pentru a produce un singur token nou. Apoi, continuă să adauge tokeni unul câte unul, citind de fiecare dată tot contextul anterior, până când decide că generarea s-a încheiat. Acest proces este extrem de costisitor din punct de vedere computational, necesitând trecerea inputului prin miliarde de parametri pentru fiecare token generat.
Pentru a face aceste modele practice în aplicațiile reale, în special atunci când deservesc mulți utilizatori simultan, cercetătorii și inginerii au dezvoltat o gamă largă de tehnici eficiente de inferență. Una dintre cele mai impactante optimizări este "batching-ul continuu" (continuous batching), care încearcă să maximizeze performanța prin procesarea paralelă a multiple conversații și înlocuirea lor atunci când s-au terminat.
Mecanismul atenției: Fundamentul procesării limbajului
Mecanismul de atenție reprezintă piesa centrală a funcționării LLM-urilor. Un model de limbaj procesează textul prin împărțirea acestuia în bucăți pe care le numim tokeni. Putem concepe conceptual "tokenii" ca fiind "cuvinte", deși uneori un cuvânt poate fi compus din mai mulți tokeni. Pentru fiecare secvență de tokeni, rețeaua calculează o predicție a următorului token.
Multe operațiuni din rețea sunt "token-wise": fiecare token este procesat independent, iar output-ul pentru un anumit token depinde doar de conținutul acelui token, nu de alți tokeni din secvență. Operațiuni precum normalizarea pe straturi sau înmulțirea matricelor intră în această categorie. Totuși, pentru a crea conexiuni între cuvintele dintr-o propoziție, avem nevoie de operațiuni unde tokenii se pot influența reciproc. Aici intervine atenția.
Straturile de atenție sunt singurul loc unde diferiți tokeni interacționează între ei. Înțelegerea modului în care o rețea conectează tokenii împreună înseamnă înțelegerea atenției. Să analizăm cum funcționează acest mecanism în practică, în cazul în care există un singur prompt de intrare.
Considerăm promptul inițial "I am sure this project", tokenizat ca 7 tokeni: [
Arhitectura tensorială și calculele matriceale
Fiecare token este reprezentat în interiorul rețelei printr-un vector de lungime d (dimensiunea ascunsă). Prin urmare, cei șapte tokeni de intrare formează un tensor x cu forma [1, 7, d]. Primul număr, 1, reprezintă numărul de secvențe sau dimensiunea batch-ului, care în cazul nostru este unul singur. Al doilea număr, 7, este lungimea secvenței, iar d este dimensiunea ascunsă sau mărimea reprezentării fiecărui token.
Tensorul de intrare x este apoi proiectat de trei matrice: proiecția query Wq, proiecția key Wk și proiecția value Wv. Aceasta produce trei tensori Q, K și V, toți cu forma [1, n, A], unde A este dimensiunea capului de atenție. Îi numim stările query, key și value, respectiv.
În continuare, tensorii Q și K sunt înmulțiți pentru a măsura similaritatea dintre tokeni, producând un tensor cu forma [1, n, n]. Aceasta este rațiunea pentru care spunem că atenția are o complexitate pătratică în lungimea secvenței. Calcularea QK^T necesită O(n²d) operațiuni, deci costul este un pătrat al lui n, lungimea secvenței.
Masca de atenție și implicațiile sale computaționale
Ulterior, aplicăm o mască de atenție booleană la QK^T pentru a controla care tokeni pot interacționa. În cazul modelelor de limbaj, masca de atenție este o mască cauzală, ceea ce înseamnă că fiecare token interacționează doar cu tokenii care au venit înaintea sa. Aceasta urmează intuiția că o cauză trebuie să preceadă consecința sa, de aici și denumirea de "mască cauzală".
Masca de atenție este crucială deoarece dictează toate interacțiunile dintre tokeni din rețea. Dacă setăm toate valorile măștii de atenție la False, niciun token nu va interacționa cu altul în întreaga rețea. După aplicarea măștii de atenție, luăm un softmax pe rânduri și înmulțim rezultatul cu proiecția value V pentru a obține output-ul unui cap de atenție.
În contextul batching-ului continuu, Q, K și V pot avea numere diferite de tokeni deoarece, așa cum vom vedea, vom procesa diferite stadii (prefill și decode) în același timp. Pentru a generaliza, putem spune că Q are forma [1, nQ, A], K are forma [1, nK, A], iar V are forma [1, nV, A]. Scorurile de atenție QK^T au atunci forma [1, nQ, nK], iar masca de atenție are aceeași formă.
Faza de Prefill și importanța sa strategică
Procesul descris anterior, în care luăm o întreagă secvență de intrare, o trecem prin multiple straturi de atenție și calculăm un scor pentru următorul token, se numește "prefill". Această denumire este justificată de faptul că o mare parte din computația efectuată poate fi stocată în cache și reutilizată – prin urmare, "prefillăm" cache-ul.
Ultimul strat al modelului produce o predicție de token pentru fiecare token de intrare. În contextul generării continuării unui singur prompt, ne interesează doar predicția următorului token de la ultimul token. De exemplu, în figura analizată, ultimul token este "ject", iar predicția asociată este "will".
Datorită utilizării acestui cache, generarea secvenței poate continua folosind mult mai puțină computație într-o fază numită "decoding". În faza de decoding, generarea unui nou token va fi mult mai rapidă decât computația inițială a întregii secvențe.
KV Cache: Optimizarea memoriei și a calculului
Pentru a continua generarea, începem un nou forward pass. Totuși, pentru a calcula scorurile de atenție ale noului token, avem nevoie încă de proiecțiile key și value ale tokenilor anteriori. Fără o optimizare, am necesita repetarea înmulțirii matriceale a tokenilor vechi cu Wk și Wv pentru a recupera un rezultat care a fost deja calculat o dată. În alți termeni, am irosi resurse computaționale.
Observăm imediat că ultimul token nu impactează calculul atenției celorlalți tokeni. Aceasta urmează ideea măștii cauzale: deoarece "will" vine după toți tokenii anteriori, nu schimbă calculul atenției lor. Pentru generarea de text, atenția cauzală este de departe cea mai comună, așa că ne vom concentra pe acest caz.
Având în vedere că avem nevoie doar de predicția următorului token pentru tokenul "will", putem simplifica mecanismul de atenție calculând doar output-ul pentru acest token. Mai mult, am calculat deja stările K și V pentru tokenii "
Acesta este principiul KV cache-ului: lista stărilor key și value create în timpul generării. În esență, permite reducerea costului computațional al generării tokenului n+1 de la O(n²) la O(n), evitând recalcularea proiecțiilor key și value, plătind totuși un cost de memorie de O(n).
Concluzii și perspective de optimizare
Înțelegerea acestor mecanisme fundamentale – atenția, prefill-ul și KV cache-ul – este esențială pentru a aprecia complexitatea și ingeniozitatea tehnicilor moderne de optimizare a inferenței. Batching-ul continuu reprezintă evoluția naturală a acestor concepte, permițând o utilizare mult mai eficientă a resurselor hardware și o experiență superioară pentru utilizatorii finali.