Java в еру AI. Вплив нових бібліотек на машинне навчання та GPU-обчислення

💡 Усі статті, обговорення, новини про AI — в одному місці. Приєднуйтесь до AI спільноти!

Мене звати Юрій Зайчик, я Senior Java Developer у DXC Luxoft. Працюю у сфері автомобільної навігації. Декілька років тому я вже робив статтю на схожу тему. Сьогодні, коли AI є головним трендом у розробці програмних продуктів, хочу розглянути, що змінилось у Java з DS & UI за два роки.

JEPs and OpenJDK projects

У вересні минулого року стався реліз версії JDK 21, яка отримала LTS — довгострокову підтримку. А у березні цього року був реліз JDK 22.

Інкубаторна фіча, про яку я писав у попередній статті — JEP 460 (Vector API), досі знаходиться у preview (вже сьома ітерація!). За цей період не відбулось значних змін, лише оптимізації та фікси. Тож переваги, які може надати це API для векторних обчислень, ті ж самі. На жаль, поки не відбудеться реліз, їх використання обмежується лише експериментами.

Та за цей час з’явилися й розвинулися ще декілька інкубаторних проєктів, які зможуть бути корисними для роботи у сфері AI та DS.

Foreign Function & Memory API

Одним із таких є JEP 454 — Foreign Function & Memory API, що отримав реліз у JDK версії 22. Основна ідея цього модуля — створити механізм ефективної взаємодії між програмами, які виконуються у JVM, й зовнішніми програмами та даними за межами Java runtime. Тобто він повинен замінити старий та ненадійний механізм JNI (Java Native Interface).

Як це може бути корисним в контексті ML? Існує вже багато бібліотек, які розроблені на інших мовах програмування (Tensorflow, MLpack, quanteda тощо). Тож можна використати цей API, щоб викликати методи зі сторонніх бібліотек та обмінюватися з ними даними.

Як приклад розглянемо роботу з бібліотекою XGBoost. Цей блок коду створює матрицю дійсних чисел та ініціалізує XgBoost.

public class Main {
    public static void main(String[] args) {
        try ( var arena = Arena.ofConfined()) {
            Linker linker = Linker.nativeLinker();
            SymbolLookup targetLookup = SymbolLookup.libraryLookup("libxgboost.so.0",
                                                                            arena);
            MemorySegment xgMatrixCreateMseg = 
			targetLookup.find("XGDMatrixCreateFromMat").orElseThrow();
            MemorySegment xgBoosterCreateMseg = 
			targetLookup.find("XGBoosterCreate").orElseThrow();

            float[] matrix = {1.1f, 2.2f, 3.3f, 4.4f, 5.0f, 6.0f};
            MemorySegment arrayMseg = arena.allocateFrom(ValueLayout.JAVA_FLOAT, 
									matrix);
            MemorySegment dMatrixHandle = arena.allocate(ValueLayout.ADDRESS);

            MethodHandle xgMatrixCrtHandle = linker.downcallHandle(xgMatrixCreateMseg,
                    FunctionDescriptor.of(ValueLayout.JAVA_INT,
                            ValueLayout.ADDRESS,
                            ValueLayout.JAVA_INT,
                            ValueLayout.JAVA_INT,
                            ValueLayout.JAVA_FLOAT,
                            ValueLayout.ADDRESS));

            int res = (int) xgMatrixCrtHandle.invoke(arrayMseg, 3, 2, -1.0f, 
								dMatrixHandle);
            if (res != 0) {
                throw new RuntimeException("Matrix creation failed");
            }

            MemorySegment dMatrixHandleVal = 
					dMatrixHandle.get(ValueLayout.ADDRESS, 0);
            System.out.println("matrix handle address: " + 
						dMatrixHandleVal.address());

            MemorySegment dMatrixHandleArr = arena.allocateFrom(ValueLayout.ADDRESS, 
								dMatrixHandleVal);


            MethodHandle xgBoosterCreateHandle = 
		linker.downcallHandle(xgBoosterCreateMseg,
                    FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, 
							ValueLayout.JAVA_INT));

            var boosterCallMseg = 
		(MemorySegment) xgBoosterCreateHandle.invoke(dMatrixHandleArr, 1);

            if (boosterCallMseg != null) {
                System.out.println("Booster handle address: " + 
					boosterCallMseg.address());
            }

            MethodHandle xgBoostFreeHandle = 
		linker.downcallHandle(
			targetLookup.find("XGBoosterFree").orElseThrow(),
                    	FunctionDescriptor.ofVoid(ValueLayout.ADDRESS));

            xgBoostFreeHandle.invoke(boosterCallMseg);

            MethodHandle xgDMatrixFree = linker.downcallHandle(targetLookup.find("XGDMatrixFree").orElseThrow(),
                    FunctionDescriptor.ofVoid(ValueLayout.ADDRESS));

            xgDMatrixFree.invoke(dMatrixHandle);
        }
        catch (Throwable e)
        {
            throw new RuntimeException(e);
        }
    }
}

