Класифікатор вайфу на fast.ai: тренування

Усі статті, обговорення, новини про AI — в одному місці. Підписуйтеся на DOU | AI!

Сьогдні ми створимо класифікатор, який буде розрізняти, на зображенні Рей чи Аска.
Стаття написана для новачків ML, які хочуть на практиці спробувати, як працює ML. Тут не буде глибокого пояснення, як все працює, але буде пояснення, як все це запустити. Ніякого попереднього досвіду не потрібно, лише бажання. Поїхали!

Підготовка

Для того, щоб не витрачати час на налаштування JupiterNotebook і решти штук, щоб нейронка працювала на відеокарті, будемо використовувати хмарне рішення.

Я рекомендую console.paperspace.com. Це повноцінний і безкоштовний сервіс.

1. Реєструємось.

2. Натискаємо «Create new notebook». Рекомендують брати fast.ai але, я не бачу плюсів перед PyTorch, тільки сміття більше. Тому беремо PyTorch.

3. Обираємо Free-GPU. Хоча, може доведеться почекати кілька хвилин, доки звільниться машина, але Free-CPU надто повільна для задачі.

4. Трохи чекаємо, доки запуститься машина, і можемо починати писати наш записник.

5. Переходимо в режим Jupiter.

Працювати з записником в paperspace зручніше, але не працюють віджети.

6. Натискаємо Notebook/Python 3

Наш записник відкрився.

Робота з Jupiter

1. Встановлюємо модулі

Перше що нам треба зробити — це встановити необхідні пакети

!pip install graphviz

Потрібен для fast.ai

!pip install -Uqq fastbook

Встановлюємо сам fastai

!pip install -q jmd_imagescraper

В fast.ai є вбудований збирач даних, але він у мене не запрацював, тому використаємо цей модуль

2. Імпортуємо

import fastbook
from fastbook import *  #не хвилюйтесь, з * так задумано
fastbook.setup_book()
from fastai.vision.widgets import *
from jmd_imagescraper.core import *

Імпортуємо Path

from pathlib import Path

3. Збираємо дані

Створюємо теку для зображень

path = Path().cwd()/"waifus"

Завантажуємо картинки з DuckDuckGo

duckduckgo_search(path,"asuka","Asuka Langley Sohryu",max_results=200)
duckduckgo_search(path,"rei","rei ayanami",max_results=200)

path — це шлях до теки з усіма картинками

«asuka» і «rei» — це теки класів зображень

«Asuka Langley Sohryu» і «rei ayanami» — це пошукові запити, за якими шукати картинки

max_results=200 — це скільки картинок завантажити

Доступ до наших файлів

fns = get_image_files(path)

В мене такого не бувало, але може бути, що завантижились пошкоджені файли. Перевіряємо

failed = verify_images(fns)
failed

Якщо масив не порожній, треба їх видалити

failed.map(Path.unlink);

Тепер створимо DataLoader — це те, з чого нейронна мережа буде брати дані. Поки що будемо думати про DataBlock як про шаблон для створення DataLoader.

waifus=DataBlock(
    blocks=(ImageBlock,CategoryBlock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2,seed=42)
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
splitter=RandomSplitter(valid_pct=0.2,seed=42) візьме 20% даних для валідації.
item_tfms=RandomResizedCrop(224, min_scale=0.5) оскільки картинки часто не квадратні - візьме кватратну частинку зображення для роботи
batch_tfms=aug_transforms() так як у нас доволі скромний об'єм даних - це функція створить додаткові зображення шляхом викривлення та деяких змін кольору в оригінальних зображеннях
dls=waifus.dataloaders(path)
dls.valid.show_batch(max_n=4,nrows=1)
dls - це те що ми будемо використовувати. Другий рядок - щоб подиитись набір даних. В даному випадку - для краси.

4. Тренуємо нейронку

learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)

resnet18 — це наша архітектура.

metrics — це те по чому ми оптимізуємо нашу нейронну мережу. В даному випадку по кількості помилок класифікації.

fine_tune(4) — означає, що будемо оптимізувати 4 епохи. Тобто чотири повні проходи по всім зображенням.

Глянемо, скільки зображень класифікувались неправильно

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Глянемо, які зображення

interp.plot_top_losses(5)

5. Очистка даних

Як бачимо, вони мають мало стосунку до наших класів.

Логічно почати тренування, з того щоб почистити дані. Але навіщо, якщо сама нейронка може нам в цьому допомогти.

Це покаже зображення з якими у мережі виникли проблеми. Ми можемо видалити їх, або перенести в інший клас.

cleaner = ImageClassifierCleaner(learn)
cleaner 

Але ImageClassifierCleaner не робить це автоматично.

Видалення:

for idx in cleaner.delete(): cleaner.fns[idx].unlink()

Перенесення класу:

for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)

Висновок

Вітаю, ви натренували свою першу нейронну мережу. Тепер можна її перетренувати на чистих даних, що я рекомендую зробити для практики і кращих результатів.
Навчились збирати і чистити дані.
В наступній частині обговоримо, як зробити веб додаток з нашої нейронки.

Як це виглядає в реальності і вихідний код:

console.paperspace.com/...​6lc1j?file=Untitled.ipynb

PS: Так, матеріал сильно опирається на матеріал course.fast.ai, який я дуже рекомендую.

👍ПодобаєтьсяСподобалось3
До обраногоВ обраному2
LinkedIn
Дозволені теги: blockquote, a, pre, code, ul, ol, li, b, i, del.
Ctrl + Enter
Дозволені теги: blockquote, a, pre, code, ul, ol, li, b, i, del.
Ctrl + Enter
animals

Где эта переменная используется?
Автор хоть сам имеет понимание как оно под капотом учится какие операции проводятся над этими наборами пикселей?

Одрук, дякую, поправив.

Автор хоть сам имеет понимание как оно под капотом

Ще ні, я в процесі навчання.

Було би добре для початку описати задачу, чого ми хочемо добитися і що в результаті отримали. А то для людей, котрі не в курсі що таке вайфу стаття виглядає як «кудись зайшли і щось зробили», а що і для чого — незрозуміло.

Тегом pre можна обгортати багато строк коду — бо зара воно якось дивно виглядає

Сорі, перша стаття. Як в цьому редакторі подивитись теги?

В мене при натисканні лінку «Редагувати» над топіком одразу raw html показує

У вас rst/md?

В мене показує редактор як в старій пошті. Жирний, італік, курсив, список і код. Тегів чомусь не бачу.

Пфффф
В мене tiny.cloud / tinymce.com заблокувались через uMatrix
Тому я бачив rst + html при редагуванні постів %)

Як варіант заблокувати ці домени і редагувати rst + html

Це явно баг — але він в 3rd party коді

Ще, я дізнався, що можна додати до url ?old в кінці

Упс, забув додати посилання на записник. Додав в висновок.

Підписатись на коментарі