Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions tests/test_trig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import random

import pytest

import arrayfire_wrapper.dtypes as dtype
import arrayfire_wrapper.lib as wrapper

from . import utility_functions as util


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse sine operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.asin(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse cosine operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.acos(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse tan operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.atan(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_float_types())
def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse tan operation across all supported data types."""
util.check_type_supported(dtype_name)
if dtype_name == dtype.f16:
pytest.skip()
lhs = wrapper.randu(shape, dtype_name)
rhs = wrapper.randu(shape, dtype_name)
result = wrapper.atan2(lhs, rhs)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"invdtypes",
[
dtype.int16,
dtype.bool,
],
)
def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
"""Test inverse tan operation for unsupported data types."""
with pytest.raises(RuntimeError):
wrapper.atan2(wrapper.randu((10, 10), invdtypes), wrapper.randu((10, 10), invdtypes))


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test cosine operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.cos(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test sin operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.sin(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test tan operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.tan(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
9 changes: 6 additions & 3 deletions tests/utility_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import pytest

import arrayfire_wrapper.lib as wrapper
from arrayfire_wrapper.dtypes import Dtype, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64
from arrayfire_wrapper.dtypes import Dtype, b8, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64


def check_type_supported(dtype: Dtype) -> None:
"""Checks to see if the specified type is supported by the current system"""
if dtype in [f64, c64] and not wrapper.get_dbl_support():
pytest.skip("Device does not support double types")

if dtype == f16 and not wrapper.get_half_support():
pytest.skip("Device does not support half types.")

Expand All @@ -25,4 +24,8 @@ def get_real_types() -> list:

def get_all_types() -> list:
"""Returns all types"""
return [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
return [b8, s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]

def get_float_types() -> list:
"""Returns all types"""
return [f16, f32, f64]