Спочатку ініціалізуємо так звану «арену». Це область, яка контролює цикл роботи з зовнішніми блоками пам’яті. Далі створюємо lookup для знаходження необхідних методів у зовнішній бібліотеці й за допомогою метода find знаходимо посилання на потрібні нам методи. Створюємо Java-масив дійсних чисел, за допомогою методу arena.allocateFrom() виділяємо для нього місце у зовнішні пам’яті та переносимо туди дані.

Після цього створюємо керуючу сігнатуру для зовнішньої функції «XGDMatrixCreateFromMat», посилання на яку ми вже попередньо створили. За допомогою FunctionDescriptor.of() описуємо тип даних, який повертає функція (перший аргумент у списку) та типи даних, які вона приймає. Тип ValueLayout.ADDRESS позначає посилання на область у зовнішній пам’яті, як вказівники у мові С.

Далі виконуємо виклик функції з відповідними параметрами, серед яких — посилання на масив даних, кількість рядків та стовпців, значення, яке відображатиме відсутність даних та посилання на створену структуру матриці у зовнішній пам’яті. Вичитуємо дані з посилання на структуру матриці:

MemorySegment dMatrixHandleVal = MatrixHandle.get(ValueLayout.ADDRESS, 0);

Яке потім використаємо як аргумент для виклику функції XGBoosterCreate. Наприкінці вдаємось до функцій «XGBoosterFree» та «XGDMatrixFree», щоб звільнити ресурси.

Babylon

Ще один цікавий проєкт — Babylon. Його основна ідея — створити механізм, за допомогою якого було б можливо транслювати Java-код в інші форми коду. Наприклад, створити з Java-методу GPU kernel. Щоб цього досягти, автори планують розширити функціонал Java Reflection API.

Розглянемо цей процес детальніше. Java-розробники, працюючи над своїми програмами, записують Java-код у текстові файли з розширенням «.java». Щоб із цього вихідного коду зробити програму, треба передати його на обробку Java-компілятору. Компілятор, своєю чергою, перетворює вихідний код у структури AST (Abstract Syntax Trees).

Наступний крок — перевірка AST на консистентність та генерація з них байт-коду. JVM використовує байт-код, щоб перетворювати його на нативні машинні інструкції. Усі ці стадії від текстового файлу до машинних інструкцій представляють однакову програму. AST семантично ближча до Java-синтаксису, а байт-код ближчий до машинних інструкцій, наприклад, type-erasure. Тож жодна з репрезентацій не підходить для відображення рефлексійної моделі коду. Рефлексійна модель не повинна мати жорсткої прив’язки до синтаксису Java, але повинна однозначно описувати структуру програми та інформацію про типи.

