Google JAX
Google JAX | |
---|---|
![]() | |
![]() | |
Тип | Machine learning |
Разработчик | Google (компания) |
Написана на | Python, C++ |
Операционные системы | Linux, macOS, Windows |
Первый выпуск | 12 декабря 2018 |
Аппаратные платформы | Python, NumPy |
Последняя версия | |
Тестовая версия | v0.3.13 (16 мая 2022 ) |
Репозиторий | github.com/google/jax |
Лицензия | Apache 2.0 |
Сайт | jax.readthedocs.io/en/la… |
Google JAX — фреймворк машинного обучения для преобразования числовых функций.[2][3][4] Представляет объединение измененной версии autograd (автоматическое получение градиентной функции через дифференцирование функции) и TensorFlow's XLA (Ускоренная линейная алгебра (Accelerated Linear Algebra)). Спроектирован таким образом, чтобы максимально соответствовать структуре и рабочему процессу NumPy для работы с различными существующими фреймворками, такими как TensorFlow и PyTorch.[5][6] Основными функциями JAX являются:[2]
- grad: автоматическое дифференцирование
- jit: компиляция
- vmap: автоматическая векторизация
- pmap: SPMD программирование
grad
Код представленный ниже демонстрирует функцию автоматического дифференцирования пакета grad.
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
Код должен напечатать:
0.19661194
jit
Код представленный ниже демонстрирует функцию оптимизации через слияние пакета jit.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
Вычислительное время для jit_cube (строка 17) должно быть заметно короче, чем для cube (строка 16). Увеличение значения в строке 7, будет увеличивать разницу.
vmap
Код представленный ниже демонстрирует функцию векторизации пакета vmap.
# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
return flat_grads
Изображение в правой части раздела иллюстрирует идея векторизованного сложения.

pmap
Код представленный ниже демонстрирует распараллеливание для умножения матриц пакета pmap.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
Последняя строка должна напечатать значенияː
[1.1566595 1.1805978]
Библиотеки, использующие Jax
Несколько библиотек Python используют Jax в качестве бэкенда, включая:
- Flax — высокоуровневая библиотека для нейронных сетей изначально разработанная Google Brain.[7]
- Haiku — объектно-ориентированная библиотека для нейронных сетей разработанная DeepMind.[8]
- Equinox — библиотека, основанная на идеи представления параметризованных функций (включая нейронные сети) как PyTrees. Она была создана Патриком Кидгером.[9]
- Optax — библиотека для градиентной обработки и оптимизации разработанная DeepMind.[10]
- RLax — библиотека для разработки агентов для обучения с подкреплением, разработанная DeepMind.[11]
См. также
- NumPy
- TensorFlow
- PyTorch
- CUDA
- Автоматическое дифференцирование
- JIT-компиляция
- Векторизация
- Автоматическое распараллеливание
Примечания
- ↑ https://github.com/google/jax/releases/tag/jax-v0.4.24
- ↑ 1 2 Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, Архивировано 18 июня 2022, Дата обращения: 18 июня 2022
- ↑ Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1—3. Архивировано (PDF) 21 июня 2022.
{{cite journal}}
: Википедия:Обслуживание CS1 (дата и год) (ссылка) - ↑ Using JAX to accelerate our research (англ.). www.deepmind.com. Дата обращения: 18 июня 2022. Архивировано 18 июня 2022 года.
- ↑ Lynley, Matthew Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta (амер. англ.). Business Insider. Дата обращения: 21 июня 2022. Архивировано 21 июня 2022 года.
- ↑ Why is Google's JAX so popular? (амер. англ.). Analytics India Magazine (25 апреля 2022). Дата обращения: 18 июня 2022. Архивировано 18 июня 2022 года.
- ↑ Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, Архивировано 3 сентября 2022, Дата обращения: 29 июля 2022
- ↑ Haiku: Sonnet for JAX, DeepMind, 2022-07-29, Архивировано 29 июля 2022, Дата обращения: 29 июля 2022
- ↑ Kidger, Patrick (2022-07-29), Equinox, Архивировано 19 сентября 2023, Дата обращения: 29 июля 2022
- ↑ Optax, DeepMind, 2022-07-28, Архивировано 7 июня 2023, Дата обращения: 29 июля 2022
- ↑ RLax, DeepMind, 2022-07-29, Архивировано 26 апреля 2023, Дата обращения: 29 июля 2022
Ссылки
- Documentationː jax.readthedocs.io
- Colab (Jupyter/iPython) Quickstart Guideː colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb
- TensorFlow's XLAː tensorflow.org/xla (Accelerated Linear Algebra)
- Intro to JAX: Accelerating Machine Learning research на YouTube
- Original paperː mlsys.org/Conferences/doc/2018/146.pdf