Java в еру AI. Вплив нових бібліотек на машинне навчання та GPU-обчислення
Мене звати Юрій Зайчик, я Senior Java Developer у DXC Luxoft. Працюю у сфері автомобільної навігації. Декілька років тому я вже робив статтю на схожу тему. Сьогодні, коли AI є головним трендом у розробці програмних продуктів, хочу розглянути, що змінилось у Java з DS & UI за два роки.
JEPs and OpenJDK projects
У вересні минулого року стався реліз версії JDK 21, яка отримала LTS — довгострокову підтримку. А у березні цього року був реліз JDK 22.
Інкубаторна фіча, про яку я писав у попередній статті — JEP 460 (Vector API), досі знаходиться у preview (вже сьома ітерація!). За цей період не відбулось значних змін, лише оптимізації та фікси. Тож переваги, які може надати це API для векторних обчислень, ті ж самі. На жаль, поки не відбудеться реліз, їх використання обмежується лише експериментами.
Та за цей час з’явилися й розвинулися ще декілька інкубаторних проєктів, які зможуть бути корисними для роботи у сфері AI та DS.
Foreign Function & Memory API
Одним із таких є JEP 454 — Foreign Function & Memory API, що отримав реліз у JDK версії 22. Основна ідея цього модуля — створити механізм ефективної взаємодії між програмами, які виконуються у JVM, й зовнішніми програмами та даними за межами Java runtime. Тобто він повинен замінити старий та ненадійний механізм JNI (Java Native Interface).
Як це може бути корисним в контексті ML? Існує вже багато бібліотек, які розроблені на інших мовах програмування (Tensorflow, MLpack, quanteda тощо). Тож можна використати цей API, щоб викликати методи зі сторонніх бібліотек та обмінюватися з ними даними.
Як приклад розглянемо роботу з бібліотекою XGBoost. Цей блок коду створює матрицю дійсних чисел та ініціалізує XgBoost.
public class Main { public static void main(String[] args) { try ( var arena = Arena.ofConfined()) { Linker linker = Linker.nativeLinker(); SymbolLookup targetLookup = SymbolLookup.libraryLookup("libxgboost.so.0", arena); MemorySegment xgMatrixCreateMseg = targetLookup.find("XGDMatrixCreateFromMat").orElseThrow(); MemorySegment xgBoosterCreateMseg = targetLookup.find("XGBoosterCreate").orElseThrow(); float[] matrix = {1.1f, 2.2f, 3.3f, 4.4f, 5.0f, 6.0f}; MemorySegment arrayMseg = arena.allocateFrom(ValueLayout.JAVA_FLOAT, matrix); MemorySegment dMatrixHandle = arena.allocate(ValueLayout.ADDRESS); MethodHandle xgMatrixCrtHandle = linker.downcallHandle(xgMatrixCreateMseg, FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.JAVA_FLOAT, ValueLayout.ADDRESS)); int res = (int) xgMatrixCrtHandle.invoke(arrayMseg, 3, 2, -1.0f, dMatrixHandle); if (res != 0) { throw new RuntimeException("Matrix creation failed"); } MemorySegment dMatrixHandleVal = dMatrixHandle.get(ValueLayout.ADDRESS, 0); System.out.println("matrix handle address: " + dMatrixHandleVal.address()); MemorySegment dMatrixHandleArr = arena.allocateFrom(ValueLayout.ADDRESS, dMatrixHandleVal); MethodHandle xgBoosterCreateHandle = linker.downcallHandle(xgBoosterCreateMseg, FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT)); var boosterCallMseg = (MemorySegment) xgBoosterCreateHandle.invoke(dMatrixHandleArr, 1); if (boosterCallMseg != null) { System.out.println("Booster handle address: " + boosterCallMseg.address()); } MethodHandle xgBoostFreeHandle = linker.downcallHandle( targetLookup.find("XGBoosterFree").orElseThrow(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS)); xgBoostFreeHandle.invoke(boosterCallMseg); MethodHandle xgDMatrixFree = linker.downcallHandle(targetLookup.find("XGDMatrixFree").orElseThrow(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS)); xgDMatrixFree.invoke(dMatrixHandle); } catch (Throwable e) { throw new RuntimeException(e); } } }
Спочатку ініціалізуємо так звану «арену». Це область, яка контролює цикл роботи з зовнішніми блоками пам’яті. Далі створюємо lookup для знаходження необхідних методів у зовнішній бібліотеці й за допомогою метода find знаходимо посилання на потрібні нам методи. Створюємо Java-масив дійсних чисел, за допомогою методу arena.allocateFrom() виділяємо для нього місце у зовнішні пам’яті та переносимо туди дані.
Після цього створюємо керуючу сігнатуру для зовнішньої функції «XGDMatrixCreateFromMat», посилання на яку ми вже попередньо створили. За допомогою FunctionDescriptor.of() описуємо тип даних, який повертає функція (перший аргумент у списку) та типи даних, які вона приймає. Тип ValueLayout.ADDRESS позначає посилання на область у зовнішній пам’яті, як вказівники у мові С.
Далі виконуємо виклик функції з відповідними параметрами, серед яких — посилання на масив даних, кількість рядків та стовпців, значення, яке відображатиме відсутність даних та посилання на створену структуру матриці у зовнішній пам’яті. Вичитуємо дані з посилання на структуру матриці:
MemorySegment dMatrixHandleVal = MatrixHandle.get(ValueLayout.ADDRESS, 0);
Яке потім використаємо як аргумент для виклику функції XGBoosterCreate. Наприкінці вдаємось до функцій «XGBoosterFree» та «XGDMatrixFree», щоб звільнити ресурси.
Babylon
Ще один цікавий проєкт — Babylon. Його основна ідея — створити механізм, за допомогою якого було б можливо транслювати Java-код в інші форми коду. Наприклад, створити з Java-методу GPU kernel. Щоб цього досягти, автори планують розширити функціонал Java Reflection API.
Розглянемо цей процес детальніше. Java-розробники, працюючи над своїми програмами, записують Java-код у текстові файли з розширенням «.java». Щоб із цього вихідного коду зробити програму, треба передати його на обробку Java-компілятору. Компілятор, своєю чергою, перетворює вихідний код у структури AST (Abstract Syntax Trees).
Наступний крок — перевірка AST на консистентність та генерація з них байт-коду. JVM використовує байт-код, щоб перетворювати його на нативні машинні інструкції. Усі ці стадії від текстового файлу до машинних інструкцій представляють однакову програму. AST семантично ближча до Java-синтаксису, а байт-код ближчий до машинних інструкцій, наприклад, type-erasure. Тож жодна з репрезентацій не підходить для відображення рефлексійної моделі коду. Рефлексійна модель не повинна мати жорсткої прив’язки до синтаксису Java, але повинна однозначно описувати структуру програми та інформацію про типи.
Автори пояснюють рефлексійну модель коду таким чином: вона складається із таких елементів, як-то операції, тіла, та блоки, які формують дерево. Операція може складатися з одного або декількох тіл. Тіло складається з одного або декількох блоків. Блок містить ланцюг з однієї або декількох операцій. Він може задавати параметри та значення, Операція може задавати результат, який операція повертає, та значення. Як операнди можуть використовуватися задані значення. Значення повинні мати тип.
Подивимось, як це виглядає на прикладі. Задамо такий Java-метод:
@CodeReflection static double sub(double a, double b) { return a - b; }
Анотація @CodeReflection вкаже компілятору, що для цього метода треба створити рефлексійну модель.
Якщо пройтись елементами створеної моделі, побачимо таке дерево:
class java.lang.reflect.code.op.CoreOp$FuncOp class java.lang.reflect.code.Body class java.lang.reflect.code.Block class java.lang.reflect.code.op.CoreOp$VarOp class java.lang.reflect.code.op.CoreOp$VarOp class java.lang.reflect.code.op.CoreOp$VarAccessOp$VarLoadOp class java.lang.reflect.code.op.CoreOp$VarAccessOp$VarLoadOp class java.lang.reflect.code.op.CoreOp$SubOp class java.lang.reflect.code.op.CoreOp$ReturnOp
У його структурі можемо побачити частини рефлексійної моделі, про які говорили вище — операції, тіла, блоки. Ця модель однозначно описує структури програми та може бути використана для створення кода іншою мовою програмування. Також автори додали механізм її текстового відображення у більш зрозумілому для людини форматі:
func @"sub" @loc="19:5:file:/.../ExpressionGraphs.java" (%0 : double, %1 : double)double -> { %2 : Var<double> = var %0 @"a" @loc="19:5"; %3 : Var<double> = var %1 @"b" @loc="19:5"; %4 : double = var.load %2 @loc="21:16"; %5 : double = var.load %3 @loc="21:20"; %6 : double = sub %4 %5 @loc="21:16"; return %6 @loc="21:9"; };
Такий формат дуже корисний для тестування моделей.
Детальніше про це можна подивитись у презентації Пола Сандоза та подивитись приклади у GitHub-репозиторії проєкту.
Java GPGPU libraries
Щоб продовжити тему обчислень за допомогою GPU, подивимось, які наявні на Java бібліотеки дозволяють це робити.
Aparapi
Одна з них — Aparapi, що дозволяє конвертувати Java-код в Kernel-код OpenCL. Як відомо, у центральному процесорі комп’ютера можуть знаходитись від одиниць до десятків обчислювальних ядер.
Ці ядра залежно від архітектури RISC або CISC мають певний набір команд, яких достатньо для виконання звичайних програм на комп’ютері — математичних обчислень, взаємодій з периферійними пристроями та ресурсами тощо.
Своєю чергою, ядра GPU мають дуже обмежений набір команд для виконання математичних операцій, притаманних потребам комп’ютерної графіки, які піддаються паралелізації. Але коли ми маємо сотні або тисячі таких ядер, це суттєво прискорює роботу. У контексті Machine learning та Data Science, де теж трапляється розпаралелювання операцій над даними, які зазвичай мають матричну або векторну структуру, застосування GPU є дуже ефективним.
Розглянемо приклад розв’язання задачі векторного обчислення з попередньої статті.
Коротко нагадаю зміст задачі: треба обчислити довжину векторної лінії.
Розв’язання даної задачі за допомогою Aparapi матиме такий вигляд:
public static double calculate(VectorLine line, boolean addLogging) { var x1 = line.latitudes(); var x2 = Arrays.copyOfRange(x1, 1, x1.length); var y1 = line.longitudes(); var y2 = Arrays.copyOfRange(y1, 1, y1.length); double[] hypotenouses = new double[x2.length]; var kernel = new Kernel() { @Override public void run() { int gid = getGlobalId(); double dx = x2[gid] - x1[gid]; double dy = y2[gid] - y1[gid]; hypotenouses[gid] = sqrt(dx * dx + dy * dy); } }; kernel.execute(Range.create(x2.length)); if (addLogging) { log(); } kernel.dispose(); return DoubleStream.of(hypotenouses).sum(); }
Вхідні дані надходять у вигляді VectorLine, цей клас містить два вектори: координати абсцис та ординат. Створюємо допоміжні вектори для операцій векторної різниці. Kernel-код буде отримувати по парі координат x та y, виконувати віднімання між парами координат, зведення у квадрат та знаходити квадратний корінь суми квадратів. Щоб задати Kernel-код, який буде переведений в OpenCL, потрібно створити екземпляр класу Kernel і імплементувати код, перевизначаючи абстрактний метод run().
Щоб запустити цей код на GPU, потрібно викликати метод kernel.execute(), вхідним параметром якого є діапазон, який визначає кількість необхідних CUDA-ядер. За допомогою метода getGlobalId() можемо отримати порядковий номер запущеного Kernel. Його будемо використовувати як індекс для отримання даних із вхідних масивів і запису результату у вихідний масив hypotenouses. Коли всі ядра закінчать обчислення, рахуємо суму значень вихідного масиву й отримуємо довжину вектора.
Jcuda
Ще одна бібліотека, схожа на Aparapi — Jcuda. На відміну від Aparapi, де ми могли простий нативний код на Java транслювати у Kernel-код, Jcuda дає інтерфейси взаємодії з зовнішньою пам’яттю та зовнішніми методами CUDA за допомогою проміжної бібліотеки. Цей підхід схожий до того, що використовується у новому Java Foreign Functions and Memorz API, який ми розглядали вище.
Розгляньмо, як виконати задачу з обчислення довжини вектора за допомогою цієї бібліотеки.
public static double calculate(VectorLine line) { // Enable exceptions and omit all subsequent error checks JCudaDriver.setExceptionsEnabled(true); var x1 = line.latitudes(); var x2 = Arrays.copyOfRange(x1, 1, x1.length); var y1 = line.longitudes(); var y2 = Arrays.copyOfRange(y1, 1, y1.length); var vecLength = x2.length; VecDouble.init(); var devPtrX1 = createDoubleVectorPointer(x1, x2.length); var devPtrX2 = createDoubleVectorPointer(x2, x2.length); var devPtrY1 = createDoubleVectorPointer(y1, y2.length); var devPtrY2 = createDoubleVectorPointer(y2, y2.length); var devPtrDX = createDoubleVectorPointer(null, vecLength, false); var devPtrDY = createDoubleVectorPointer(null, vecLength, false); double[] hypotenuses = new double[x2.length]; var devResultPtr = createDoubleVectorPointer(hypotenuses, hypotenuses.length, false); VecDouble.sub(vecLength, devPtrDX, devPtrX2, devPtrX1); VecDouble.sub(vecLength, devPtrDY, devPtrY2, devPtrY1); VecDouble.mul(vecLength, devPtrDX, devPtrDX, devPtrDX); VecDouble.mul(vecLength, devPtrDY, devPtrDY, devPtrDY); VecDouble.add(vecLength, devResultPtr, devPtrDX, devPtrDY); VecDouble.sqrt(vecLength, devResultPtr, devResultPtr); cuMemcpyDtoH(Pointer.to(hypotenuses), devResultPtr, vecLength * Sizeof.DOUBLE); Stream.of(devResultPtr, devPtrDX, devPtrDY, devPtrX1, devPtrX2, devPtrY1, devPtrY2) .forEach(JCudaDriver::cuMemFree); VecDouble.shutdown(); return DoubleStream.of(hypotenuses).sum(); } private static CUdeviceptr createDoubleVectorPointer(double[] array, int length, boolean copyToDevice) { var devPtr = new CUdeviceptr(); var size = length * Sizeof.DOUBLE; cuMemAlloc(devPtr, size); if (copyToDevice) cuMemcpyHtoD(devPtr, Pointer.to(array), size); return devPtr; }
Як і у попередньому прикладі, спочатку ініціалізуємо вхідні масиви. Далі потрібно ініціалізувати роботи із векторами для дробових чисел подвійної точності VecDouble.init(). Потім для кожного із вхідних векторів треба зробити Pointer, який виконуватиме роль посередника між контейнером даних у JVM та у GPU-пристрої.
Виконуємо послідовність векторних операцій — VecDouble.sub, VecDouble.mul, VecDouble.add, VecDouble.sqrt. Результати отримання квадратних коренів записуємо у контейнер devResultPtr. За допомогою методу cuMemcpyDtoH копіюємо результати обчислень із пам’яті пристрою GPU у пам’ять JVM, звільняємо пам’ять у GPU і рахуємо суму елементів вихідного масиву.
Якщо порівнювати підходи, застосовані у цих двох бібліотеках, то в Aparapi отримаємо більш наближений до Java варіант, де описуємо Kernel-код на Java. Звісно, маємо на увазі обмеження: виконувати можна лише примітивні операції. В той час, як JCUDA дає нам інтерфейс для роботи з нативною бібліотекою, що підключена через JNI. Aparapi ховає від нас процес взаємодії із GPU, водночас з JCUDA треба власноруч виконувати деякі низькорівневі операції, як-то виділення та очищення пам’яті пристрою GPU.
Java ML libraries
Одним із цікавих проєктів на Java, який створений безпосередньо для розв’язання завдань Machine learning, є бібліотека Tribuo від Oracle labs. Вона може бути використана для типових задач ML: класифікація, регресія, кластеризація. Має зручний інтерфейс для взаємодії зі сторонніми бібліотеками xgboost, liblinear та іншими.
Що її відрізняє від інших інструментів, так це об’єктноорієнтований підхід, тобто API має суворо типізовану форму зі спеціальними параметризованими класами та інтерфейсами для визначення моделей, прогнозів та наборів даних. Розглянемо її використання на прикладі.
Спробуємо побудувати модель для розпізнавання рукописних цифр. Візьмемо відомий для цього датасет MNIST.
Спочатку робимо ініціалізацію датасету: визначаємо imgTrainSource та imgTestSource відповідно для тренування та подальшого тестування створеної моделі. Потім визначаємо алгоритм для створення моделі за допомогою convolutional-нейромережі (у нашому випадку це буде AdaGrad) та задаємо параметри для нього.
public static void trainAndTest(Path trainImagesPath, Path trainLabelsPath, Path testImagesPath, Path testLabelsPath, int mlpEpochs, int cnnEpcochs, int imageSize) throws IOException { var labelFactory = new LabelFactory(); var imgTrainSource = new IDXDataSource<>(trainImagesPath, trainLabelsPath, labelFactory); var imgTestSource = new IDXDataSource<>(testImagesPath, testLabelsPath, labelFactory); var imgTrain = new MutableDataset<>(imgTrainSource); var imgTest = new MutableDataset<>(imgTestSource); var inputName = "MNIST_INPUT"; var gradAlgorithm = GradientOptimiser.ADAGRAD; var gradParams = Map.of("learningRate", 0.01F, "initialAccumulatorValue", 0.01F); if (mlpEpochs <= 0) return; doCnnTest(inputName, imgTrain, imgTest, gradAlgorithm, gradParams, cnnEpcochs, imageSize); }
Далі будуємо модель та перевіряємо її.
private static Model<Label> doCnnTest (String inputName, MutableDataset<Label> trainSet, MutableDataset<Label> testSet, GradientOptimiser gradAlgorithm, Map<String, Float> gradParams, int epochs, int imageSize) { var cnnTuple = CNNExamples.buildLeNetGraph(inputName, imageSize, 255, trainSet.getOutputs().size()); var imageConverter = new ImageConverter(inputName, imageSize, imageSize, 1); var imgCNNTrainer = new TensorFlowTrainer<Label>( cnnTuple.graphDef, cnnTuple.outputName, gradAlgorithm, gradParams, imageConverter, new LabelConverter(), 16, //train batch size epochs, 16, // test batch size -1 // logging interval ); var cnnStart = System.currentTimeMillis(); var cnnModel = imgCNNTrainer.train(trainSet); var cnnEnd = System.currentTimeMillis(); System.out.println("CNN training took:" + Util.formatDuration(cnnStart, cnnEnd)); var labelEval = new LabelEvaluator(); var predictions = model.predict(testSet); var cnnEvaluation = labelEval.evaluate(model, predictions, testSet.getProvenance()); System.out.println(cnnEvaluation.toString()); return cnnModel; }
Як видно із коду за лаштунками Tribuo буде використовувати бібліотеку Tensorflow.
Результати перевірки:
9 класів — це відповідно цифри від 0 до 9. n — кількість цифр у датасеті, tp, fn, fp — true positive (розпізнано правильно), false negative (розпізнана як інший клас), false positive (інший клас, розпізнаний як заданий) відповідно.
Точність розпізнавання побудованої моделі — 0.989. По коду бачимо, що процес навчання та використання моделі має звичну для Java-розробників об’єктноорієнтовану концепцію та назви класів і методів доволі самоописові. Як на мене, такий процес зручніший за взаємодію з безликими масивами чисел. Більш детально з іншими прикладами можна ознайомитись на сайті Tribuo.
Висновки
На цей час Java є не найліпшим інструментом для роботи з моделями напряму, насамперед через відсутність зручних вбудованих можливостей візуалізації усіх проміжних етапів створення та навчання моделі.
Але якщо застосовувати Java для роботи із вже готовими моделями та даними для вирішення певних бізнес-завдань, в цьому Java може конкурувати. Адже має багату екосистему, яка дозволяє створювати продукти у різних парадигмах.
3 коментарі
Додати коментар Підписатись на коментаріВідписатись від коментарів