Автори пояснюють рефлексійну модель коду таким чином: вона складається із таких елементів, як-то операції, тіла, та блоки, які формують дерево. Операція може складатися з одного або декількох тіл. Тіло складається з одного або декількох блоків. Блок містить ланцюг з однієї або декількох операцій. Він може задавати параметри та значення, Операція може задавати результат, який операція повертає, та значення. Як операнди можуть використовуватися задані значення. Значення повинні мати тип.

Подивимось, як це виглядає на прикладі. Задамо такий Java-метод:

@CodeReflection
static double sub(double a, double b) {
   return a - b;
}

Анотація @CodeReflection вкаже компілятору, що для цього метода треба створити рефлексійну модель.

Якщо пройтись елементами створеної моделі, побачимо таке дерево:

class java.lang.reflect.code.op.CoreOp$FuncOp
  class java.lang.reflect.code.Body
    class java.lang.reflect.code.Block
      class java.lang.reflect.code.op.CoreOp$VarOp
      class java.lang.reflect.code.op.CoreOp$VarOp
      class java.lang.reflect.code.op.CoreOp$VarAccessOp$VarLoadOp
      class java.lang.reflect.code.op.CoreOp$VarAccessOp$VarLoadOp
      class java.lang.reflect.code.op.CoreOp$SubOp
      class java.lang.reflect.code.op.CoreOp$ReturnOp

У його структурі можемо побачити частини рефлексійної моделі, про які говорили вище — операції, тіла, блоки. Ця модель однозначно описує структури програми та може бути використана для створення кода іншою мовою програмування. Також автори додали механізм її текстового відображення у більш зрозумілому для людини форматі:

func @"sub" @loc="19:5:file:/.../ExpressionGraphs.java" (%0 : double, %1 : double)double -> {
    %2 : Var<double> = var %0 @"a" @loc="19:5";
    %3 : Var<double> = var %1 @"b" @loc="19:5";
    %4 : double = var.load %2 @loc="21:16";
    %5 : double = var.load %3 @loc="21:20";
    %6 : double = sub %4 %5 @loc="21:16";
    return %6 @loc="21:9";
};

Такий формат дуже корисний для тестування моделей.

Детальніше про це можна подивитись у презентації Пола Сандоза та подивитись приклади у GitHub-репозиторії проєкту.

Java GPGPU libraries

Щоб продовжити тему обчислень за допомогою GPU, подивимось, які наявні на Java бібліотеки дозволяють це робити.

Aparapi

Одна з них — Aparapi, що дозволяє конвертувати Java-код в Kernel-код OpenCL. Як відомо, у центральному процесорі комп’ютера можуть знаходитись від одиниць до десятків обчислювальних ядер.

Ці ядра залежно від архітектури RISC або CISC мають певний набір команд, яких достатньо для виконання звичайних програм на комп’ютері — математичних обчислень, взаємодій з периферійними пристроями та ресурсами тощо.

Своєю чергою, ядра GPU мають дуже обмежений набір команд для виконання математичних операцій, притаманних потребам комп’ютерної графіки, які піддаються паралелізації. Але коли ми маємо сотні або тисячі таких ядер, це суттєво прискорює роботу. У контексті Machine learning та Data Science, де теж трапляється розпаралелювання операцій над даними, які зазвичай мають матричну або векторну структуру, застосування GPU є дуже ефективним.

Розглянемо приклад розв’язання задачі векторного обчислення з попередньої статті.

Коротко нагадаю зміст задачі: треба обчислити довжину векторної лінії.

A graph with purple lines and dotsDescription automatically generated

Розв’язання даної задачі за допомогою Aparapi матиме такий вигляд:

public static double calculate(VectorLine line, boolean addLogging) {
    var x1 = line.latitudes();
    var x2 = Arrays.copyOfRange(x1, 1, x1.length);
    var y1 = line.longitudes();
    var y2 = Arrays.copyOfRange(y1, 1, y1.length);

    double[] hypotenouses = new double[x2.length];

    var kernel = new Kernel() {
        @Override
        public void run() {
            int gid = getGlobalId();
            double dx = x2[gid] - x1[gid];
            double dy = y2[gid] - y1[gid];

            hypotenouses[gid] = sqrt(dx * dx + dy * dy);
        }
    };

    kernel.execute(Range.create(x2.length));
    if (addLogging) {
        log();
    }
    kernel.dispose();
    return DoubleStream.of(hypotenouses).sum();
}

