Some checks failed
Periodic Merges (6h) / master → staging-nixos (push) Failing after 12m50s
Periodic Merges (6h) / master → staging-next (push) Failing after 12m54s
Periodic Merges (24h) / merge-base(master,staging) → haskell-updates (push) Failing after 11m54s
Periodic Merges (6h) / staging-next → staging (push) Failing after 12m13s
Periodic Merges (24h) / staging-next-25.05 → staging-25.05 (push) Failing after 13m24s
Periodic Merges (24h) / release-25.05 → staging-next-25.05 (push) Failing after 14m28s
30 lines
594 B
Nix
30 lines
594 B
Nix
{
|
|
jax,
|
|
pkgs,
|
|
}:
|
|
|
|
pkgs.writers.writePython3Bin "jax-test-cuda"
|
|
{
|
|
libraries = [
|
|
jax
|
|
]
|
|
++ jax.optional-dependencies.cuda;
|
|
}
|
|
''
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import random
|
|
from jax.experimental import sparse
|
|
|
|
assert jax.devices()[0].platform == "gpu" # libcuda.so
|
|
|
|
rng = random.key(0) # libcudart.so, libcudnn.so
|
|
x = random.normal(rng, (100, 100))
|
|
x @ x # libcublas.so
|
|
jnp.fft.fft(x) # libcufft.so
|
|
jnp.linalg.inv(x) # libcusolver.so
|
|
sparse.CSR.fromdense(x) @ x # libcusparse.so
|
|
|
|
print("success!")
|
|
''
|