Skip to content

CPS

Для чего это вообще надо?
Почему мы это учим?
Зачем проводить допуск на знание CPS?
Как не запутаться и ней уйти на комсу?

Про это всё наша методичка.

Автор текста - Homka122

Для начала рассмотрим хвостые функции и какую они роль играют.

Хвостовые функции

Рассмотрим довольно простую функцию, которая считает сумму элементов массива

(* int list -> int *)
let rec sum xs =
match xs with
| [] -> 0
| h :: tl -> h + sum tl

В чем проблема такой функции?

Проблема в том, что h + sum tl порождает каждый раз новый фрейм на стеке для того, чтобы обработать повторный вызов sum tl.
Так как нам нужно посчитать h + sum tl мы обязаны оставить старый фрейм целым с переменной h и создать новый фрейм для подсчета sum tl
Это заставляет нас все время возвращаться назад после получения результата.

Хвостовая рекурсия

Рассмотрим немного видоизменную функцию, в которую мы добавим аккумулятор:

(* int list -> int *)
let rec sum_tail xs acc =
match xs with
| [] -> acc
| h :: tl -> sum_tail tl (acc + h)

Отличие этой фукциии от предыдущий в том, что мы смогли добиться хвостовой рекурсии (tail recursion).
Называтеся это так, потому наш рекурсивный вызов sum_tail является последним действием нашей функции
Так как вызов рекурсивной функции является последним, что выполняется в текущем вызове функции - нам нет нужды возвращаться назад! А значит и хранить текущий фрейм нам тоже не нужно — мы можем беспощадно его удалить со стека.

Оптимизация хвостовой рекурсии

Благодаря хвостовой рекурсии мы можем применить больше не хранить старые фреймы и сразу исполнять следующий рекурсивный вызов, тем самым не заполняя стек.
Если раньше при вызове функции sum на списке из 10млн элеметов мы падали с ошибкой Stack Overflow, то теперь мы можем эффективно это считать и не потреблять лишнюю память

Рассмотрим наглядное применение sum и sum_tail на массиве [1; 2; 3; 4]

(* sum (1 :: 2 :: 3 :: 4 :: [])
1 + sum (2 :: 3 :: 4 :: [])
1 + 2 + sum (3 :: 4 :: [])
1 + 2 + 3 + sum (4 :: [])
1 + 2 + 3 + 4 + sum ([])
1 + 2 + 3 + 4 + 0 // Конец фазы 1
1 + 2 + 3 + 4
1 + 2 + 7
1 + 9
10
*)

Это было довольно трудоемко, не так? Более того, в конце 1 фазы мы имеем в памяти 5 стековых фрейвов! Каждый из них хранит лишь одно число и ждет ожидания вычисления функции sum tl, что очень затратно и излишне.
sum_tail лишен этого недостатка и при использовании этой функции у нас не создаются новые лишние стек фреймы.

(* sum_tail (1 :: 2 :: 3 :: 4 :: []) 0
sum_tail (2 :: 3 :: 4 :: []) 1
sum_tail (3 :: 4 :: []) 3
sum_tail (4 :: []) 6
sum_tail ([]) 10
10
*)

Что же такое CPS?

Если хвостовые функции настолько крутые — можно ли все крутые рекурсивные функции сделать хвостовыми?
Хорошая новость: да!
Плохая новость: не всегда это можно сделать просто

Для этого нам на помощь приходит Continutaion Passing Style (CPS преобразование)

CPS преобразование

let sum a b = a + b ⇝ let sumk a b k = k (a + b)

Особенность CPS преобразования в том, что это позволяет перевести нашу рекурсивную функцию в функции с хвостовой рекурсией!
Причем — всегда. Поэтому предупреждения от Какаду:

Примеры CPS преобразования

Рассмотрим несколько примеров CPS преобразований.

Factorial

Классический пример, с которого начинается наше путешствествие — факториал.

(* int -> int *)
let rec fact x = if x = 1 then 1 else x * fact (x - 1)

Заметили в чем проблема? Помимо существования отрицательных чисел.
То, что мы уже видели с функцией sum. Для подсчета fact n нужно заходить в функцию fact (n - 1), возвращаться и считать n * fact (n - 1).
Хотим хвостовую рекурсию. Поэтому применим CPS преобразование:

(* int -> (int -> 'a) -> 'a *)
let rec factk x k = if x = 1 then k 1 else factk (x - 1) (fun r -> k (x * r))

Что же произошло?

Константа

В случае fact при x = 1 нам возвращалось значение 1.
Для CPS преобразования достаточно превратить 1 в k 1

Вызов fact

Заметим, что функция k имеет тип int -> 'a.
Это функция, которая принимает ответ нашей исходной функции.
Поэтому, когда мы видим вызов x * fact (x - 1) стараемся вычислить ответ и передать это в функцию k.
Получаем: x * (factk (x - 1) (fun r -> k r))
А теперь осознаем, что нам в ответе нужно не fact (x - 1), а fact (x - 1) * x, поэтому x также засовываем внутрь функции
Получаем: factk (x - 1) (fun r -> k (r * x))

Готово!