|
78 | 78 | import numpy as np |
79 | 79 |
|
80 | 80 | from arraycontext.container import ( |
| 81 | + ArithArrayContainer, |
81 | 82 | ArrayContainer, |
82 | 83 | NotAnArrayContainerError, |
83 | 84 | SerializationKey, |
|
88 | 89 | from arraycontext.context import ( |
89 | 90 | Array, |
90 | 91 | ArrayContext, |
| 92 | + ArrayOrArithContainer, |
91 | 93 | ArrayOrContainer, |
92 | 94 | ArrayOrContainerOrScalar, |
93 | 95 | ArrayOrContainerT, |
@@ -987,4 +989,79 @@ def treat_as_scalar(x: Any) -> bool: |
987 | 989 |
|
988 | 990 | # }}} |
989 | 991 |
|
| 992 | + |
| 993 | +# {{{ |
| 994 | + |
| 995 | +def bcast_left( |
| 996 | + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], |
| 997 | + ArrayOrArithContainer], |
| 998 | + left: ArrayOrArithContainer, |
| 999 | + right: ArithArrayContainer, |
| 1000 | + ) -> ArrayOrArithContainer: |
| 1001 | + try: |
| 1002 | + serialized = serialize_container(right) |
| 1003 | + except NotAnArrayContainerError: |
| 1004 | + return op(left, right) |
| 1005 | + |
| 1006 | + return deserialize_container(right, [ |
| 1007 | + (k, op(left, right_v)) for k, right_v in serialized]) |
| 1008 | + |
| 1009 | + |
| 1010 | +def bcast_right( |
| 1011 | + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], |
| 1012 | + ArrayOrArithContainer], |
| 1013 | + left: ArrayOrArithContainer, |
| 1014 | + right: ArithArrayContainer, |
| 1015 | + ) -> ArrayOrArithContainer: |
| 1016 | + try: |
| 1017 | + serialized = serialize_container(left) |
| 1018 | + except NotAnArrayContainerError: |
| 1019 | + return op(left, right) |
| 1020 | + |
| 1021 | + return deserialize_container(right, [ |
| 1022 | + (k, op(left_v, right)) for k, left_v in serialized]) |
| 1023 | + |
| 1024 | + |
| 1025 | +def bcast_left_until_actx_array( |
| 1026 | + actx: ArrayContext, |
| 1027 | + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], |
| 1028 | + ArrayOrArithContainer], |
| 1029 | + left: ArrayOrArithContainer, |
| 1030 | + right: ArithArrayContainer, |
| 1031 | + ) -> ArrayOrArithContainer: |
| 1032 | + try: |
| 1033 | + serialized = serialize_container(right) |
| 1034 | + except NotAnArrayContainerError: |
| 1035 | + return op(left, right) |
| 1036 | + |
| 1037 | + return deserialize_container(right, [ |
| 1038 | + (k, op(left, right_v) |
| 1039 | + if isinstance(right_v, actx.array_types) else |
| 1040 | + bcast_left_until_actx_array(actx, op, left, right_v) |
| 1041 | + ) |
| 1042 | + for k, right_v in serialized]) |
| 1043 | + |
| 1044 | + |
| 1045 | +def bcast_right_until_actx_array( |
| 1046 | + actx: ArrayContext, |
| 1047 | + op: Callable[[ArrayOrArithContainer, ArrayOrArithContainer], |
| 1048 | + ArrayOrArithContainer], |
| 1049 | + left: ArrayOrArithContainer, |
| 1050 | + right: ArithArrayContainer, |
| 1051 | + ) -> ArrayOrArithContainer: |
| 1052 | + try: |
| 1053 | + serialized = serialize_container(left) |
| 1054 | + except NotAnArrayContainerError: |
| 1055 | + return op(left, right) |
| 1056 | + |
| 1057 | + return deserialize_container(right, [ |
| 1058 | + (k, op(left_v, right) |
| 1059 | + if isinstance(left_v, actx.array_types) else |
| 1060 | + bcast_right_until_actx_array(actx, op, left_v, right) |
| 1061 | + ) |
| 1062 | + for k, left_v in serialized]) |
| 1063 | + |
| 1064 | +# }}} |
| 1065 | + |
| 1066 | + |
990 | 1067 | # vim: foldmethod=marker |
0 commit comments