Вхідні дані надходять у вигляді VectorLine, цей клас містить два вектори: координати абсцис та ординат. Створюємо допоміжні вектори для операцій векторної різниці. Kernel-код буде отримувати по парі координат x та y, виконувати віднімання між парами координат, зведення у квадрат та знаходити квадратний корінь суми квадратів. Щоб задати Kernel-код, який буде переведений в OpenCL, потрібно створити екземпляр класу Kernel і імплементувати код, перевизначаючи абстрактний метод run().

Щоб запустити цей код на GPU, потрібно викликати метод kernel.execute(), вхідним параметром якого є діапазон, який визначає кількість необхідних CUDA-ядер. За допомогою метода getGlobalId() можемо отримати порядковий номер запущеного Kernel. Його будемо використовувати як індекс для отримання даних із вхідних масивів і запису результату у вихідний масив hypotenouses. Коли всі ядра закінчать обчислення, рахуємо суму значень вихідного масиву й отримуємо довжину вектора.

Jcuda

Ще одна бібліотека, схожа на Aparapi — Jcuda. На відміну від Aparapi, де ми могли простий нативний код на Java транслювати у Kernel-код, Jcuda дає інтерфейси взаємодії з зовнішньою пам’яттю та зовнішніми методами CUDA за допомогою проміжної бібліотеки. Цей підхід схожий до того, що використовується у новому Java Foreign Functions and Memorz API, який ми розглядали вище.

Розгляньмо, як виконати задачу з обчислення довжини вектора за допомогою цієї бібліотеки.

public static double calculate(VectorLine line) {
    // Enable exceptions and omit all subsequent error checks
    JCudaDriver.setExceptionsEnabled(true);

    var x1 = line.latitudes();
    var x2 = Arrays.copyOfRange(x1, 1, x1.length);
    var y1 = line.longitudes();
    var y2 = Arrays.copyOfRange(y1, 1, y1.length);

    var vecLength = x2.length;

    VecDouble.init();

    var devPtrX1 = createDoubleVectorPointer(x1, x2.length);
    var devPtrX2 = createDoubleVectorPointer(x2, x2.length);
    var devPtrY1 = createDoubleVectorPointer(y1, y2.length);
    var devPtrY2 = createDoubleVectorPointer(y2, y2.length);

    var devPtrDX = createDoubleVectorPointer(null, vecLength, false);
    var devPtrDY = createDoubleVectorPointer(null, vecLength, false);

    double[] hypotenuses = new double[x2.length];
    var devResultPtr = createDoubleVectorPointer(hypotenuses, hypotenuses.length, false);

    VecDouble.sub(vecLength, devPtrDX, devPtrX2, devPtrX1);
    VecDouble.sub(vecLength, devPtrDY, devPtrY2, devPtrY1);
    VecDouble.mul(vecLength, devPtrDX, devPtrDX, devPtrDX);
    VecDouble.mul(vecLength, devPtrDY, devPtrDY, devPtrDY);
    VecDouble.add(vecLength, devResultPtr, devPtrDX, devPtrDY);
    VecDouble.sqrt(vecLength, devResultPtr, devResultPtr);

    cuMemcpyDtoH(Pointer.to(hypotenuses), devResultPtr, vecLength * Sizeof.DOUBLE);

    Stream.of(devResultPtr, devPtrDX, devPtrDY, devPtrX1, devPtrX2, devPtrY1, devPtrY2)
            .forEach(JCudaDriver::cuMemFree);

    VecDouble.shutdown();

    return DoubleStream.of(hypotenuses).sum();
}

