Сегодня я немного расскажу о решении задачи классификации с использованием программного пакета R и его расширений. Задача классификации, пожалуй, одна из самых распространенных в анализе данных. Существует множество методов для ее решения с использованием разных математических техник, но нас с тобой, как апологетов R, не может не радовать, что при этом программировать что-либо с нуля не нужно, — все есть (причем далеко не в единственном экземпляре) в системе пакетов R.

 

Задача классификации

Задача классификации — типичный пример «обучения с учителем». Как правило, мы располагаем данными в виде таблицы, где столбцы содержат значение наборов признаков для каждого случая. Причем все строки заранее размечены таким образом, что один из столбцов (положим, что последний) указывает на класс, к которому принадлежит данная строка. Как хороший пример можно привести задачу классификации писем на спам и не спам. Для того чтобы воспользоваться алгоритмами машинного обучения, нужно для начала иметь размеченные данные — такие, для которых значение класса известно наряду с остальными признаками. Причем набор данных должен быть существенным, особенно если количество признаков велико.

Если у нас есть достаточно данных, то можно начинать обучение модели. Общая стратегия с классификаторами не особо зависит от модели и включает следующие шаги:

  • выбор тренировочного и тестового множества;
  • обучение модели на тренировочном множестве;
  • проверка модели на тестовом множестве;
  • перекрестная проверка;
  • улучшение модели.
 

Точность и полнота

Как оценить, насколько хорошо работает наш классификатор? Непростой вопрос. Дело в том, что различные варианты развития событий возможны, даже если у нас есть всего только два класса. Допустим, мы решаем задачу фильтрации спама. После проверки модели на тестовом множестве мы получим четыре величины:

TP (true positive) — сколько сообщений было правильно классифицировано как спам,
TN (true negative) — сколько сообщений было правильно классифицировано как не спам,
FP (false positive) — сколько сообщений было неправильно классифицировано как спам (то есть письма спамом не были, но модель классифицировала эти сообщения как спам),
FN (false negative) — сколько сообщений было неправильно классифицировано как не спам, а на самом деле это был все-таки Центр американского английского.


В зависимости от деталей решаемой задачи качество модели может быть оценено по-разному. В каждом конкретном случае нам нужно выбрать, что важнее: чтобы в спам попало как можно больше писем, которые на самом деле являются спамом, и мы готовы пожертвовать некоторыми важными сообщениями или же чтобы в спам попадали только гарантированно «плохие» сообщения, а все остальные сообщения терялись в папке спам как можно реже (наверное, у всякого приличного интернетчика хотя бы раз в жизни важное-важное сообщение пропадало в спаме). Редко удается усидеть на двух стульях, и поэтому приходится при выборе классификатора принять решение в пользу той или иной стратегии. Как правильно отразить эти два крайних случая с математической точки зрения? Математической реализацией этих двух случаев служат точность (precision) и полнота (recall) соответственно. Это сравнительно простые характеристики модели, которые определяются следующим образом:

  • точность PPV = TP / (TP + FP);
  • полнота TPR = TP / (TP + FN).

Иногда хочется иметь величину, которая бы показывала качество модели в целом, учитывая и точность и полноту. Для этого можно использовать так называемую F-меру:

F1 = 2*TP / (2*TP + FP + FN)

F-мера представляет собой гармоническое среднее между точностью и полнотой. Это означает, что если хотя бы одна из этих двух величин стремится к нулю, то и F-мера стремится к нему же. На самом деле описанный вариант учитывает и точность, и полноту в равной степени, однако можно корректировать значимость каждой из величин (см. формулу 1).

Формула 1. Общее определение F-меры с коэффициентом beta, который устанавливает значимость полноты в формуле: если beta > 1, то полнота важнее, если beta < 1, то важнее точность
Формула 1. Общее определение F-меры с коэффициентом beta, который устанавливает значимость полноты в формуле: если beta > 1, то полнота важнее, если beta < 1, то важнее точность
 

Эффект переобучения

Бывает, модель «выучивает» те свойства обучающего множества, которые не характерны для данных в общем случае, как иногда говорят — отсутствуют в генеральной совокупности. Такая модель на реальных данных будет давать плохие результаты. Этот феномен известен как переобучение, или overfitting. Способы борьбы с переобучением зависят от алгоритма машинного обучения. К примеру, для деревьев решений от эффекта переобучения можно избавиться за счет отсечения некоторых ветвей (см. далее).

 

Перекрестная проверка

