mirror of
https://github.com/NixOS/nixpkgs.git
synced 2026-03-08 01:24:09 +01:00
python3Packages.flashinfer: 0.6.1 -> 0.6.4 (#495151)
This commit is contained in:
commit
205ef081d6
12 changed files with 314 additions and 90 deletions
69
pkgs/development/python-modules/apache-tvm-ffi/default.nix
Normal file
69
pkgs/development/python-modules/apache-tvm-ffi/default.nix
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
{
|
||||
lib,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
|
||||
# build-system
|
||||
cmake,
|
||||
cython,
|
||||
ninja,
|
||||
scikit-build-core,
|
||||
setuptools-scm,
|
||||
|
||||
# dependencies
|
||||
typing-extensions,
|
||||
|
||||
# tests
|
||||
numpy,
|
||||
pytestCheckHook,
|
||||
writableTmpDirAsHomeHook,
|
||||
}:
|
||||
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "apache-tvm-ffi";
|
||||
version = "0.1.9";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "apache";
|
||||
repo = "tvm-ffi";
|
||||
tag = "v${finalAttrs.version}";
|
||||
fetchSubmodules = true;
|
||||
hash = "sha256-XnlM//WW2TbjbmzYBq6itJQ7R3J646UMVQUVhV5Afwc=";
|
||||
};
|
||||
|
||||
build-system = [
|
||||
cmake
|
||||
cython
|
||||
ninja
|
||||
scikit-build-core
|
||||
setuptools-scm
|
||||
];
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
dependencies = [
|
||||
typing-extensions
|
||||
];
|
||||
|
||||
optional-dependencies = {
|
||||
cpp = [
|
||||
ninja
|
||||
];
|
||||
};
|
||||
|
||||
pythonImportsCheck = [ "tvm_ffi" ];
|
||||
|
||||
nativeCheckInputs = [
|
||||
numpy
|
||||
pytestCheckHook
|
||||
writableTmpDirAsHomeHook
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "Open ABI and FFI for Machine Learning Systems";
|
||||
changelog = "https://github.com/apache/tvm-ffi/releases/tag/${finalAttrs.src.tag}";
|
||||
homepage = "https://github.com/apache/tvm-ffi";
|
||||
license = lib.licenses.asl20;
|
||||
maintainers = with lib.maintainers; [ GaetanLepage ];
|
||||
};
|
||||
})
|
||||
52
pkgs/development/python-modules/cuda-pathfinder/default.nix
Normal file
52
pkgs/development/python-modules/cuda-pathfinder/default.nix
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
{
|
||||
lib,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
|
||||
# build-system
|
||||
setuptools,
|
||||
setuptools-scm,
|
||||
|
||||
# tests
|
||||
pytest-mock,
|
||||
pytestCheckHook,
|
||||
}:
|
||||
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "cuda-pathfinder";
|
||||
version = "1.4.0";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "NVIDIA";
|
||||
repo = "cuda-python";
|
||||
tag = "cuda-pathfinder-v${finalAttrs.version}";
|
||||
hash = "sha256-Bsou6vLyMBNbVMPT4vtnWpoi05lXG6pjhuee6Hg/Mm8=";
|
||||
};
|
||||
|
||||
sourceRoot = "${finalAttrs.src.name}/cuda_pathfinder";
|
||||
|
||||
build-system = [
|
||||
setuptools
|
||||
setuptools-scm
|
||||
];
|
||||
|
||||
pythonImportsCheck = [
|
||||
"cuda"
|
||||
"cuda.pathfinder"
|
||||
];
|
||||
|
||||
nativeCheckInputs = [
|
||||
pytest-mock
|
||||
pytestCheckHook
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "one-stop solution for locating CUDA components";
|
||||
homepage = "https://github.com/NVIDIA/cuda-python/tree/main/cuda_pathfinder";
|
||||
changelog = "https://nvidia.github.io/cuda-python/cuda-pathfinder/${finalAttrs.version}/release/${finalAttrs.version}-notes.html";
|
||||
license = lib.licenses.asl20;
|
||||
maintainers = with lib.maintainers; [ GaetanLepage ];
|
||||
platforms = lib.platforms.linux;
|
||||
};
|
||||
})
|
||||
|
|
@ -1,17 +1,24 @@
|
|||
{
|
||||
lib,
|
||||
stdenv,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
|
||||
# build-system
|
||||
cython,
|
||||
fastrlock,
|
||||
numpy,
|
||||
pytestCheckHook,
|
||||
mock,
|
||||
setuptools,
|
||||
|
||||
# nativeBuildInputs
|
||||
cudaPackages,
|
||||
addDriverRunpath,
|
||||
symlinkJoin,
|
||||
addDriverRunpath,
|
||||
|
||||
# dependencies
|
||||
numpy,
|
||||
cuda-pathfinder,
|
||||
|
||||
# tests
|
||||
pytest-mock,
|
||||
pytestCheckHook,
|
||||
}:
|
||||
|
||||
let
|
||||
|
|
@ -39,8 +46,7 @@ let
|
|||
libcurand
|
||||
libcusolver
|
||||
libcusparse
|
||||
# NOTE: libcusparse_lt is too new for CuPy, so we must do without.
|
||||
# libcusparse_lt
|
||||
libcusparse_lt # cusparseLt.h
|
||||
]
|
||||
);
|
||||
cudatoolkit-joined = symlinkJoin {
|
||||
|
|
@ -49,19 +55,26 @@ let
|
|||
outpaths ++ lib.concatMap (outpath: lib.map (output: outpath.${output}) outpath.outputs) outpaths;
|
||||
};
|
||||
in
|
||||
buildPythonPackage.override { stdenv = cudaPackages.backendStdenv; } rec {
|
||||
buildPythonPackage.override { stdenv = cudaPackages.backendStdenv; } (finalAttrs: {
|
||||
pname = "cupy";
|
||||
version = "13.6.0";
|
||||
version = "14.0.1";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "cupy";
|
||||
repo = "cupy";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-nU3VL0MSCN+mI5m7C5sKAjBSL6ybM6YAk5lJiIDY0ck=";
|
||||
tag = "v${finalAttrs.version}";
|
||||
fetchSubmodules = true;
|
||||
hash = "sha256-TaEJ0BveUCXCRrNq9L49Tfbu0334+cANcVm5qnSOE1Q=";
|
||||
};
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace pyproject.toml \
|
||||
--replace-fail \
|
||||
"Cython>=3.1,<3.2" \
|
||||
"Cython"
|
||||
'';
|
||||
|
||||
env = {
|
||||
LDFLAGS = toString [
|
||||
# Fake libcuda.so (the real one is deployed impurely)
|
||||
|
|
@ -83,7 +96,6 @@ buildPythonPackage.override { stdenv = cudaPackages.backendStdenv; } rec {
|
|||
|
||||
build-system = [
|
||||
cython
|
||||
fastrlock
|
||||
setuptools
|
||||
];
|
||||
|
||||
|
|
@ -100,13 +112,13 @@ buildPythonPackage.override { stdenv = cudaPackages.backendStdenv; } rec {
|
|||
];
|
||||
|
||||
dependencies = [
|
||||
fastrlock
|
||||
cuda-pathfinder
|
||||
numpy
|
||||
];
|
||||
|
||||
nativeCheckInputs = [
|
||||
pytest-mock
|
||||
pytestCheckHook
|
||||
mock
|
||||
];
|
||||
|
||||
# Won't work with the GPU, whose drivers won't be accessible from the build
|
||||
|
|
@ -124,12 +136,12 @@ buildPythonPackage.override { stdenv = cudaPackages.backendStdenv; } rec {
|
|||
meta = {
|
||||
description = "NumPy-compatible matrix library accelerated by CUDA";
|
||||
homepage = "https://cupy.chainer.org/";
|
||||
changelog = "https://github.com/cupy/cupy/releases/tag/${src.tag}";
|
||||
changelog = "https://github.com/cupy/cupy/releases/tag/${finalAttrs.src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [
|
||||
"aarch64-linux"
|
||||
"x86_64-linux"
|
||||
];
|
||||
maintainers = [ ];
|
||||
maintainers = with lib.maintainers; [ GaetanLepage ];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@
|
|||
config,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
fetchpatch2,
|
||||
|
||||
# build-system
|
||||
apache-tvm-ffi,
|
||||
setuptools,
|
||||
|
||||
# nativeBuildInputs
|
||||
|
|
@ -29,30 +29,24 @@
|
|||
tqdm,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "flashinfer";
|
||||
version = "0.6.1";
|
||||
version = "0.6.4";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "flashinfer-ai";
|
||||
repo = "flashinfer";
|
||||
tag = "v${version}";
|
||||
tag = "v${finalAttrs.version}";
|
||||
fetchSubmodules = true;
|
||||
hash = "sha256-NRjas11VvvCY6MZiZaYtxG5MXEaFqfbhJxflUT/uraE=";
|
||||
hash = "sha256-Hq3oTeEJHRvXwThI8vc06E3Ot/FnPP0sZUfze3ISa2o=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
# TODO: remove patch with update to v0.5.2+
|
||||
# Switch pynvml to nvidia-ml-py
|
||||
(fetchpatch2 {
|
||||
url = "https://github.com/flashinfer-ai/flashinfer/commit/a42f99255d68d1a54b689bd4985339c6b44963a6.patch?full_index=1";
|
||||
hash = "sha256-3XJFcdQeZ/c5fwiQvd95z4p9BzTn8pjle21WzeBxUgk=";
|
||||
})
|
||||
build-system = [
|
||||
apache-tvm-ffi
|
||||
setuptools
|
||||
];
|
||||
|
||||
build-system = [ setuptools ];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
ninja
|
||||
|
|
@ -91,6 +85,7 @@ buildPythonPackage rec {
|
|||
|
||||
pythonRemoveDeps = [
|
||||
"nvidia-cudnn-frontend"
|
||||
"nvidia-cutlass-dsl"
|
||||
];
|
||||
dependencies = [
|
||||
click
|
||||
|
|
@ -119,4 +114,4 @@ buildPythonPackage rec {
|
|||
daniel-fahey
|
||||
];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
{
|
||||
lib,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
|
||||
# build-system
|
||||
hatch-vcs,
|
||||
hatchling,
|
||||
lib,
|
||||
|
||||
# dependencies
|
||||
tomlkit,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "hatch-min-requirements";
|
||||
version = "0.2.0";
|
||||
pyproject = true;
|
||||
|
|
@ -14,7 +19,7 @@ buildPythonPackage rec {
|
|||
src = fetchFromGitHub {
|
||||
owner = "tlambert03";
|
||||
repo = "hatch-min-requirements";
|
||||
tag = "v${version}";
|
||||
tag = "v${finalAttrs.version}";
|
||||
hash = "sha256-QKO5fVvjSqwY+48Fc8sAiZazrxZ4eBYxzVElHr2lcEA=";
|
||||
};
|
||||
|
||||
|
|
@ -23,6 +28,10 @@ buildPythonPackage rec {
|
|||
hatch-vcs
|
||||
];
|
||||
|
||||
dependencies = [
|
||||
tomlkit
|
||||
];
|
||||
|
||||
# As of v0.1.0 all tests attempt to use the network
|
||||
doCheck = false;
|
||||
|
||||
|
|
@ -31,9 +40,10 @@ buildPythonPackage rec {
|
|||
meta = {
|
||||
description = "Hatchling plugin to create optional-dependencies pinned to minimum versions";
|
||||
homepage = "https://github.com/tlambert03/hatch-min-requirements";
|
||||
changelog = "https://github.com/tlambert03/hatch-min-requirements/releases/tag/${finalAttrs.src.tag}";
|
||||
license = lib.licenses.bsd3;
|
||||
maintainers = with lib.maintainers; [
|
||||
samuela
|
||||
];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
pytestCheckHook,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "kmapper";
|
||||
version = "2.1.0";
|
||||
pyproject = true;
|
||||
|
|
@ -30,7 +30,7 @@ buildPythonPackage rec {
|
|||
src = fetchFromGitHub {
|
||||
owner = "scikit-tda";
|
||||
repo = "kepler-mapper";
|
||||
tag = "v${version}";
|
||||
tag = "v${finalAttrs.version}";
|
||||
hash = "sha256-i909J0yI8v8BqGbCkcjBAdA02Io+qpILdDkojZj0wv4=";
|
||||
};
|
||||
|
||||
|
|
@ -55,11 +55,18 @@ buildPythonPackage rec {
|
|||
pytestCheckHook
|
||||
];
|
||||
|
||||
disabledTests = [
|
||||
# UnboundLocalError: cannot access local variable 'X_blend' where it is not associated with a value
|
||||
"test_tuple_projection"
|
||||
"test_tuple_projection_fit"
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "Python implementation of Mapper algorithm for Topological Data Analysis";
|
||||
homepage = "https://kepler-mapper.scikit-tda.org/";
|
||||
changelog = "https://github.com/scikit-tda/kepler-mapper/releases/tag/v${version}";
|
||||
downloadPage = "https://github.com/scikit-tda/kepler-mapper";
|
||||
changelog = "https://github.com/scikit-tda/kepler-mapper/releases/tag/${finalAttrs.src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
maintainers = [ ];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -2,19 +2,29 @@
|
|||
lib,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
|
||||
# build-system
|
||||
setuptools,
|
||||
|
||||
# dependencies
|
||||
numba,
|
||||
numpy,
|
||||
pillow,
|
||||
pytestCheckHook,
|
||||
scipy,
|
||||
setuptools,
|
||||
config,
|
||||
cudaSupport ? config.cudaSupport,
|
||||
# cuda-only
|
||||
cupy,
|
||||
pyopencl,
|
||||
|
||||
# tests
|
||||
pocl,
|
||||
pytestCheckHook,
|
||||
writableTmpDirAsHomeHook,
|
||||
|
||||
config,
|
||||
cudaSupport ? config.cudaSupport,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "pymatting";
|
||||
version = "1.1.15";
|
||||
pyproject = true;
|
||||
|
|
@ -22,7 +32,7 @@ buildPythonPackage rec {
|
|||
src = fetchFromGitHub {
|
||||
owner = "pymatting";
|
||||
repo = "pymatting";
|
||||
tag = "v${version}";
|
||||
tag = "v${finalAttrs.version}";
|
||||
hash = "sha256-rcatlQE+YgppY//ZgGY9jO5KI0ED30fLlqW9N+xRNqk=";
|
||||
};
|
||||
|
||||
|
|
@ -39,10 +49,19 @@ buildPythonPackage rec {
|
|||
pyopencl
|
||||
];
|
||||
|
||||
nativeCheckInputs = [ pytestCheckHook ];
|
||||
|
||||
pythonImportsCheck = [ "pymatting" ];
|
||||
|
||||
nativeCheckInputs = [
|
||||
pytestCheckHook
|
||||
]
|
||||
++ lib.optionals cudaSupport [
|
||||
# Provides a CPU-based OpenCL ICD so that pyopencl's module-level
|
||||
# cl.create_some_context() succeeds without GPU hardware.
|
||||
pocl
|
||||
# pocl needs a writable $HOME for its kernel cache directory.
|
||||
writableTmpDirAsHomeHook
|
||||
];
|
||||
|
||||
disabledTests = [
|
||||
# no access to input data set
|
||||
# see: https://github.com/pymatting/pymatting/blob/master/tests/download_images.py
|
||||
|
|
@ -52,14 +71,16 @@ buildPythonPackage rec {
|
|||
"test_lkm"
|
||||
];
|
||||
|
||||
# pyopencl._cl.LogicError: clGetPlatformIDs failed: PLATFORM_NOT_FOUND_KHR
|
||||
disabledTestPaths = lib.optional cudaSupport "tests/test_foreground.py";
|
||||
disabledTestPaths = lib.optionals cudaSupport [
|
||||
# Requires a CUDA driver for cupy GPU operations
|
||||
"tests/test_foreground.py"
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "Python library for alpha matting";
|
||||
homepage = "https://github.com/pymatting/pymatting";
|
||||
changelog = "https://github.com/pymatting/pymatting/blob/v${version}/CHANGELOG.md";
|
||||
changelog = "https://github.com/pymatting/pymatting/blob/${finalAttrs.src.tag}/CHANGELOG.md";
|
||||
license = lib.licenses.mit;
|
||||
maintainers = [ ];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
pytestCheckHook,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "pynndescent";
|
||||
version = "0.6.0";
|
||||
pyproject = true;
|
||||
|
|
@ -25,7 +25,7 @@ buildPythonPackage rec {
|
|||
src = fetchFromGitHub {
|
||||
owner = "lmcinnes";
|
||||
repo = "pynndescent";
|
||||
tag = "release-${version}";
|
||||
tag = "release-${finalAttrs.version}";
|
||||
hash = "sha256-RfIbPPyx+Y7niuFrLjA02cUDHTSv9s5E4JiXv4ZBNEc=";
|
||||
};
|
||||
|
||||
|
|
@ -39,14 +39,24 @@ buildPythonPackage rec {
|
|||
scipy
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "pynndescent" ];
|
||||
|
||||
nativeCheckInputs = [ pytestCheckHook ];
|
||||
|
||||
pythonImportsCheck = [ "pynndescent" ];
|
||||
disabledTests = [
|
||||
# AssertionError: Arrays are not almost equal to 6 decimals
|
||||
"test_seuclidean"
|
||||
|
||||
# sklearn.utils._param_validation.InvalidParameterError: The 'metric' parameter of
|
||||
# pairwise_distances must be a str among ...
|
||||
"test_binary_check"
|
||||
"test_sparse_binary_check"
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "Nearest Neighbor Descent";
|
||||
homepage = "https://github.com/lmcinnes/pynndescent";
|
||||
changelog = "https://github.com/lmcinnes/pynndescent/releases/tag/release-${src.tag}";
|
||||
changelog = "https://github.com/lmcinnes/pynndescent/releases/tag/${finalAttrs.src.tag}";
|
||||
license = lib.licenses.bsd2;
|
||||
maintainers = with lib.maintainers; [ mic92 ];
|
||||
badPlatforms = [
|
||||
|
|
@ -55,4 +65,4 @@ buildPythonPackage rec {
|
|||
"aarch64-linux"
|
||||
];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -33,8 +33,12 @@
|
|||
|
||||
# tests
|
||||
versionCheckHook,
|
||||
pocl,
|
||||
|
||||
withCli ? false,
|
||||
config,
|
||||
cudaSupport ? config.cudaSupport,
|
||||
writableTmpDirAsHomeHook,
|
||||
}:
|
||||
|
||||
let
|
||||
|
|
@ -100,8 +104,23 @@ buildPythonPackage (finalAttrs: {
|
|||
postInstall = lib.optionalString (!withCli) "rm -r $out/bin";
|
||||
|
||||
# not running python tests, as they require network access
|
||||
nativeCheckInputs = lib.optionals withCli [
|
||||
nativeCheckInputs =
|
||||
lib.optionals
|
||||
(
|
||||
withCli
|
||||
# Crashes in the sandbox as no drivers are available
|
||||
# opencl._cl.RuntimeError: no CL platforms available to ICD loader
|
||||
&& (!cudaSupport)
|
||||
)
|
||||
[
|
||||
versionCheckHook
|
||||
]
|
||||
++ lib.optionals cudaSupport [
|
||||
# Provides a CPU-based OpenCL ICD so that pyopencl's module-level
|
||||
# cl.create_some_context() succeeds without GPU hardware.
|
||||
pocl
|
||||
# pocl needs a writable $HOME for its kernel cache directory.
|
||||
writableTmpDirAsHomeHook
|
||||
];
|
||||
versionCheckKeepEnvironment = [
|
||||
# Otherwise, fail with:
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@
|
|||
scikit-misc,
|
||||
|
||||
# tests
|
||||
dependency-groups,
|
||||
jinja2,
|
||||
pytest-cov-stub,
|
||||
pytest-mock,
|
||||
|
|
@ -156,6 +157,7 @@ buildPythonPackage (finalAttrs: {
|
|||
};
|
||||
|
||||
nativeCheckInputs = [
|
||||
dependency-groups
|
||||
jinja2
|
||||
pytest-cov-stub
|
||||
pytest-mock
|
||||
|
|
@ -172,6 +174,12 @@ buildPythonPackage (finalAttrs: {
|
|||
export NUMBA_CACHE_DIR=$(mktemp -d);
|
||||
'';
|
||||
|
||||
pytestFlagsArray = [
|
||||
# UserWarning: 'where' used without 'out', expect unitialized memory in output.
|
||||
# If this is intentional, use out=None.
|
||||
"-Wignore::UserWarning"
|
||||
];
|
||||
|
||||
disabledTestPaths = [
|
||||
# try to download data:
|
||||
"tests/test_aggregated.py"
|
||||
|
|
@ -221,6 +229,11 @@ buildPythonPackage (finalAttrs: {
|
|||
# 'write/test.h5ad', errno = 2, error message = 'No such file or directory', flags = 13, o_flags
|
||||
# = 242)
|
||||
"test_write"
|
||||
|
||||
# Snapshot tests failing because of warnings in output
|
||||
"scanpy.datasets._datasets.krumsiek11"
|
||||
"scanpy.datasets._datasets.toggleswitch"
|
||||
"scanpy.preprocessing._simple.filter_cells"
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "scanpy" ];
|
||||
|
|
|
|||
|
|
@ -2,18 +2,30 @@
|
|||
lib,
|
||||
buildPythonPackage,
|
||||
fetchFromGitHub,
|
||||
|
||||
# build-system
|
||||
setuptools,
|
||||
|
||||
# dependencies
|
||||
gym,
|
||||
gymnasium,
|
||||
torch,
|
||||
packaging,
|
||||
tensorboard,
|
||||
torch,
|
||||
tqdm,
|
||||
wandb,
|
||||
packaging,
|
||||
|
||||
# tests
|
||||
flax,
|
||||
jax,
|
||||
optax,
|
||||
pettingzoo,
|
||||
pygame,
|
||||
pymunk,
|
||||
pytestCheckHook,
|
||||
}:
|
||||
|
||||
buildPythonPackage rec {
|
||||
buildPythonPackage (finalAttrs: {
|
||||
pname = "skrl";
|
||||
version = "1.4.3";
|
||||
pyproject = true;
|
||||
|
|
@ -21,49 +33,49 @@ buildPythonPackage rec {
|
|||
src = fetchFromGitHub {
|
||||
owner = "Toni-SM";
|
||||
repo = "skrl";
|
||||
tag = version;
|
||||
tag = finalAttrs.version;
|
||||
hash = "sha256-5lkoYAmMIWqK3+E3WxXMWS9zal2DhZkfl30EkrHKpdI=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [ setuptools ];
|
||||
build-system = [ setuptools ];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
dependencies = [
|
||||
gym
|
||||
gymnasium
|
||||
torch
|
||||
packaging
|
||||
tensorboard
|
||||
torch
|
||||
tqdm
|
||||
wandb
|
||||
packaging
|
||||
];
|
||||
|
||||
nativeCheckInputs = [ pytestCheckHook ];
|
||||
doCheck = torch.cudaSupport;
|
||||
pythonImportsCheck = [ "skrl" ];
|
||||
|
||||
pythonImportsCheck = [
|
||||
"skrl"
|
||||
"skrl.agents"
|
||||
"skrl.agents.torch"
|
||||
"skrl.envs"
|
||||
"skrl.envs.torch"
|
||||
"skrl.models"
|
||||
"skrl.models.torch"
|
||||
"skrl.resources"
|
||||
"skrl.resources.noises"
|
||||
"skrl.resources.noises.torch"
|
||||
"skrl.resources.schedulers"
|
||||
"skrl.resources.schedulers.torch"
|
||||
"skrl.trainers"
|
||||
"skrl.trainers.torch"
|
||||
"skrl.utils"
|
||||
"skrl.utils.model_instantiators"
|
||||
nativeCheckInputs = [
|
||||
flax
|
||||
jax
|
||||
optax
|
||||
pettingzoo
|
||||
pygame
|
||||
pymunk
|
||||
pytestCheckHook
|
||||
];
|
||||
|
||||
disabledTests = [
|
||||
# TypeError: The array passed to from_dlpack must have __dlpack__ and __dlpack_device__ methods
|
||||
"test_env"
|
||||
"test_multi_agent_env"
|
||||
|
||||
# OverflowError
|
||||
"test_key"
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "Reinforcement learning library using PyTorch focusing on readability and simplicity";
|
||||
changelog = "https://github.com/Toni-SM/skrl/releases/tag/${version}";
|
||||
homepage = "https://skrl.readthedocs.io";
|
||||
downloadPage = "https://github.com/Toni-SM/skrl";
|
||||
changelog = "https://github.com/Toni-SM/skrl/releases/tag/${finalAttrs.src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
maintainers = with lib.maintainers; [ bcdarwin ];
|
||||
};
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -834,6 +834,8 @@ self: super: with self; {
|
|||
|
||||
apache-beam = callPackage ../development/python-modules/apache-beam { };
|
||||
|
||||
apache-tvm-ffi = callPackage ../development/python-modules/apache-tvm-ffi { };
|
||||
|
||||
apcaccess = callPackage ../development/python-modules/apcaccess { };
|
||||
|
||||
apeye = callPackage ../development/python-modules/apeye { };
|
||||
|
|
@ -3384,6 +3386,8 @@ self: super: with self; {
|
|||
|
||||
cuda-bindings = callPackage ../development/python-modules/cuda-bindings { };
|
||||
|
||||
cuda-pathfinder = callPackage ../development/python-modules/cuda-pathfinder { };
|
||||
|
||||
cupy = callPackage ../development/python-modules/cupy {
|
||||
cudaPackages =
|
||||
# CuDNN 9 is not supported:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue