diff --git a/dpctl/__main__.py b/dpctl/__main__.py index 01d55eb2c9..83038551cd 100644 --- a/dpctl/__main__.py +++ b/dpctl/__main__.py @@ -19,6 +19,7 @@ import os.path import platform import sys +import warnings import dpctl @@ -50,6 +51,24 @@ def print_library() -> None: print(ld_flags + " -lSyclInterface") +def _warn_if_any_set(args, li) -> None: + opts_set = [it for it in li if getattr(args, it, True)] + if opts_set: + if len(opts_set) == 1: + warnings.warn( + "The option " + str(opts_set[0]) + " is being ignored.", + stacklevel=3, + ) + else: + warnings.warn( + "Options " + str(opts_set) + " are being ignored.", stacklevel=3 + ) + + +def print_lsplatform(verbosity: int) -> None: + dpctl.lsplatform(verbosity=verbosity) + + def main() -> None: """Main entry-point.""" parser = argparse.ArgumentParser() @@ -68,9 +87,45 @@ def main() -> None: action="store_true", help="Linker flags for SyclInterface library.", ) + parser.add_argument( + "-f", + "--full-list", + action="store_true", + help="Enumerate system platforms, using dpctl.lsplatform(verbosity=2)", + ) + parser.add_argument( + "-l", + "--long-list", + action="store_true", + help="Enumerate system platforms, using dpctl.lsplatform(verbosity=1)", + ) + parser.add_argument( + "-s", + "--summary", + action="store_true", + help="Enumerate system platforms, using dpctl.lsplatform()", + ) args = parser.parse_args() if not sys.argv[1:]: parser.print_help() + if args.full_list: + _warn_if_any_set( + args, ["long_list", "summary", "includes", "cmakedir", "library"] + ) + print_lsplatform(2) + return + if args.long_list: + _warn_if_any_set( + args, ["full_list", "summary", "includes", "cmakedir", "library"] + ) + print_lsplatform(1) + return + if args.summary: + _warn_if_any_set( + args, ["long_list", "full_list", "includes", "cmakedir", "library"] + ) + print_lsplatform(0) + return if args.includes: print_includes() if args.cmakedir: diff --git a/dpctl/tests/test_service.py b/dpctl/tests/test_service.py index 175bdb3521..3ba3f7fd8f 100644 --- a/dpctl/tests/test_service.py +++ b/dpctl/tests/test_service.py @@ -182,3 +182,49 @@ def test_cmakedir(): assert res.stdout cmake_dir = res.stdout.decode("utf-8").strip() assert os.path.exists(os.path.join(cmake_dir, "FindDpctl.cmake")) + + +def test_main_full_list(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "-f"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + assert res.stdout.decode("utf-8") + + +def test_main_long_list(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "-l"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + assert res.stdout.decode("utf-8") + + +def test_main_summary(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "-s"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + assert res.stdout.decode("utf-8") + + +def test_main_warnings(): + res = subprocess.run( + [sys.executable, "-m", "dpctl", "-s", "--includes"], capture_output=True + ) + assert res.returncode == 0 + assert res.stdout + assert "UserWarning" in res.stderr.decode("utf-8") + assert "is being ignored." in res.stderr.decode("utf-8") + + res = subprocess.run( + [sys.executable, "-m", "dpctl", "-s", "--includes", "--cmakedir"], + capture_output=True, + ) + assert res.returncode == 0 + assert res.stdout + assert "UserWarning" in res.stderr.decode("utf-8") + assert "are being ignored." in res.stderr.decode("utf-8")