Перекрестная проверка, или cross-validation, — это всего лишь метод проверки модели. Очень часто данные делят тренировочное и тестовое множество в соотношении 2 : 1. Таким образом, обучают модель на 2/3 данных, а тестируют на 1/3. А что будет, если с данными что-то не так? И как понять, что модель ведет себя неправильно? Для того чтобы убедиться, что модель работает хорошо, поступают следующим образом: данные делятся на k частей, из этих частей (k – 1) используется для обучения модели, а одна для тестирования. Процедура повторяется k раз, так, чтобы каждый из семплов использовался для тестирования ровно один раз. Такой способ перекрестной проверки известен как k-fold. Часто используется 10-fold, но на самом деле нет никаких специальных ограничений на k.

Каждый легко может представить и другие виды перекрестной проверки, решающие ту же самую задачу. В целом все методы можно поделить на исчерпывающие (exhaustive) и неисчерпывающие (non-exhaustive). Первые проверяют все возможные варианты для обучения/тестирования, а вторые — лишь часть. Ко второму классу можно отнести методы со случайной выборкой типа Монте-Карло. Очевидно, что k-fold относится к исчерпывающим методам. Его можно усложнить, сделав разбиение на (k – p) и p семплов соответственно.

 

Деревья решений

Теперь можно приступить к практике. Начнем с чего-нибудь интуитивно понятного.

Для начала нужно выбрать данные, с которыми мы планируем работать. В интернете можно найти огромное количество данных для анализа и обработки, но так как цель этой статьи сугубо обучающая, то мы воспользуемся классическим источником: существует прекрасная книга An Intoduction to Statistical Learning with Application with R, которую легко найти в интернете. Для R есть специальный пакет ISLR, нужно будет установить его перед работой. Также нам понадобится пакет tree, в котором реализован алгоритм работы с деревьями решений (на самом деле это далеко не единственный пакет для подобных задач, чуть позже мы рассмотрим и другие варианты). Для начала подключаем все это хозяйство:

library(ISLR)
library(tree)

head(Carseats)

attach(Carseats)

Здесь мы будем работать с данными Carseats из пакета ISLR, это учебный материал по продажам автокресел в 400 различных магазинах. Для начала посмотрим, что представляет собой этот набор данных:

  Sales CompPrice Income Advertising Population Price ShelveLoc Age Education Urban  US
1  9.50       138     73          11        276   120       Bad  42        17   Yes Yes
2 11.22       111     48          16        260    83      Good  65        10   Yes Yes
3 10.06       113     35          10        269    80    Medium  59        12   Yes Yes
4  7.40       117    100           4        466    97    Medium  55        14   Yes Yes
5  4.15       141     64           3        340   128       Bad  38        13   Yes  No
6 10.81       124    113          13        501    72       Bad  78        16    No Yes

С помощью attach мы делаем колонки обычными переменными, то есть вместо того, чтобы писать каждый раз Carseats$Sales, теперь можно просто указывать Sales, что несомненно удобно. Положим, мы хотим просто разделить все продажи на высокие и низкие, а затем понять, как признаки в нашей таблице влияют на качество продаж. Подробное описание формата данных можно получить обычным способом, просто набрав ?Carseats в консоли R.

Сперва надо решить, что считать высокими, а что низкими продажами. Сначала можно посмотреть, в каком промежутке меняются значения Sales во всем объеме данных:

range(Sales)
[1]  0.00 16.27

Положим, что продажи можно считать большими, если значение больше или равно 8. Отразим это, добавив соответствующий столбец к нашему фрейму данных. Чтобы его создать, можно воспользоваться векторной функцией ifelse:

High = ifelse(Sales >= 8, "Yes", "No")

Так мы получили отдельный вектор из Yes и No. Теперь его легко присоединить к нашему фрейму данных Carseats. Строго говоря, колонка Sales нам больше не нужна, и ее можно просто удалить:

Carseats = Carseats[, -1]             # Удаляем Sales
Carseats = data.frame(Carseats, High) # Добавляем High

Перед тем как разбить данные на тестовое и тренировочное множество, было бы неплохо зафиксировать начальное значение генератора случайных чисел:

set.seed(3)

Это сделает наши результаты воспроизводимыми. Если, к примеру, качество данных плохое, то при построении и тестировании модели точность и F-мера могут значительно различаться, что может быть нежелательно. Теперь разделим наши данные на тренировочное и тестовое множества в отношении 2 : 1. Это можно сделать с помощью функции sample:

n = nrow(Carseats)
test = sample(1:n, n/3)
train = -test