private static CUdeviceptr createDoubleVectorPointer(double[] array, int length, boolean copyToDevice) {
    var devPtr = new CUdeviceptr();
    var size = length * Sizeof.DOUBLE;
    cuMemAlloc(devPtr, size);

    if (copyToDevice)
        cuMemcpyHtoD(devPtr, Pointer.to(array), size);

    return devPtr;
}

Як і у попередньому прикладі, спочатку ініціалізуємо вхідні масиви. Далі потрібно ініціалізувати роботи із векторами для дробових чисел подвійної точності VecDouble.init(). Потім для кожного із вхідних векторів треба зробити Pointer, який виконуватиме роль посередника між контейнером даних у JVM та у GPU-пристрої.

Виконуємо послідовність векторних операцій — VecDouble.sub, VecDouble.mul, VecDouble.add, VecDouble.sqrt. Результати отримання квадратних коренів записуємо у контейнер devResultPtr. За допомогою методу cuMemcpyDtoH копіюємо результати обчислень із пам’яті пристрою GPU у пам’ять JVM, звільняємо пам’ять у GPU і рахуємо суму елементів вихідного масиву.

Якщо порівнювати підходи, застосовані у цих двох бібліотеках, то в Aparapi отримаємо більш наближений до Java варіант, де описуємо Kernel-код на Java. Звісно, маємо на увазі обмеження: виконувати можна лише примітивні операції. В той час, як JCUDA дає нам інтерфейс для роботи з нативною бібліотекою, що підключена через JNI. Aparapi ховає від нас процес взаємодії із GPU, водночас з JCUDA треба власноруч виконувати деякі низькорівневі операції, як-то виділення та очищення пам’яті пристрою GPU.

Java ML libraries

Одним із цікавих проєктів на Java, який створений безпосередньо для розв’язання завдань Machine learning, є бібліотека Tribuo від Oracle labs. Вона може бути використана для типових задач ML: класифікація, регресія, кластеризація. Має зручний інтерфейс для взаємодії зі сторонніми бібліотеками xgboost, liblinear та іншими.

Що її відрізняє від інших інструментів, так це об’єктноорієнтований підхід, тобто API має суворо типізовану форму зі спеціальними параметризованими класами та інтерфейсами для визначення моделей, прогнозів та наборів даних. Розглянемо її використання на прикладі.

Спробуємо побудувати модель для розпізнавання рукописних цифр. Візьмемо відомий для цього датасет MNIST.

MNIST sample images

Спочатку робимо ініціалізацію датасету: визначаємо imgTrainSource та imgTestSource відповідно для тренування та подальшого тестування створеної моделі. Потім визначаємо алгоритм для створення моделі за допомогою convolutional-нейромережі (у нашому випадку це буде AdaGrad) та задаємо параметри для нього.

public static void trainAndTest(Path trainImagesPath, Path trainLabelsPath, Path testImagesPath, Path testLabelsPath,
                                int mlpEpochs, int cnnEpcochs, int imageSize) throws IOException
{
    var labelFactory = new LabelFactory();

    var imgTrainSource = new IDXDataSource<>(trainImagesPath, trainLabelsPath, labelFactory);
    var imgTestSource = new IDXDataSource<>(testImagesPath, testLabelsPath, labelFactory);

    var imgTrain = new MutableDataset<>(imgTrainSource);
    var imgTest = new MutableDataset<>(imgTestSource);

    var inputName = "MNIST_INPUT";

    var gradAlgorithm = GradientOptimiser.ADAGRAD;
    var gradParams = Map.of("learningRate", 0.01F, "initialAccumulatorValue", 0.01F);

    if (mlpEpochs <= 0)
        return;

    doCnnTest(inputName, imgTrain, imgTest, gradAlgorithm, gradParams, cnnEpcochs, imageSize);
}

Далі будуємо модель та перевіряємо її.

