JAXยค
The Python library JAX has been used as the numerical engine for calculations and it also manages the parallelization. JAX is designed for high-performance numerical computing with focus on machine learning applications. It combines XLA (accelerated linear algebra) and Autograd, the former being a compiler that optimises models for different hardware platforms, the latter is an Automatic Differentiation (AD) tool in Python. See the documentation. The extensible system of composable function transformations provides a set of important features for Computational Science as illustrated in Figure above. For instance, the vmap function allows for complex vectorisation operations and the pmap function for Single-Program Multiple-Data (SPMD) parallelisation. Both forward and reverse mode automatic differentiation are supported. Finally the just-in-time compilation (jit) relies on the XLA engine to compile and execute functions on CPUs but also on accelerators such as GPUs and TPUs, offering a versatile solution for seamlessly connecting the software to various types of hardware.
The XLA module is a domain-specific compiler for linear algebra that optimizes computations for both CPUs and GPUs. In fact, XLA is platform agnostic and achieves optimized performance on the target architecture orchestrating a complex process that encompassing a series of optimizations and transformations: the source code is first converted into HLO (High-Level Optimizer) code, an specialized language derived from a graph representation of the computations; XLA performs optimizations on the HLO code (geared towards high-level mathematical operations, particularly those in linear algebra and machine learning models), and are independent of the hardware architecture, such as operation fusion. It then carries optimizations for the particular architecture in use. From there, the Low Level Virtual Machine (LLVM) toolkit is utilized to produce an Intermediate Representation that the LLVM compiler can understand, which then performs further optimizations before outputting the machine code. When it comes to leveraging the computational power of GPUs, the link between XLA and CUDA kernels is critical. On the one hand, JAX utilizes CUDA libraries such as cuBLAS for dense linear algebra; on the other hand, it is capable of generating custom CUDA kernels for operations that are not efficiently covered by standard libraries.