Теперь test — это вектор случайных чисел от 1 до n длины n/3. Так как удалять колонки и строки можно, просто используя отрицательные значения индексов, то тренировочное и тестовое множество будут выглядеть следующим образом:

training_set = Carseats[train, ]
testing_set = Carseats[test, ]

Для того чтобы обучить модель, достаточно всего лишь одного вызова:

tree_model = tree(High~., training_set)

Здесь High~. представляет собой формулу, которая говорит системе, что переменная High зависит от всех признаков. Иногда бывает полезно визуализировать модель. Это можно сделать с помощью стандартной функции plot:

plot(tree_model)

Результатом работы такой функции будет отрисованное дерево решений. В таком виде оно довольно бесполезно, поэтому соответствующие ветвления нужно подписать:

text(tree_model)

Результат выполнения команды plot изборажен на Рис 1.

Рис. 1. Визуализация полученной модели, здесь нет подписей, каждый может вызвать функцию text  и тогда станет понятно, почему :)
Рис. 1. Визуализация полученной модели, здесь нет подписей, каждый может вызвать функцию text и тогда станет понятно, почему 🙂

Посмотри на дерево, и тебе станет понятно, что происходит. Алгоритм дает нам просто набор условных выражений, посредством которых происходит классификация. Теперь можно преступить к проверке модели. Для начала попробуем модель на тестовых данных:

tree_pred = predict(tree_model, testing_set, type="class")
mean(tree_pred != High[test])

Здесь первая строчка строит предсказание на основе модели для тестового множества, type=class указывает на то, что мы занимаемся задачей классификации. Наверное, самым простым способом проверки модели является простой подсчет ошибок — функция mean даст нам количество ошибок в долях от единицы. Здесь это значение равно 0.2481203, что в переводе на русский язык означает 24.8%.

Кстати, как насчет улучшения результата? Если внимательно посмотреть на дерево решений, можно предположить, что мы имеем дело с переобучением: некоторые закономерности, которые стали частью модели, на самом деле не свойственны тестовому множеству. В случае деревьев это, как правило, означает, что высота дерева слишком велика и нужно удалить лишние ветки.

А как определить, сколько именно ветвей нужно удалять? В этом деле нам поможет процедура перекрестной проверки. Средствами R она реализуется довольно легко:

set.seed(5)
cv_tree = cv.tree(tree_model, FUN=prune.misclass)

В этом (достаточно понятном и без всяких комментариев) коде мы сначала еще раз устанавливаем начальное значение для генератора псевдослучайных чисел. Затем с помощью встроенной функции мы делаем перекрестную проверку. Результатом будет список cv_tree, в котором есть поля size и dev: первое - размер дерева (его высота), второе - допуск ошибки. Функция prune.misclass - как раз та функция, которая позволит нам сделать отсечение:

plot(cv_tree$size, cv_tree$dev, type="b")

Если посмотреть на построенный график зависимости ошибки от размера дерева (см. рис. 2), то видно, что минимум ошибки приходится где-то на 9:

Рис. 2. Результат перекрестной проверки
Рис. 2. Результат перекрестной проверки

Исходя из полученных данных, можно сделать улучшенную модель, сделав отсечение ветвей с помощью той же функции (см. рис. 3):

prune_model = prune.misclass(tree_model, best=9)
Рис. 3. Визуалиация дерева решений после отсечения
Рис. 3. Визуалиация дерева решений после отсечения

Часто бывает нужно визуализировать дерево решений прямо в консоли. Для этого достаточно просто напечатать модель:

print(prune_model)
node), split, n, deviance, yval, (yprob)
      * denotes terminal node

1) root 267 359.500 No ( 0.59925 0.40075 )  
  2) ShelveLoc: Bad,Medium 213 260.400 No ( 0.69953 0.30047 )  
    4) Price < 92.5 30  36.650 Yes ( 0.30000 0.70000 ) *
    5) Price > 92.5 183 199.500 No ( 0.76503 0.23497 )  
     10) CompPrice < 124.5 79  42.460 No ( 0.92405 0.07595 ) *
     11) CompPrice > 124.5 104 135.400 No ( 0.64423 0.35577 )  
       22) ShelveLoc: Bad 34  15.210 No ( 0.94118 0.05882 ) *
       23) ShelveLoc: Medium 70  97.040 Yes ( 0.50000 0.50000 )  
         46) Price < 115.5 19   7.835 Yes ( 0.05263 0.94737 ) *
         47) Price > 115.5 51  64.920 No ( 0.66667 0.33333 )  
           94) Age < 33.5 8   0.000 Yes ( 0.00000 1.00000 ) *
           95) Age > 33.5 43  44.120 No ( 0.79070 0.20930 )  
            190) Advertising < 15 37  25.350 No ( 0.89189 0.10811 ) *
            191) Advertising > 15 6   5.407 Yes ( 0.16667 0.83333 ) *
  3) ShelveLoc: Good 54  54.590 Yes ( 0.20370 0.79630 )  
    6) Price < 135 46  27.180 Yes ( 0.08696 0.91304 ) *
    7) Price > 135 8   6.028 No ( 0.87500 0.12500 ) *

