Класифікатор вайфу на fast.ai: тренування
Усі статті, обговорення, новини про AI — в одному місці. Підписуйтеся на DOU | AI!
Сьогдні ми створимо класифікатор, який буде розрізняти, на зображенні Рей чи Аска.
Стаття написана для новачків ML, які хочуть на практиці спробувати, як працює ML. Тут не буде глибокого пояснення, як все працює, але буде пояснення, як все це запустити. Ніякого попереднього досвіду не потрібно, лише бажання. Поїхали!
Підготовка
Для того, щоб не витрачати час на налаштування JupiterNotebook і решти штук, щоб нейронка працювала на відеокарті, будемо використовувати хмарне рішення.
Я рекомендую console.paperspace.com. Це повноцінний і безкоштовний сервіс.
1. Реєструємось.
2. Натискаємо «Create new notebook». Рекомендують брати fast.ai але, я не бачу плюсів перед PyTorch, тільки сміття більше. Тому беремо PyTorch.
3. Обираємо Free-GPU. Хоча, може доведеться почекати кілька хвилин, доки звільниться машина, але Free-CPU надто повільна для задачі.
4. Трохи чекаємо, доки запуститься машина, і можемо починати писати наш записник.
5. Переходимо в режим Jupiter.
Працювати з записником в paperspace зручніше, але не працюють віджети.
6. Натискаємо Notebook/Python 3
Наш записник відкрився.
Робота з Jupiter
1. Встановлюємо модулі
Перше що нам треба зробити — це встановити необхідні пакети
!pip install graphviz
Потрібен для fast.ai
!pip install -Uqq fastbook
Встановлюємо сам fastai
!pip install -q jmd_imagescraper
В fast.ai є вбудований збирач даних, але він у мене не запрацював, тому використаємо цей модуль
2. Імпортуємо
import fastbook from fastbook import * #не хвилюйтесь, з * так задумано fastbook.setup_book() from fastai.vision.widgets import * from jmd_imagescraper.core import *
Імпортуємо Path
from pathlib import Path
3. Збираємо дані
Створюємо теку для зображень
path = Path().cwd()/"waifus"
Завантажуємо картинки з DuckDuckGo
duckduckgo_search(path,"asuka","Asuka Langley Sohryu",max_results=200) duckduckgo_search(path,"rei","rei ayanami",max_results=200)
path — це шлях до теки з усіма картинками
«asuka» і «rei» — це теки класів зображень
«Asuka Langley Sohryu» і «rei ayanami» — це пошукові запити, за якими шукати картинки
max_results=200 — це скільки картинок завантажити
Доступ до наших файлів
fns = get_image_files(path)
В мене такого не бувало, але може бути, що завантижились пошкоджені файли. Перевіряємо
failed = verify_images(fns) failed
Якщо масив не порожній, треба їх видалити
failed.map(Path.unlink);
Тепер створимо DataLoader — це те, з чого нейронна мережа буде брати дані. Поки що будемо думати про DataBlock як про шаблон для створення DataLoader.
waifus=DataBlock( blocks=(ImageBlock,CategoryBlock), get_items=get_image_files, splitter=RandomSplitter(valid_pct=0.2,seed=42) get_y=parent_label, item_tfms=RandomResizedCrop(224, min_scale=0.5), batch_tfms=aug_transforms())
splitter=RandomSplitter(valid_pct=0.2,seed=42) візьме 20% даних для валідації.
item_tfms=RandomResizedCrop(224, min_scale=0.5) оскільки картинки часто не квадратні - візьме кватратну частинку зображення для роботи
batch_tfms=aug_transforms() так як у нас доволі скромний об'єм даних - це функція створить додаткові зображення шляхом викривлення та деяких змін кольору в оригінальних зображеннях
dls=waifus.dataloaders(path) dls.valid.show_batch(max_n=4,nrows=1)
dls - це те що ми будемо використовувати. Другий рядок - щоб подиитись набір даних. В даному випадку - для краси.
4. Тренуємо нейронку
learn = cnn_learner(dls, resnet18, metrics=error_rate) learn.fine_tune(4)
resnet18 — це наша архітектура.
metrics — це те по чому ми оптимізуємо нашу нейронну мережу. В даному випадку по кількості помилок класифікації.
fine_tune(4) — означає, що будемо оптимізувати 4 епохи. Тобто чотири повні проходи по всім зображенням.
Глянемо, скільки зображень класифікувались неправильно
interp = ClassificationInterpretation.from_learner(learn) interp.plot_confusion_matrix()
Глянемо, які зображення
interp.plot_top_losses(5)
5. Очистка даних
Як бачимо, вони мають мало стосунку до наших класів.
Логічно почати тренування, з того щоб почистити дані. Але навіщо, якщо сама нейронка може нам в цьому допомогти.
Це покаже зображення з якими у мережі виникли проблеми. Ми можемо видалити їх, або перенести в інший клас.
cleaner = ImageClassifierCleaner(learn) cleaner
Але ImageClassifierCleaner не робить це автоматично.
Видалення:
for idx in cleaner.delete(): cleaner.fns[idx].unlink()
Перенесення класу:
for idx,cat in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/cat)
Висновок
Вітаю, ви натренували свою першу нейронну мережу. Тепер можна її перетренувати на чистих даних, що я рекомендую зробити для практики і кращих результатів.
Навчились збирати і чистити дані.
В наступній частині обговоримо, як зробити веб додаток з нашої нейронки.
Як це виглядає в реальності і вихідний код:
console.paperspace.com/...6lc1j?file=Untitled.ipynb
PS: Так, матеріал сильно опирається на матеріал course.fast.ai, який я дуже рекомендую.
16 коментарів
Додати коментар Підписатись на коментаріВідписатись від коментарів