private static Model<Label> doCnnTest (String inputName, MutableDataset<Label> trainSet, MutableDataset<Label> testSet,
                        GradientOptimiser gradAlgorithm, Map<String, Float> gradParams, int epochs, int imageSize)
{
    var cnnTuple = CNNExamples.buildLeNetGraph(inputName, imageSize, 255, trainSet.getOutputs().size());
    var imageConverter = new ImageConverter(inputName, imageSize, imageSize, 1);

    var imgCNNTrainer = new TensorFlowTrainer<Label>(
            cnnTuple.graphDef,
            cnnTuple.outputName,
            gradAlgorithm,
            gradParams,
            imageConverter,
            new LabelConverter(),
            16, //train batch size
            epochs,
            16, // test batch size
            -1 // logging interval 
    );

    var cnnStart = System.currentTimeMillis();
    var cnnModel = imgCNNTrainer.train(trainSet);
    var cnnEnd = System.currentTimeMillis();

    System.out.println("CNN training took:" + Util.formatDuration(cnnStart, cnnEnd));

    var labelEval = new LabelEvaluator();
    var predictions = model.predict(testSet);
    var cnnEvaluation = labelEval.evaluate(model, predictions,   
         testSet.getProvenance());

    System.out.println(cnnEvaluation.toString());
    
    return cnnModel;
}

Як видно із коду за лаштунками Tribuo буде використовувати бібліотеку Tensorflow.

Результати перевірки:

A screenshot of a computer screenDescription automatically generated

9 класів — це відповідно цифри від 0 до 9. n — кількість цифр у датасеті, tp, fn, fp — true positive (розпізнано правильно), false negative (розпізнана як інший клас), false positive (інший клас, розпізнаний як заданий) відповідно.

Точність розпізнавання побудованої моделі — 0.989. По коду бачимо, що процес навчання та використання моделі має звичну для Java-розробників об’єктноорієнтовану концепцію та назви класів і методів доволі самоописові. Як на мене, такий процес зручніший за взаємодію з безликими масивами чисел. Більш детально з іншими прикладами можна ознайомитись на сайті Tribuo.

Висновки

На цей час Java є не найліпшим інструментом для роботи з моделями напряму, насамперед через відсутність зручних вбудованих можливостей візуалізації усіх проміжних етапів створення та навчання моделі.

Але якщо застосовувати Java для роботи із вже готовими моделями та даними для вирішення певних бізнес-завдань, в цьому Java може конкурувати. Адже має багату екосистему, яка дозволяє створювати продукти у різних парадигмах.

👍ПодобаєтьсяСподобалось7
До обраногоВ обраному0
LinkedIn
Дозволені теги: blockquote, a, pre, code, ul, ol, li, b, i, del.
Ctrl + Enter
Дозволені теги: blockquote, a, pre, code, ul, ol, li, b, i, del.
Ctrl + Enter

Так сталося, що я в Java на аматорскому рівні займаюсь (для себе) проектами NLP, та, трохи AI. І користуюсь бібліотеками, які Ви не згадуєте — можете прокоментувати — вони є нецікавими/застарілими, або щось інше? Я про deeplearning4j, deepjavalibrary (djl), для прикладу. Там є реалізації старих та більш нових технологій, PyTorch, huggingface моделі можно юзать.

Вітаю, дякую з коментар! Коли я починав занурюватись у тему AI, то хотілося цим займатись у звичному та «рідному» Java-середовищі, бо я не є великим фанатом Пайтону. Тому, шукаючи інструменти для роботи з GPU та ML, я орієнтувався на нформацію від профільних офіційних Java джерел (типу Java magazine, Java newscast та ін.) і там згадувались інструменти, які я згадував у статті.
Про deeplearning4j я теж чув, але не працював із цією бібліотекою. Можливо спробую погратися і порівняти з тими інструментами, що я використовую і освітити це в наступних статтях :)

Той самий випадок коли хтось закрив твій гештальт. Дякую за статтю !

Підписатись на коментарі