Теперь можно опять проверить работоспособность нашей модели на тестовых данных:

tree_pred1 = predict(prune_model, testing_set, type="class")
mean(tree_pred1 != High[test])

Надо сказать, в данном конкретном случае отсечение не очень-то и помогло: при данных значениях seed получится 24% вместо 24.8%. В других случаях с помощью этого метода иногда получается добиться лучших результатов (хотя как мы увидим дальше, сейчас тоже получилось неплохо).

Пользуясь случаем, расскажу немного о том, что такое «хорошо» и «плохо». На этом этапе я хочу вернуть читателя в реальный мир и показать, насколько все на самом деле может быть плохо. Если поменять значение seed, можно получить, например, 30% - да так, что никакое отсечение не поможет. Все дело в данных: здесь мы делаем случайную выборку из не самых лучших данных, и поэтому в зависимости от начального значения генератора результат может очень сильно меняться. В задачах анализа данных это происходит довольно часто. Данные редко бывают настолько хороши, чтобы работа с ними была сразу приятна, а результаты стабильны и поражали нас значениями точности и полноты.

Кстати, чуть не забыл! Надо бы посчитать соответствующие величины для полученных нами только что результатов. Для начала сделаем это для нашей оригинальной модели:

t = table(pred=tree_pred, true=High[test])

Если вывести ее на консоль, мы увидим:

     true
pred  No Yes
  No  66  23
  Yes 10  34

Эта таблица как раз и представляет собой вычисленные значения TP, FP, TN, FN. Теперь легко написать функцию, которая сможет рассчитывать точность и полноту:

check <- function(res) {
  tn = res[1,1]
  tp = res[2,2]
  fp = res[1,2]
  fn = res[2,1]
  prec <- tp / (tp + fp)
  recall <- tp / (tp + fn)
  f1 <- 2 * tp / (2 * tp + fp + fn)
  list(precision=prec, recall=recall, f1=f1)
}

Итак, для оригинальной модели соответствующие значения будут:

$precision
[1] 0.5964912

$recall
[1] 0.7727273

$f1
[1] 0.6732673

Теперь все то же самое, только уже для модели с отсечением:

t1 = table(pred=tree_pred1, true=High[test])
check(t1)

Несмотря на то, что при использовании mean разница между моделями не казалось уж такой существенной, точность модели возросла более чем на 5%, а F-мера - на 2%, хотя полнота упала на 2%:

$precision
[1] 0.6491228

$recall
[1] 0.755102

$f1
[1] 0.6981132

Итак, когда мы немного разобрались с тем, как решать задачу классификации с помощью деревьев решений, настало время рассказать о том, какие еще пакеты доступны для работы с ними. Тем более, что методы визуализации для пакета tree, как мы видели выше, не так хороши.

 

Визуализация деревьев

Для начала следует упомянуть пакет party, который имеет довольно приличную систему визуализации (см. рис 4):

library(party)
ct = ctree(High~., training_set)
plot(ct)
Рис. 4. Визуализация дерева решений для пакета party
Рис. 4. Визуализация дерева решений для пакета party

Еще один приятный в использовании вариант - это пакет rpart и пакет для визуализации rpart.plot:

library(rpart)
library(rpart.plot)
rp = rpart(High~., training_set)
prp(rp)
Рис. 5. Результат работы функции prp из пакета rpart.plot
Рис. 5. Результат работы функции prp из пакета rpart.plot

Если воспользоваться функций fancyRpartPlot из пакета rattle, то можно получить прекрасное изображение, как показано на рис 6.

library(rattle)
fancyRpartPlot(rp)
Рис. 6. Результат работы fancyRpartPlot
Рис. 6. Результат работы fancyRpartPlot
 

Вместо заключения

В следующих статьях я расскажу, про работу с другими алгоритмами классификации (такими как наивный байессовский классификатор и метод опорных векторов), а также с нейронными сетями и алгоритмом Random Forest. Более того, я немного расскажу про то, как это работает на самом деле.

Оставить мнение